✨ feat: add-volcengine-tts-support
This commit is contained in:
@@ -803,13 +803,14 @@ CONFIG_METADATA_2 = {
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
"type": "volcengine_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"appid": "",
|
||||
"cluster": "",
|
||||
"voice_type": "xiaoyun",
|
||||
"voice_type": "",
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": "20",
|
||||
"timeout": 20,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
|
||||
@@ -206,6 +206,10 @@ class ProviderManager:
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import uuid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import requests
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
@@ -16,8 +20,11 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
self.appid = provider_config.get("appid", "")
|
||||
self.cluster = provider_config.get("cluster", "")
|
||||
self.voice_type = provider_config.get("voice_type", "xiaoyun")
|
||||
self.api_base = provider_config.get("api_base", "https://openspeech.bytedance.com/api/v1/tts")
|
||||
self.timeout = provider_config.get("timeout", "20")
|
||||
|
||||
host = "openspeech.bytedance.com"
|
||||
self.api_base = provider_config.get("api_base", f"https://{host}/api/v1/tts")
|
||||
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
|
||||
def _build_request_payload(self, text: str) -> dict:
|
||||
return {
|
||||
@@ -27,7 +34,7 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
"cluster": self.cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4())
|
||||
"uid": str(uuid.uuid4())
|
||||
},
|
||||
"audio": {
|
||||
"voice_type": self.voice_type,
|
||||
@@ -46,20 +53,62 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
}
|
||||
}
|
||||
|
||||
def get_audio(self, text: str) -> str:
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""异步方法获取语音文件路径"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer;{self.api_key}"
|
||||
}
|
||||
|
||||
payload = self._build_request_payload(text)
|
||||
response = requests.post(self.api_base, json=payload, headers=headers, timeout=self.timeout)
|
||||
|
||||
if response.status_code == 200:
|
||||
resp_data = response.json()
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
with open(file_path, "wb") as audio_file:
|
||||
audio_file.write(audio_data)
|
||||
return file_path
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 返回错误: {resp_data}")
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status_code}, {response.text}")
|
||||
|
||||
# 打印请求信息以便调试
|
||||
print(f"请求 URL: {self.api_base}")
|
||||
print(f"请求头: {headers}")
|
||||
print(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
|
||||
|
||||
try:
|
||||
# 使用 aiohttp 进行异步请求
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.api_base,
|
||||
data=json.dumps(payload), # 使用 data 而不是 json 参数
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
print(f"响应状态码: {response.status}")
|
||||
|
||||
# 获取响应内容
|
||||
response_text = await response.text()
|
||||
print(f"响应内容: {response_text[:200]}...")
|
||||
|
||||
if response.status == 200:
|
||||
resp_data = json.loads(response_text)
|
||||
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
|
||||
# 使用线程运行I/O操作,避免阻塞
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: open(file_path, "wb").write(audio_data)
|
||||
)
|
||||
|
||||
return file_path
|
||||
else:
|
||||
error_msg = resp_data.get("message", "未知错误")
|
||||
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
|
||||
|
||||
except Exception as e:
|
||||
# 添加更详细的异常捕获
|
||||
error_details = traceback.format_exc()
|
||||
print(f"火山引擎 TTS 异常详情: {error_details}")
|
||||
raise Exception(f"火山引擎 TTS 异常: {str(e)}")
|
||||
Reference in New Issue
Block a user