diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 26d08ac75..04170b930 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -1,12 +1,10 @@ import json import os import uuid -from typing import Dict, Iterator, List, Union - -import requests - +import aiohttp +from typing import Dict, List, Union, AsyncIterator from astrbot.core.utils.astrbot_path import get_astrbot_data_path - +from astrbot.api import logger from ..entities import ProviderType from ..provider import TTSProvider from ..register import register_provider_adapter @@ -81,43 +79,54 @@ class ProviderMiniMaxTTSAPI(TTSProvider): return json.dumps(dict_body) - def _call_tts_stream(self, text: str) -> Iterator[bytes]: + async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: """进行流式请求""" try: - response = requests.post( - self.concat_base_url, - stream=True, - headers=self.headers, - data=self._build_tts_stream_body(text), - ) - response.raise_for_status() + async with aiohttp.ClientSession() as session: + async with session.post( + self.concat_base_url, + headers=self.headers, + data=self._build_tts_stream_body(text), + timeout=aiohttp.ClientTimeout(total=60), + ) as response: + response.raise_for_status() - for chunk in response.raw: - if not chunk or not chunk.startswith(b"data:"): - continue - data = json.loads(chunk[5:]) - if "extra_info" in data: - continue - audio = data.get("data", {}).get("audio") - if audio is not None: - yield audio + async for chunk in response.content.iter_any(): + if not chunk or not chunk.startswith(b"data:"): + logger.warning(f"Minimax TTS resp: {chunk}") + if "invalid api key" in chunk.decode("utf-8"): + raise Exception("MiniMax TTS: 无效的 API 密钥") + continue + try: + data = json.loads(chunk[5:]) + if "extra_info" in data: + continue + audio = data.get("data", {}).get("audio") + if audio is not None: + yield audio + except json.JSONDecodeError: + continue - except requests.exceptions.RequestException as e: + except aiohttp.ClientError as e: raise Exception(f"MiniMax TTS API请求失败: {str(e)}") - def _audio_play(self, audio_stream: Iterator[bytes]) -> bytes: - """解码数据流到audio比特流""" - return b"".join( - bytes.fromhex(chunk) for chunk in audio_stream if chunk and chunk != b"\n" - ) + async def _audio_play(self, audio_stream: AsyncIterator[bytes]) -> bytes: + """解码数据流到 audio 比特流""" + chunks = [] + async for chunk in audio_stream: + if chunk and chunk != b"\n": + chunks.append(bytes.fromhex(chunk.decode("utf-8"))) + return b"".join(chunks) async def get_audio(self, text: str) -> str: temp_dir = os.path.join(get_astrbot_data_path(), "temp") + os.makedirs(temp_dir, exist_ok=True) path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3") try: - audio_chunk_iterator = self._call_tts_stream(text) - audio = self._audio_play(audio_chunk_iterator) + # 直接将异步生成器传递给 _audio_play 方法 + audio_stream = self._call_tts_stream(text) + audio = await self._audio_play(audio_stream) # 结果保存至文件 with open(path, "wb") as file: @@ -125,5 +134,5 @@ class ProviderMiniMaxTTSAPI(TTSProvider): return path - except requests.exceptions.RequestException as e: + except aiohttp.ClientError as e: raise e