feat: add-volcengine-tts-support

This commit is contained in:
YOO_koishi
2025-05-18 03:18:36 +08:00
parent 6439917cbe
commit db13a60274
3 changed files with 75 additions and 21 deletions
+3 -2
View File
@@ -803,13 +803,14 @@ CONFIG_METADATA_2 = {
"火山引擎_TTS(API)": {
"id": "volcengine_tts",
"type": "volcengine_tts",
"provider_type": "text_to_speech",
"enable": False,
"api_key": "",
"appid": "",
"cluster": "",
"voice_type": "xiaoyun",
"voice_type": "",
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
"timeout": "20",
"timeout": 20,
},
},
"items": {
+4
View File
@@ -206,6 +206,10 @@ class ProviderManager:
from .sources.azure_tts_source import (
AzureTTSProvider as AzureTTSProvider,
)
case "volcengine_tts":
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
+68 -19
View File
@@ -1,6 +1,10 @@
import uuid
import base64
import json
import os
import traceback
import asyncio
import aiohttp
import requests
from ..provider import TTSProvider
from ..entities import ProviderType
@@ -16,8 +20,11 @@ class ProviderVolcengineTTS(TTSProvider):
self.appid = provider_config.get("appid", "")
self.cluster = provider_config.get("cluster", "")
self.voice_type = provider_config.get("voice_type", "xiaoyun")
self.api_base = provider_config.get("api_base", "https://openspeech.bytedance.com/api/v1/tts")
self.timeout = provider_config.get("timeout", "20")
host = "openspeech.bytedance.com"
self.api_base = provider_config.get("api_base", f"https://{host}/api/v1/tts")
self.timeout = provider_config.get("timeout", 20)
def _build_request_payload(self, text: str) -> dict:
return {
@@ -27,7 +34,7 @@ class ProviderVolcengineTTS(TTSProvider):
"cluster": self.cluster
},
"user": {
"uid": str(uuid.uuid4())
"uid": str(uuid.uuid4())
},
"audio": {
"voice_type": self.voice_type,
@@ -46,20 +53,62 @@ class ProviderVolcengineTTS(TTSProvider):
}
}
def get_audio(self, text: str) -> str:
headers = {"Authorization": f"Bearer {self.api_key}"}
async def get_audio(self, text: str) -> str:
"""异步方法获取语音文件路径"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer;{self.api_key}"
}
payload = self._build_request_payload(text)
response = requests.post(self.api_base, json=payload, headers=headers, timeout=self.timeout)
if response.status_code == 200:
resp_data = response.json()
if "data" in resp_data:
audio_data = base64.b64decode(resp_data["data"])
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
with open(file_path, "wb") as audio_file:
audio_file.write(audio_data)
return file_path
else:
raise Exception(f"火山引擎 TTS API 返回错误: {resp_data}")
else:
raise Exception(f"火山引擎 TTS API 请求失败: {response.status_code}, {response.text}")
# 打印请求信息以便调试
print(f"请求 URL: {self.api_base}")
print(f"请求头: {headers}")
print(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
try:
# 使用 aiohttp 进行异步请求
async with aiohttp.ClientSession() as session:
async with session.post(
self.api_base,
data=json.dumps(payload), # 使用 data 而不是 json 参数
headers=headers,
timeout=self.timeout
) as response:
print(f"响应状态码: {response.status}")
# 获取响应内容
response_text = await response.text()
print(f"响应内容: {response_text[:200]}...")
if response.status == 200:
resp_data = json.loads(response_text)
if "data" in resp_data:
audio_data = base64.b64decode(resp_data["data"])
# 确保目录存在
os.makedirs("data/temp", exist_ok=True)
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
# 使用线程运行I/O操作,避免阻塞
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: open(file_path, "wb").write(audio_data)
)
return file_path
else:
error_msg = resp_data.get("message", "未知错误")
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
else:
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
except Exception as e:
# 添加更详细的异常捕获
error_details = traceback.format_exc()
print(f"火山引擎 TTS 异常详情: {error_details}")
raise Exception(f"火山引擎 TTS 异常: {str(e)}")