refactor: use aiohttp
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, Iterator, List, Union
|
||||
|
||||
import requests
|
||||
|
||||
import aiohttp
|
||||
from typing import Dict, List, Union, AsyncIterator
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from astrbot.api import logger
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
@@ -81,43 +79,54 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
|
||||
return json.dumps(dict_body)
|
||||
|
||||
def _call_tts_stream(self, text: str) -> Iterator[bytes]:
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
"""进行流式请求"""
|
||||
try:
|
||||
response = requests.post(
|
||||
self.concat_base_url,
|
||||
stream=True,
|
||||
headers=self.headers,
|
||||
data=self._build_tts_stream_body(text),
|
||||
)
|
||||
response.raise_for_status()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.concat_base_url,
|
||||
headers=self.headers,
|
||||
data=self._build_tts_stream_body(text),
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
for chunk in response.raw:
|
||||
if not chunk or not chunk.startswith(b"data:"):
|
||||
continue
|
||||
data = json.loads(chunk[5:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
if audio is not None:
|
||||
yield audio
|
||||
async for chunk in response.content.iter_any():
|
||||
if not chunk or not chunk.startswith(b"data:"):
|
||||
logger.warning(f"Minimax TTS resp: {chunk}")
|
||||
if "invalid api key" in chunk.decode("utf-8"):
|
||||
raise Exception("MiniMax TTS: 无效的 API 密钥")
|
||||
continue
|
||||
try:
|
||||
data = json.loads(chunk[5:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
if audio is not None:
|
||||
yield audio
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"MiniMax TTS API请求失败: {str(e)}")
|
||||
|
||||
def _audio_play(self, audio_stream: Iterator[bytes]) -> bytes:
|
||||
"""解码数据流到audio比特流"""
|
||||
return b"".join(
|
||||
bytes.fromhex(chunk) for chunk in audio_stream if chunk and chunk != b"\n"
|
||||
)
|
||||
async def _audio_play(self, audio_stream: AsyncIterator[bytes]) -> bytes:
|
||||
"""解码数据流到 audio 比特流"""
|
||||
chunks = []
|
||||
async for chunk in audio_stream:
|
||||
if chunk and chunk != b"\n":
|
||||
chunks.append(bytes.fromhex(chunk.decode("utf-8")))
|
||||
return b"".join(chunks)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
|
||||
|
||||
try:
|
||||
audio_chunk_iterator = self._call_tts_stream(text)
|
||||
audio = self._audio_play(audio_chunk_iterator)
|
||||
# 直接将异步生成器传递给 _audio_play 方法
|
||||
audio_stream = self._call_tts_stream(text)
|
||||
audio = await self._audio_play(audio_stream)
|
||||
|
||||
# 结果保存至文件
|
||||
with open(path, "wb") as file:
|
||||
@@ -125,5 +134,5 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
|
||||
return path
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
except aiohttp.ClientError as e:
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user