优化
This commit is contained in:
@@ -25,8 +25,8 @@ DEFAULT_CONFIG = {
|
||||
"id_whitelist_log": True,
|
||||
"wl_ignore_admin_on_group": True,
|
||||
"wl_ignore_admin_on_friend": True,
|
||||
"reply_with_mention": 0.0,
|
||||
"reply_with_quote": 0.0,
|
||||
"reply_with_mention": False,
|
||||
"reply_with_quote": False,
|
||||
"path_mapping": [],
|
||||
"segmented_reply": {
|
||||
"enable": False,
|
||||
@@ -808,6 +808,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "http://127.0.0.1:9880",
|
||||
"gpt_weights_path": "",
|
||||
"sovits_weights_path": "",
|
||||
"timeout": 60,
|
||||
"gsv_default_parms": {
|
||||
"gsv_ref_audio_path": "",
|
||||
"gsv_prompt_text": "",
|
||||
|
||||
@@ -23,78 +23,101 @@ class ProviderGSVTTS(TTSProvider):
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
# 基础URL
|
||||
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880")
|
||||
if self.api_base.endswith("/"):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
# 模型文件路径
|
||||
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
|
||||
"/"
|
||||
)
|
||||
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
|
||||
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
|
||||
asyncio.create_task(self._set_model_weights())
|
||||
|
||||
# 默认参数
|
||||
raw_params = provider_config.get("gsv_default_parms", {})
|
||||
# TTS 请求的默认参数,移除前缀gsv_
|
||||
self.default_params: dict = {
|
||||
key.removeprefix("gsv_"): str(value).lower()
|
||||
for key, value in raw_params.items()
|
||||
for key, value in provider_config.get("gsv_default_parms", {}).items()
|
||||
}
|
||||
self.timeout = provider_config.get("timeout", 60)
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
)
|
||||
asyncio.create_task(self._async_init()).add_done_callback(
|
||||
self._handle_init_exception
|
||||
)
|
||||
|
||||
# 情绪预设
|
||||
self.emotions = provider_config.get("emotions", {})
|
||||
async def _async_init(self):
|
||||
await self._set_model_weights()
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
params=None,
|
||||
) -> str | bytes:
|
||||
"""通用的异步请求方法"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request("GET", endpoint, params=params) as response:
|
||||
if response.status != 200:
|
||||
return await response.text()
|
||||
else:
|
||||
def _handle_init_exception(self, task: asyncio.Task):
|
||||
if task.exception():
|
||||
logger.error(f"[GSV TTS] 初始化失败:{task.exception()}")
|
||||
|
||||
def get_session(self) -> aiohttp.ClientSession:
|
||||
if not self._session or self._session.closed:
|
||||
raise RuntimeError("[GSV TTS] Provider HTTP session is not ready or closed.")
|
||||
return self._session
|
||||
|
||||
async def _make_request(self, endpoint: str, params=None, retries: int = 3) -> bytes | None:
|
||||
"""发起请求"""
|
||||
for attempt in range(retries):
|
||||
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
|
||||
try:
|
||||
async with self.get_session().get(endpoint, params=params) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
|
||||
)
|
||||
return await response.read()
|
||||
except Exception as e:
|
||||
if attempt < retries - 1:
|
||||
logger.warning(
|
||||
f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
|
||||
raise
|
||||
|
||||
async def _set_model_weights(self):
|
||||
"""设置模型"""
|
||||
"""设置模型路径"""
|
||||
try:
|
||||
# 设置 GPT 模型
|
||||
if self.gpt_weights_path:
|
||||
gpt_endpoint = f"{self.api_base}/set_gpt_weights"
|
||||
gpt_params = {"weights_path": self.gpt_weights_path}
|
||||
if await self._make_request(endpoint=gpt_endpoint, params=gpt_params):
|
||||
logger.info(f"成功设置 GPT 模型路径:{self.gpt_weights_path}")
|
||||
await self._make_request(
|
||||
f"{self.api_base}/set_gpt_weights",
|
||||
{"weights_path": self.gpt_weights_path},
|
||||
)
|
||||
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
|
||||
else:
|
||||
logger.info("GPT 模型路径未配置,将使用GPT_SoVITS内置的GPT模型")
|
||||
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
|
||||
|
||||
# 设置 SoVITS 模型
|
||||
if self.sovits_weights_path:
|
||||
sovits_endpoint = f"{self.api_base}/set_sovits_weights"
|
||||
sovits_params = {"weights_path": self.sovits_weights_path}
|
||||
if await self._make_request(
|
||||
endpoint=sovits_endpoint, params=sovits_params
|
||||
):
|
||||
logger.info(f"成功设置 SoVITS 模型路径:{self.sovits_weights_path}")
|
||||
await self._make_request(
|
||||
f"{self.api_base}/set_sovits_weights",
|
||||
{"weights_path": self.sovits_weights_path},
|
||||
)
|
||||
logger.info(
|
||||
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
|
||||
)
|
||||
else:
|
||||
logger.info("SoVITS 模型路径未配置,将使用GPT_SoVITS内置的SoVITS模型")
|
||||
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"设置模型路径时发生错误:{e}")
|
||||
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
|
||||
except Exception as e:
|
||||
logger.error(f"发生未知错误:{e}")
|
||||
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
|
||||
if not text:
|
||||
raise ValueError("[GSV TTS] TTS 文本不能为空")
|
||||
|
||||
endpoint = f"{self.api_base}/tts"
|
||||
|
||||
params = self.default_params.copy()
|
||||
params["text"] = text
|
||||
params = self.build_synthesis_params(text)
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
||||
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
|
||||
|
||||
logger.debug(f"正在调用GSV语音合成接口,参数:{params}")
|
||||
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
|
||||
|
||||
result = await self._make_request(endpoint, params)
|
||||
if isinstance(result, bytes):
|
||||
@@ -102,5 +125,23 @@ class ProviderGSVTTS(TTSProvider):
|
||||
f.write(result)
|
||||
return path
|
||||
else:
|
||||
raise Exception(f"GSVI TTS API 请求失败: {result}")
|
||||
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
|
||||
|
||||
def build_synthesis_params(self, text: str) -> dict:
|
||||
"""
|
||||
构建语音合成所需的参数字典。
|
||||
|
||||
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
|
||||
"""
|
||||
params = self.default_params.copy()
|
||||
params["text"] = text
|
||||
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
|
||||
return params
|
||||
|
||||
async def shutdown(self):
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "_session") and self._session and not self._session.closed:
|
||||
logger.warning("[GSV TTS] ProviderGSVTTS 已被销毁但 session 未关闭,请确保调用 shutdown()")
|
||||
|
||||
Reference in New Issue
Block a user