From db13a602742b2a2fe0d215f80cc6fe1fb5686359 Mon Sep 17 00:00:00 2001 From: YOO_koishi <2358181935@qq.com> Date: Sun, 18 May 2025 03:18:36 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add-volcengine-tts-support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 5 +- astrbot/core/provider/manager.py | 4 + .../core/provider/sources/volcengine_tts.py | 87 +++++++++++++++---- 3 files changed, 75 insertions(+), 21 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index dd50b110f..526280130 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -803,13 +803,14 @@ CONFIG_METADATA_2 = { "火山引擎_TTS(API)": { "id": "volcengine_tts", "type": "volcengine_tts", + "provider_type": "text_to_speech", "enable": False, "api_key": "", "appid": "", "cluster": "", - "voice_type": "xiaoyun", + "voice_type": "", "api_base": "https://openspeech.bytedance.com/api/v1/tts", - "timeout": "20", + "timeout": 20, }, }, "items": { diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index e61fbf925..80269752d 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -206,6 +206,10 @@ class ProviderManager: from .sources.azure_tts_source import ( AzureTTSProvider as AzureTTSProvider, ) + case "volcengine_tts": + from .sources.volcengine_tts import ( + ProviderVolcengineTTS as ProviderVolcengineTTS, + ) except (ImportError, ModuleNotFoundError) as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。" diff --git a/astrbot/core/provider/sources/volcengine_tts.py b/astrbot/core/provider/sources/volcengine_tts.py index 76d08f078..90680f42b 100644 --- a/astrbot/core/provider/sources/volcengine_tts.py +++ b/astrbot/core/provider/sources/volcengine_tts.py @@ -1,6 +1,10 @@ import uuid import base64 import json +import os +import traceback +import asyncio +import aiohttp import requests from ..provider import TTSProvider from ..entities import ProviderType @@ -16,8 +20,11 @@ class ProviderVolcengineTTS(TTSProvider): self.appid = provider_config.get("appid", "") self.cluster = provider_config.get("cluster", "") self.voice_type = provider_config.get("voice_type", "xiaoyun") - self.api_base = provider_config.get("api_base", "https://openspeech.bytedance.com/api/v1/tts") - self.timeout = provider_config.get("timeout", "20") + + host = "openspeech.bytedance.com" + self.api_base = provider_config.get("api_base", f"https://{host}/api/v1/tts") + + self.timeout = provider_config.get("timeout", 20) def _build_request_payload(self, text: str) -> dict: return { @@ -27,7 +34,7 @@ class ProviderVolcengineTTS(TTSProvider): "cluster": self.cluster }, "user": { - "uid": str(uuid.uuid4()) + "uid": str(uuid.uuid4()) }, "audio": { "voice_type": self.voice_type, @@ -46,20 +53,62 @@ class ProviderVolcengineTTS(TTSProvider): } } - def get_audio(self, text: str) -> str: - headers = {"Authorization": f"Bearer {self.api_key}"} + async def get_audio(self, text: str) -> str: + """异步方法获取语音文件路径""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer;{self.api_key}" + } + payload = self._build_request_payload(text) - response = requests.post(self.api_base, json=payload, headers=headers, timeout=self.timeout) - - if response.status_code == 200: - resp_data = response.json() - if "data" in resp_data: - audio_data = base64.b64decode(resp_data["data"]) - file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3" - with open(file_path, "wb") as audio_file: - audio_file.write(audio_data) - return file_path - else: - raise Exception(f"火山引擎 TTS API 返回错误: {resp_data}") - else: - raise Exception(f"火山引擎 TTS API 请求失败: {response.status_code}, {response.text}") \ No newline at end of file + + # 打印请求信息以便调试 + print(f"请求 URL: {self.api_base}") + print(f"请求头: {headers}") + print(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...") + + try: + # 使用 aiohttp 进行异步请求 + async with aiohttp.ClientSession() as session: + async with session.post( + self.api_base, + data=json.dumps(payload), # 使用 data 而不是 json 参数 + headers=headers, + timeout=self.timeout + ) as response: + print(f"响应状态码: {response.status}") + + # 获取响应内容 + response_text = await response.text() + print(f"响应内容: {response_text[:200]}...") + + if response.status == 200: + resp_data = json.loads(response_text) + + if "data" in resp_data: + audio_data = base64.b64decode(resp_data["data"]) + + # 确保目录存在 + os.makedirs("data/temp", exist_ok=True) + + file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3" + + # 使用线程运行I/O操作,避免阻塞 + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: open(file_path, "wb").write(audio_data) + ) + + return file_path + else: + error_msg = resp_data.get("message", "未知错误") + raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}") + else: + raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}") + + except Exception as e: + # 添加更详细的异常捕获 + error_details = traceback.format_exc() + print(f"火山引擎 TTS 异常详情: {error_details}") + raise Exception(f"火山引擎 TTS 异常: {str(e)}") \ No newline at end of file