From 14c29f07bd0c6552083cad17e342e022bb9d95f4 Mon Sep 17 00:00:00 2001 From: Zhalslar Date: Tue, 17 Jun 2025 10:55:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 5 +- .../provider/sources/gsv_selfhosted_source.py | 129 ++++++++++++------ 2 files changed, 88 insertions(+), 46 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 8eb9c0e97..61f911511 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -25,8 +25,8 @@ DEFAULT_CONFIG = { "id_whitelist_log": True, "wl_ignore_admin_on_group": True, "wl_ignore_admin_on_friend": True, - "reply_with_mention": 0.0, - "reply_with_quote": 0.0, + "reply_with_mention": False, + "reply_with_quote": False, "path_mapping": [], "segmented_reply": { "enable": False, @@ -808,6 +808,7 @@ CONFIG_METADATA_2 = { "api_base": "http://127.0.0.1:9880", "gpt_weights_path": "", "sovits_weights_path": "", + "timeout": 60, "gsv_default_parms": { "gsv_ref_audio_path": "", "gsv_prompt_text": "", diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index dbde29f0d..ff0083a1a 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -23,78 +23,101 @@ class ProviderGSVTTS(TTSProvider): provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - # 基础URL - self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880") - if self.api_base.endswith("/"): - self.api_base = self.api_base[:-1] - # 模型文件路径 + self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip( + "/" + ) self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "") self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "") - asyncio.create_task(self._set_model_weights()) - # 默认参数 - raw_params = provider_config.get("gsv_default_parms", {}) + # TTS 请求的默认参数,移除前缀gsv_ self.default_params: dict = { key.removeprefix("gsv_"): str(value).lower() - for key, value in raw_params.items() + for key, value in provider_config.get("gsv_default_parms", {}).items() } + self.timeout = provider_config.get("timeout", 60) + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + asyncio.create_task(self._async_init()).add_done_callback( + self._handle_init_exception + ) - # 情绪预设 - self.emotions = provider_config.get("emotions", {}) + async def _async_init(self): + await self._set_model_weights() - async def _make_request( - self, - endpoint: str, - params=None, - ) -> str | bytes: - """通用的异步请求方法""" - async with aiohttp.ClientSession() as session: - async with session.request("GET", endpoint, params=params) as response: - if response.status != 200: - return await response.text() - else: + def _handle_init_exception(self, task: asyncio.Task): + if task.exception(): + logger.error(f"[GSV TTS] 初始化失败:{task.exception()}") + + def get_session(self) -> aiohttp.ClientSession: + if not self._session or self._session.closed: + raise RuntimeError("[GSV TTS] Provider HTTP session is not ready or closed.") + return self._session + + async def _make_request(self, endpoint: str, params=None, retries: int = 3) -> bytes | None: + """发起请求""" + for attempt in range(retries): + logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}") + try: + async with self.get_session().get(endpoint, params=params) as response: + if response.status != 200: + error_text = await response.text() + raise Exception( + f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}" + ) return await response.read() + except Exception as e: + if attempt < retries - 1: + logger.warning( + f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..." + ) + await asyncio.sleep(1) + else: + logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}") + raise async def _set_model_weights(self): - """设置模型""" + """设置模型路径""" try: - # 设置 GPT 模型 if self.gpt_weights_path: - gpt_endpoint = f"{self.api_base}/set_gpt_weights" - gpt_params = {"weights_path": self.gpt_weights_path} - if await self._make_request(endpoint=gpt_endpoint, params=gpt_params): - logger.info(f"成功设置 GPT 模型路径:{self.gpt_weights_path}") + await self._make_request( + f"{self.api_base}/set_gpt_weights", + {"weights_path": self.gpt_weights_path}, + ) + logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}") else: - logger.info("GPT 模型路径未配置,将使用GPT_SoVITS内置的GPT模型") + logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型") - # 设置 SoVITS 模型 if self.sovits_weights_path: - sovits_endpoint = f"{self.api_base}/set_sovits_weights" - sovits_params = {"weights_path": self.sovits_weights_path} - if await self._make_request( - endpoint=sovits_endpoint, params=sovits_params - ): - logger.info(f"成功设置 SoVITS 模型路径:{self.sovits_weights_path}") + await self._make_request( + f"{self.api_base}/set_sovits_weights", + {"weights_path": self.sovits_weights_path}, + ) + logger.info( + f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}" + ) else: - logger.info("SoVITS 模型路径未配置,将使用GPT_SoVITS内置的SoVITS模型") + logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型") except aiohttp.ClientError as e: - logger.error(f"设置模型路径时发生错误:{e}") + logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}") except Exception as e: - logger.error(f"发生未知错误:{e}") + logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}") async def get_audio(self, text: str) -> str: """实现 TTS 核心方法,根据文本内容自动切换情绪""" + if not text: + raise ValueError("[GSV TTS] TTS 文本不能为空") + endpoint = f"{self.api_base}/tts" - params = self.default_params.copy() - params["text"] = text + params = self.build_synthesis_params(text) temp_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(temp_dir, exist_ok=True) - path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") + path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav") - logger.debug(f"正在调用GSV语音合成接口,参数:{params}") + logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}") result = await self._make_request(endpoint, params) if isinstance(result, bytes): @@ -102,5 +125,23 @@ class ProviderGSVTTS(TTSProvider): f.write(result) return path else: - raise Exception(f"GSVI TTS API 请求失败: {result}") + raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") + def build_synthesis_params(self, text: str) -> dict: + """ + 构建语音合成所需的参数字典。 + + 当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。 + """ + params = self.default_params.copy() + params["text"] = text + # TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text) + return params + + async def shutdown(self): + if self._session and not self._session.closed: + await self._session.close() + + def __del__(self): + if hasattr(self, "_session") and self._session and not self._session.closed: + logger.warning("[GSV TTS] ProviderGSVTTS 已被销毁但 session 未关闭,请确保调用 shutdown()")