From 72702beb0bb6f01a37d832d60267c59c11b44222 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Wed, 2 Jul 2025 10:29:10 +0800 Subject: [PATCH] chore: clean code --- .../core/provider/sources/azure_tts_source.py | 48 ++++++++++------- astrbot/core/provider/sources/dify_source.py | 16 +++--- .../core/provider/sources/gemini_source.py | 20 +++---- .../core/provider/sources/openai_source.py | 9 ++-- .../core/provider/sources/volcengine_tts.py | 53 ++++++++++--------- 5 files changed, 76 insertions(+), 70 deletions(-) diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index c35c7ec6c..6ddf452d4 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -19,6 +19,7 @@ from ..register import register_provider_adapter TEMP_DIR = Path("data/temp/azure_tts") TEMP_DIR.mkdir(parents=True, exist_ok=True) + class OTTSProvider: def __init__(self, config: Dict): self.skey = config["OTTS_SKEY"] @@ -70,12 +71,12 @@ class OTTSProvider: "style": voice_params["style"], "role": voice_params["role"], "rate": voice_params["rate"], - "volume": voice_params["volume"] + "volume": voice_params["volume"], }, headers={ "User-Agent": f"AstrBot/{VERSION}", - "UAK": "AstrBot/AzureTTS" - } + "UAK": "AstrBot/AzureTTS", + }, ) response.raise_for_status() file_path.parent.mkdir(parents=True, exist_ok=True) @@ -88,14 +89,19 @@ class OTTSProvider: raise RuntimeError(f"OTTS请求失败: {str(e)}") from e await asyncio.sleep(0.5 * (attempt + 1)) + class AzureNativeProvider(TTSProvider): def __init__(self, provider_config: dict, provider_settings: dict): super().__init__(provider_config, provider_settings) - self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip() + self.subscription_key = provider_config.get( + "azure_tts_subscription_key", "" + ).strip() if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key): raise ValueError("无效的Azure订阅密钥") self.region = provider_config.get("azure_tts_region", "eastus").strip() - self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" + self.endpoint = ( + f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" + ) self.client = None self.token = None self.token_expire = 0 @@ -104,15 +110,17 @@ class AzureNativeProvider(TTSProvider): "style": provider_config.get("azure_tts_style", "cheerful"), "role": provider_config.get("azure_tts_role", "Boy"), "rate": provider_config.get("azure_tts_rate", "1"), - "volume": provider_config.get("azure_tts_volume", "100") + "volume": provider_config.get("azure_tts_volume", "100"), } async def __aenter__(self): - self.client = AsyncClient(headers={ - "User-Agent": f"AstrBot/{VERSION}", - "Content-Type": "application/ssml+xml", - "X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm" - }) + self.client = AsyncClient( + headers={ + "User-Agent": f"AstrBot/{VERSION}", + "Content-Type": "application/ssml+xml", + "X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm", + } + ) return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -120,10 +128,11 @@ class AzureNativeProvider(TTSProvider): await self.client.aclose() async def _refresh_token(self): - token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" + token_url = ( + f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" + ) response = await self.client.post( - token_url, - headers={"Ocp-Apim-Subscription-Key": self.subscription_key} + token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key} ) response.raise_for_status() self.token = response.text @@ -150,8 +159,8 @@ class AzureNativeProvider(TTSProvider): content=ssml, headers={ "Authorization": f"Bearer {self.token}", - "User-Agent": f"AstrBot/{VERSION}" - } + "User-Agent": f"AstrBot/{VERSION}", + }, ) response.raise_for_status() file_path.parent.mkdir(parents=True, exist_ok=True) @@ -160,6 +169,7 @@ class AzureNativeProvider(TTSProvider): f.write(chunk) return str(file_path.resolve()) + @register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH) class AzureTTSProvider(TTSProvider): def __init__(self, provider_config: dict, provider_settings: dict): @@ -183,7 +193,7 @@ class AzureTTSProvider(TTSProvider): error_msg = ( f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n" f"错误详情: {e.msg}\n" - f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}" + f"错误上下文: {json_str[max(0, e.pos - 30) : e.pos + 30]}" ) raise ValueError(error_msg) from e except KeyError as e: @@ -202,8 +212,8 @@ class AzureTTSProvider(TTSProvider): "style": self.provider_config.get("azure_tts_style"), "role": self.provider_config.get("azure_tts_role"), "rate": self.provider_config.get("azure_tts_rate"), - "volume": self.provider_config.get("azure_tts_volume") - } + "volume": self.provider_config.get("azure_tts_volume"), + }, ) else: async with self.provider as provider: diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 81c910d66..b3a0ccccf 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -18,7 +18,7 @@ class ProviderDify(Provider): self, provider_config, provider_settings, - default_persona = None, + default_persona=None, ) -> None: super().__init__( provider_config, @@ -65,7 +65,7 @@ class ProviderDify(Provider): if image_urls is None: image_urls = [] result = "" - session_id = session_id or kwargs.get("user") # 1734 + session_id = session_id or kwargs.get("user") # 1734 conversation_id = self.conversation_ids.get(session_id, "") files_payload = [] @@ -84,13 +84,11 @@ class ProviderDify(Provider): f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" ) continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - } - ) + files_payload.append({ + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + }) # 获得会话变量 payload_vars = self.variables.copy() diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index d67dd2a94..573fe7684 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -259,12 +259,10 @@ class ProviderGoogleGenAI(Provider): contents.append(content_cls(parts=part)) gemini_contents: list[types.Content] = [] - native_tool_enabled = any( - [ - self.provider_config.get("gm_native_coderunner", False), - self.provider_config.get("gm_native_search", False), - ] - ) + native_tool_enabled = any([ + self.provider_config.get("gm_native_coderunner", False), + self.provider_config.get("gm_native_search", False), + ]) for message in payloads["messages"]: role, content = message["role"], message.get("content") @@ -634,12 +632,10 @@ class ProviderGoogleGenAI(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append( - { - "type": "image_url", - "image_url": {"url": image_data}, - } - ) + user_content["content"].append({ + "type": "image_url", + "image_url": {"url": image_data}, + }) return user_content else: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index ef6131d8c..936fc2e34 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -30,7 +30,7 @@ class ProviderOpenAIOfficial(Provider): self, provider_config, provider_settings, - default_persona = None, + default_persona=None, ) -> None: super().__init__( provider_config, @@ -525,9 +525,10 @@ class ProviderOpenAIOfficial(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append( - {"type": "image_url", "image_url": {"url": image_data}} - ) + user_content["content"].append({ + "type": "image_url", + "image_url": {"url": image_data}, + }) return user_content else: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/volcengine_tts.py b/astrbot/core/provider/sources/volcengine_tts.py index dca0196b1..12e7ed9cd 100644 --- a/astrbot/core/provider/sources/volcengine_tts.py +++ b/astrbot/core/provider/sources/volcengine_tts.py @@ -5,12 +5,12 @@ import os import traceback import asyncio import aiohttp -import requests from ..provider import TTSProvider from ..entities import ProviderType from ..register import register_provider_adapter from astrbot import logger + @register_provider_adapter( "volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH ) @@ -22,7 +22,9 @@ class ProviderVolcengineTTS(TTSProvider): self.cluster = provider_config.get("volcengine_cluster", "") self.voice_type = provider_config.get("volcengine_voice_type", "") self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0) - self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts") + self.api_base = provider_config.get( + "api_base", "https://openspeech.bytedance.com/api/v1/tts" + ) self.timeout = provider_config.get("timeout", 20) def _build_request_payload(self, text: str) -> dict: @@ -30,11 +32,9 @@ class ProviderVolcengineTTS(TTSProvider): "app": { "appid": self.appid, "token": self.api_key, - "cluster": self.cluster - }, - "user": { - "uid": str(uuid.uuid4()) + "cluster": self.cluster, }, + "user": {"uid": str(uuid.uuid4())}, "audio": { "voice_type": self.voice_type, "encoding": "mp3", @@ -48,60 +48,61 @@ class ProviderVolcengineTTS(TTSProvider): "text_type": "plain", "operation": "query", "with_frontend": 1, - "frontend_type": "unitTson" - } + "frontend_type": "unitTson", + }, } async def get_audio(self, text: str) -> str: """异步方法获取语音文件路径""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer; {self.api_key}" + "Authorization": f"Bearer; {self.api_key}", } - + payload = self._build_request_payload(text) - + logger.debug(f"请求头: {headers}") logger.debug(f"请求 URL: {self.api_base}") logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...") - + try: async with aiohttp.ClientSession() as session: async with session.post( self.api_base, - data=json.dumps(payload), + data=json.dumps(payload), headers=headers, - timeout=self.timeout + timeout=self.timeout, ) as response: logger.debug(f"响应状态码: {response.status}") - + response_text = await response.text() logger.debug(f"响应内容: {response_text[:200]}...") - + if response.status == 200: resp_data = json.loads(response_text) - + if "data" in resp_data: audio_data = base64.b64decode(resp_data["data"]) - + os.makedirs("data/temp", exist_ok=True) - + file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3" - + loop = asyncio.get_running_loop() await loop.run_in_executor( - None, - lambda: open(file_path, "wb").write(audio_data) + None, lambda: open(file_path, "wb").write(audio_data) ) - + return file_path else: error_msg = resp_data.get("message", "未知错误") raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}") else: - raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}") - + raise Exception( + f"火山引擎 TTS API 请求失败: {response.status}, {response_text}" + ) + except Exception as e: error_details = traceback.format_exc() logger.debug(f"火山引擎 TTS 异常详情: {error_details}") - raise Exception(f"火山引擎 TTS 异常: {str(e)}") \ No newline at end of file + raise Exception(f"火山引擎 TTS 异常: {str(e)}")