diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 70e90ea70..95ce4d3e0 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -28,7 +28,15 @@ class OTTSProvider: self.last_sync_time = 0 self.timeout = Timeout(10.0) self.retry_count = 3 + self.client = None + + async def __aenter__(self): self.client = AsyncClient(timeout=self.timeout) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.client: + await self.client.aclose() async def _sync_time(self): try: @@ -63,9 +71,10 @@ class OTTSProvider: "role": voice_params["role"], "rate": voice_params["rate"], "volume": voice_params["volume"] - },headers={ + }, + headers={ "User-Agent": f"AstrBot/{VERSION}", - "UAK": f"AstrBot/AzureTTS" + "UAK": "AstrBot/AzureTTS" } ) response.raise_for_status() @@ -85,17 +94,11 @@ class AzureNativeProvider(TTSProvider): self.subscription_key = provider_config["azure_tts_subscription_key"].strip() if not re.fullmatch(r'^[a-zA-Z0-9]{32}$', self.subscription_key): raise ValueError("无效的Azure订阅密钥") - self.region = provider_config.get("azure_tts_region", "eastus").strip() self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" - self.client = AsyncClient(headers={ - "User-Agent": f"AstrBot/{VERSION}", - "Content-Type": "application/ssml+xml", - "X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm" - }) + self.client = None self.token = None self.token_expire = 0 - self.voice_params = { "voice": provider_config.get("azure_tts_voice", "zh-CN-YunxiaNeural"), "style": provider_config.get("azure_tts_style", "cheerful"), @@ -104,6 +107,18 @@ class AzureNativeProvider(TTSProvider): "volume": provider_config.get("azure_tts_volume", "100") } + async def __aenter__(self): + self.client = AsyncClient(headers={ + "User-Agent": f"AstrBot/{VERSION}", + "Content-Type": "application/ssml+xml", + "X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm" + }) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.client: + await self.client.aclose() + async def _refresh_token(self): token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" response = await self.client.post( @@ -136,7 +151,7 @@ class AzureNativeProvider(TTSProvider): headers={ "Authorization": f"Bearer {self.token}", "User-Agent": f"AstrBot/{VERSION}" - } + } ) response.raise_for_status() file_path.parent.mkdir(parents=True, exist_ok=True) @@ -163,7 +178,6 @@ class AzureTTSProvider(TTSProvider): required = {"OTTS_SKEY", "OTTS_URL", "OTTS_AUTH_TIME"} if missing := required - otts_config.keys(): raise ValueError(f"缺少OTTS参数: {', '.join(missing)}") - return OTTSProvider(otts_config) except json.JSONDecodeError as e: error_msg = ( @@ -177,16 +191,20 @@ class AzureTTSProvider(TTSProvider): if re.fullmatch(r'^[a-zA-Z0-9]{32}$', key_value): return AzureNativeProvider(config, self.provider_settings) raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式") + async def get_audio(self, text: str) -> str: if isinstance(self.provider, OTTSProvider): - return await self.provider.get_audio( - text, - { - "voice": self.provider_config.get("azure_tts_voice"), - "style": self.provider_config.get("azure_tts_style"), - "role": self.provider_config.get("azure_tts_role"), - "rate": self.provider_config.get("azure_tts_rate"), - "volume": self.provider_config.get("azure_tts_volume") - } - ) - return await self.provider.get_audio(text) \ No newline at end of file + async with self.provider as provider: + return await provider.get_audio( + text, + { + "voice": self.provider_config.get("azure_tts_voice"), + "style": self.provider_config.get("azure_tts_style"), + "role": self.provider_config.get("azure_tts_role"), + "rate": self.provider_config.get("azure_tts_rate"), + "volume": self.provider_config.get("azure_tts_volume") + } + ) + else: + async with self.provider as provider: + return await provider.get_audio(text) \ No newline at end of file