diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 32dd1b454..27d28dc82 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -25,8 +25,8 @@ DEFAULT_CONFIG = { "id_whitelist_log": True, "wl_ignore_admin_on_group": True, "wl_ignore_admin_on_friend": True, - "reply_with_mention": False, - "reply_with_quote": False, + "reply_with_mention": 0.0, + "reply_with_quote": 0.0, "path_mapping": [], "segmented_reply": { "enable": False, @@ -466,13 +466,13 @@ CONFIG_METADATA_2 = { }, "reply_with_mention": { "description": "回复时 @ 发送者", - "type": "bool", - "hint": "启用后,机器人回复消息时会 @ 发送者。实际效果以具体的平台适配器为准。", + "type": "float", + "hint": "启用后,机器人回复消息时会 @ 发送者。0.0-1.0 之间的概率值,0.0 表示从不,1.0 表示总是。实际效果以具体的平台适配器为准。", }, "reply_with_quote": { "description": "回复时引用消息", - "type": "bool", - "hint": "启用后,机器人回复消息时会引用原消息。实际效果以具体的平台适配器为准。", + "type": "float", + "hint": "启用后,机器人回复消息时会引用原消息。0.0-1.0 之间的概率值,0.0 表示从不,1.0 表示总是。实际效果以具体的平台适配器为准。", }, "path_mapping": { "description": "路径映射", @@ -800,6 +800,36 @@ CONFIG_METADATA_2 = { "edge-tts-voice": "zh-CN-XiaoxiaoNeural", "timeout": 20, }, + "GSV TTS(本地加载)": { + "id": "gsv_tts", + "enable": False, + "type": "gsv_tts_selfhost", + "provider_type": "text_to_speech", + "api_base": "http://127.0.0.1:9880", + "gpt_weights_path": "", + "sovits_weights_path": "", + "gsv_default_parms": { + "gsv_ref_audio_path": "", + "gsv_prompt_text": "", + "gsv_prompt_lang": "zh", + "gsv_aux_ref_audio_paths": "", + "gsv_text_lang": "zh", + "gsv_top_k": 5, + "gsv_top_p": 1.0, + "gsv_temperature": 1.0, + "gsv_text_split_method": "cut3", + "gsv_batch_size": 1, + "gsv_batch_threshold": 0.75, + "gsv_split_bucket": True, + "gsv_speed_factor": 1, + "gsv_fragment_interval": 0.3, + "gsv_streaming_mode": False, + "gsv_seed": -1, + "gsv_parallel_infer": True, + "gsv_repetition_penalty": 1.35, + "gsv_media_type": "wav", + }, + }, "GSVI TTS(API)": { "id": "gsvi_tts", "type": "gsvi_tts_api", @@ -901,6 +931,130 @@ CONFIG_METADATA_2 = { }, }, "items": { + "gpt_weights_path": { + "description": "GPT模型文件路径", + "type": "string", + "hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)", + "obvious_hint": True, + }, + "sovits_weights_path": { + "description": "SoVITS模型文件路径", + "type": "string", + "hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)", + "obvious_hint": True, + }, + "gsv_default_parms": { + "description": "GPT_SoVITS默认参数", + "hint": "参考音频文件路径、参考音频文本必填,其他参数根据个人爱好自行填写", + "type": "object", + "items": { + "gsv_ref_audio_path": { + "description": "参考音频文件路径", + "type": "string", + "hint": "必填!请使用绝对路径!路径两端不要带双引号!", + "obvious_hint": True, + }, + "gsv_prompt_text": { + "description": "参考音频文本", + "type": "string", + "hint": "必填!请填写参考音频讲述的文本", + "obvious_hint": True, + }, + "gsv_prompt_lang": { + "description": "参考音频文本语言", + "type": "string", + "hint": "请填写参考音频讲述的文本的语言,默认为中文", + }, + "gsv_aux_ref_audio_paths": { + "description": "辅助参考音频文件路径", + "type": "string", + "hint": "辅助参考音频文件,可不填", + }, + "gsv_text_lang": { + "description": "文本语言", + "type": "string", + "hint": "默认为中文", + }, + "gsv_top_k": { + "description": "生成语音的多样性", + "type": "int", + "hint": "", + }, + "gsv_top_p": { + "description": "核采样的阈值", + "type": "float", + "hint": "", + }, + "gsv_temperature": { + "description": "生成语音的随机性", + "type": "float", + "hint": "", + }, + "gsv_text_split_method": { + "description": "切分文本的方法", + "type": "string", + "hint": "可选值: `cut0`:不切分 `cut1`:四句一切 `cut2`:50字一切 `cut3`:按中文句号切 `cut4`:按英文句号切 `cut5`:按标点符号切", + "options": [ + "cut0", + "cut1", + "cut2", + "cut3", + "cut4", + "cut5", + ], + }, + "gsv_batch_size": { + "description": "批处理大小", + "type": "int", + "hint": "", + }, + "gsv_batch_threshold": { + "description": "批处理阈值", + "type": "float", + "hint": "", + }, + "gsv_split_bucket": { + "description": "将文本分割成桶以便并行处理", + "type": "bool", + "hint": "", + }, + "gsv_speed_factor": { + "description": "语音播放速度", + "type": "float", + "hint": "1为原始语速", + }, + "gsv_fragment_interval": { + "description": "语音片段之间的间隔时间", + "type": "float", + "hint": "", + }, + "gsv_streaming_mode": { + "description": "启用流模式", + "type": "bool", + "hint": "", + }, + "gsv_seed": { + "description": "随机种子", + "type": "int", + "hint": "用于结果的可重复性", + }, + "gsv_parallel_infer": { + "description": "并行执行推理", + "type": "bool", + "hint": "", + }, + "gsv_repetition_penalty": { + "description": "重复惩罚因子", + "type": "float", + "hint": "", + }, + "gsv_media_type": { + "description": "输出媒体的类型", + "type": "string", + "hint": "建议用wav", + }, + }, + }, "embedding_dimensions": { "description": "嵌入维度", "type": "int", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index b11a3361a..382f469fe 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -225,6 +225,10 @@ class ProviderManager: from .sources.edge_tts_source import ( ProviderEdgeTTS as ProviderEdgeTTS, ) + case "gsv_tts_selfhost": + from .sources.gsv_selfhosted_source import ( + ProviderGSVTTS as ProviderGSVTTS, + ) case "gsvi_tts_api": from .sources.gsvi_tts_source import ( ProviderGSVITTS as ProviderGSVITTS, diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py new file mode 100644 index 000000000..dbde29f0d --- /dev/null +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -0,0 +1,106 @@ + +import asyncio +import os +import uuid + +import aiohttp +from ..provider import TTSProvider +from ..entities import ProviderType +from ..register import register_provider_adapter +from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + + +@register_provider_adapter( + provider_type_name="gsv_tts_selfhost", + desc=" GPT-SoVITS TTS(本地加载)", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderGSVTTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + # 基础URL + self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880") + if self.api_base.endswith("/"): + self.api_base = self.api_base[:-1] + + # 模型文件路径 + self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "") + self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "") + asyncio.create_task(self._set_model_weights()) + + # 默认参数 + raw_params = provider_config.get("gsv_default_parms", {}) + self.default_params: dict = { + key.removeprefix("gsv_"): str(value).lower() + for key, value in raw_params.items() + } + + # 情绪预设 + self.emotions = provider_config.get("emotions", {}) + + async def _make_request( + self, + endpoint: str, + params=None, + ) -> str | bytes: + """通用的异步请求方法""" + async with aiohttp.ClientSession() as session: + async with session.request("GET", endpoint, params=params) as response: + if response.status != 200: + return await response.text() + else: + return await response.read() + + async def _set_model_weights(self): + """设置模型""" + try: + # 设置 GPT 模型 + if self.gpt_weights_path: + gpt_endpoint = f"{self.api_base}/set_gpt_weights" + gpt_params = {"weights_path": self.gpt_weights_path} + if await self._make_request(endpoint=gpt_endpoint, params=gpt_params): + logger.info(f"成功设置 GPT 模型路径:{self.gpt_weights_path}") + else: + logger.info("GPT 模型路径未配置,将使用GPT_SoVITS内置的GPT模型") + + # 设置 SoVITS 模型 + if self.sovits_weights_path: + sovits_endpoint = f"{self.api_base}/set_sovits_weights" + sovits_params = {"weights_path": self.sovits_weights_path} + if await self._make_request( + endpoint=sovits_endpoint, params=sovits_params + ): + logger.info(f"成功设置 SoVITS 模型路径:{self.sovits_weights_path}") + else: + logger.info("SoVITS 模型路径未配置,将使用GPT_SoVITS内置的SoVITS模型") + except aiohttp.ClientError as e: + logger.error(f"设置模型路径时发生错误:{e}") + except Exception as e: + logger.error(f"发生未知错误:{e}") + + async def get_audio(self, text: str) -> str: + """实现 TTS 核心方法,根据文本内容自动切换情绪""" + endpoint = f"{self.api_base}/tts" + + params = self.default_params.copy() + params["text"] = text + + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + os.makedirs(temp_dir, exist_ok=True) + path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") + + logger.debug(f"正在调用GSV语音合成接口,参数:{params}") + + result = await self._make_request(endpoint, params) + if isinstance(result, bytes): + with open(path, "wb") as f: + f.write(result) + return path + else: + raise Exception(f"GSVI TTS API 请求失败: {result}") +