This commit is contained in:
Zhalslar
2025-06-17 10:55:35 +08:00
parent 825e3dbcf5
commit 14c29f07bd
2 changed files with 88 additions and 46 deletions
+3 -2
View File
@@ -25,8 +25,8 @@ DEFAULT_CONFIG = {
"id_whitelist_log": True,
"wl_ignore_admin_on_group": True,
"wl_ignore_admin_on_friend": True,
"reply_with_mention": 0.0,
"reply_with_quote": 0.0,
"reply_with_mention": False,
"reply_with_quote": False,
"path_mapping": [],
"segmented_reply": {
"enable": False,
@@ -808,6 +808,7 @@ CONFIG_METADATA_2 = {
"api_base": "http://127.0.0.1:9880",
"gpt_weights_path": "",
"sovits_weights_path": "",
"timeout": 60,
"gsv_default_parms": {
"gsv_ref_audio_path": "",
"gsv_prompt_text": "",
@@ -23,78 +23,101 @@ class ProviderGSVTTS(TTSProvider):
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
# 基础URL
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880")
if self.api_base.endswith("/"):
self.api_base = self.api_base[:-1]
# 模型文件路径
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
"/"
)
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
asyncio.create_task(self._set_model_weights())
# 默认参数
raw_params = provider_config.get("gsv_default_parms", {})
# TTS 请求的默认参数,移除前缀gsv_
self.default_params: dict = {
key.removeprefix("gsv_"): str(value).lower()
for key, value in raw_params.items()
for key, value in provider_config.get("gsv_default_parms", {}).items()
}
self.timeout = provider_config.get("timeout", 60)
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
)
asyncio.create_task(self._async_init()).add_done_callback(
self._handle_init_exception
)
# 情绪预设
self.emotions = provider_config.get("emotions", {})
async def _async_init(self):
await self._set_model_weights()
async def _make_request(
self,
endpoint: str,
params=None,
) -> str | bytes:
"""通用的异步请求方法"""
async with aiohttp.ClientSession() as session:
async with session.request("GET", endpoint, params=params) as response:
if response.status != 200:
return await response.text()
else:
def _handle_init_exception(self, task: asyncio.Task):
if task.exception():
logger.error(f"[GSV TTS] 初始化失败:{task.exception()}")
def get_session(self) -> aiohttp.ClientSession:
if not self._session or self._session.closed:
raise RuntimeError("[GSV TTS] Provider HTTP session is not ready or closed.")
return self._session
async def _make_request(self, endpoint: str, params=None, retries: int = 3) -> bytes | None:
"""发起请求"""
for attempt in range(retries):
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
try:
async with self.get_session().get(endpoint, params=params) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
)
return await response.read()
except Exception as e:
if attempt < retries - 1:
logger.warning(
f"[GSV TTS] 请求 {endpoint}{attempt + 1} 次失败:{e},重试中..."
)
await asyncio.sleep(1)
else:
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
raise
async def _set_model_weights(self):
"""设置模型"""
"""设置模型路径"""
try:
# 设置 GPT 模型
if self.gpt_weights_path:
gpt_endpoint = f"{self.api_base}/set_gpt_weights"
gpt_params = {"weights_path": self.gpt_weights_path}
if await self._make_request(endpoint=gpt_endpoint, params=gpt_params):
logger.info(f"成功设置 GPT 模型路径:{self.gpt_weights_path}")
await self._make_request(
f"{self.api_base}/set_gpt_weights",
{"weights_path": self.gpt_weights_path},
)
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
else:
logger.info("GPT 模型路径未配置,将使用GPT_SoVITS内置GPT模型")
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
# 设置 SoVITS 模型
if self.sovits_weights_path:
sovits_endpoint = f"{self.api_base}/set_sovits_weights"
sovits_params = {"weights_path": self.sovits_weights_path}
if await self._make_request(
endpoint=sovits_endpoint, params=sovits_params
):
logger.info(f"成功设置 SoVITS 模型路径:{self.sovits_weights_path}")
await self._make_request(
f"{self.api_base}/set_sovits_weights",
{"weights_path": self.sovits_weights_path},
)
logger.info(
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
)
else:
logger.info("SoVITS 模型路径未配置,将使用GPT_SoVITS内置SoVITS模型")
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
except aiohttp.ClientError as e:
logger.error(f"设置模型路径时发生错误:{e}")
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
except Exception as e:
logger.error(f"发生未知错误:{e}")
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
async def get_audio(self, text: str) -> str:
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
if not text:
raise ValueError("[GSV TTS] TTS 文本不能为空")
endpoint = f"{self.api_base}/tts"
params = self.default_params.copy()
params["text"] = text
params = self.build_synthesis_params(text)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
logger.debug(f"正在调用GSV语音合成接口,参数:{params}")
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
result = await self._make_request(endpoint, params)
if isinstance(result, bytes):
@@ -102,5 +125,23 @@ class ProviderGSVTTS(TTSProvider):
f.write(result)
return path
else:
raise Exception(f"GSVI TTS API 请求失败: {result}")
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
def build_synthesis_params(self, text: str) -> dict:
"""
构建语音合成所需的参数字典。
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
"""
params = self.default_params.copy()
params["text"] = text
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
return params
async def shutdown(self):
if self._session and not self._session.closed:
await self._session.close()
def __del__(self):
if hasattr(self, "_session") and self._session and not self._session.closed:
logger.warning("[GSV TTS] ProviderGSVTTS 已被销毁但 session 未关闭,请确保调用 shutdown()")