refactor: use aiohttp

This commit is contained in:
Soulter
2025-05-16 11:04:01 +08:00
parent 6723fe8271
commit c6eaf3d010
@@ -1,12 +1,10 @@
import json
import os
import uuid
from typing import Dict, Iterator, List, Union
import requests
import aiohttp
from typing import Dict, List, Union, AsyncIterator
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.api import logger
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
@@ -81,43 +79,54 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return json.dumps(dict_body)
def _call_tts_stream(self, text: str) -> Iterator[bytes]:
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
"""进行流式请求"""
try:
response = requests.post(
self.concat_base_url,
stream=True,
headers=self.headers,
data=self._build_tts_stream_body(text),
)
response.raise_for_status()
async with aiohttp.ClientSession() as session:
async with session.post(
self.concat_base_url,
headers=self.headers,
data=self._build_tts_stream_body(text),
timeout=aiohttp.ClientTimeout(total=60),
) as response:
response.raise_for_status()
for chunk in response.raw:
if not chunk or not chunk.startswith(b"data:"):
continue
data = json.loads(chunk[5:])
if "extra_info" in data:
continue
audio = data.get("data", {}).get("audio")
if audio is not None:
yield audio
async for chunk in response.content.iter_any():
if not chunk or not chunk.startswith(b"data:"):
logger.warning(f"Minimax TTS resp: {chunk}")
if "invalid api key" in chunk.decode("utf-8"):
raise Exception("MiniMax TTS: 无效的 API 密钥")
continue
try:
data = json.loads(chunk[5:])
if "extra_info" in data:
continue
audio = data.get("data", {}).get("audio")
if audio is not None:
yield audio
except json.JSONDecodeError:
continue
except requests.exceptions.RequestException as e:
except aiohttp.ClientError as e:
raise Exception(f"MiniMax TTS API请求失败: {str(e)}")
def _audio_play(self, audio_stream: Iterator[bytes]) -> bytes:
"""解码数据流到audio比特流"""
return b"".join(
bytes.fromhex(chunk) for chunk in audio_stream if chunk and chunk != b"\n"
)
async def _audio_play(self, audio_stream: AsyncIterator[bytes]) -> bytes:
"""解码数据流到 audio 比特流"""
chunks = []
async for chunk in audio_stream:
if chunk and chunk != b"\n":
chunks.append(bytes.fromhex(chunk.decode("utf-8")))
return b"".join(chunks)
async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
try:
audio_chunk_iterator = self._call_tts_stream(text)
audio = self._audio_play(audio_chunk_iterator)
# 直接将异步生成器传递给 _audio_play 方法
audio_stream = self._call_tts_stream(text)
audio = await self._audio_play(audio_stream)
# 结果保存至文件
with open(path, "wb") as file:
@@ -125,5 +134,5 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return path
except requests.exceptions.RequestException as e:
except aiohttp.ClientError as e:
raise e