Fix AsyncClient

This commit is contained in:
NanoRocky
2025-05-11 01:54:44 +08:00
parent da4cd7fb65
commit 6d7c40eb76
@@ -28,7 +28,15 @@ class OTTSProvider:
self.last_sync_time = 0
self.timeout = Timeout(10.0)
self.retry_count = 3
self.client = None
async def __aenter__(self):
self.client = AsyncClient(timeout=self.timeout)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
async def _sync_time(self):
try:
@@ -63,9 +71,10 @@ class OTTSProvider:
"role": voice_params["role"],
"rate": voice_params["rate"],
"volume": voice_params["volume"]
},headers={
},
headers={
"User-Agent": f"AstrBot/{VERSION}",
"UAK": f"AstrBot/AzureTTS"
"UAK": "AstrBot/AzureTTS"
}
)
response.raise_for_status()
@@ -85,17 +94,11 @@ class AzureNativeProvider(TTSProvider):
self.subscription_key = provider_config["azure_tts_subscription_key"].strip()
if not re.fullmatch(r'^[a-zA-Z0-9]{32}$', self.subscription_key):
raise ValueError("无效的Azure订阅密钥")
self.region = provider_config.get("azure_tts_region", "eastus").strip()
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
self.client = AsyncClient(headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
})
self.client = None
self.token = None
self.token_expire = 0
self.voice_params = {
"voice": provider_config.get("azure_tts_voice", "zh-CN-YunxiaNeural"),
"style": provider_config.get("azure_tts_style", "cheerful"),
@@ -104,6 +107,18 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100")
}
async def __aenter__(self):
self.client = AsyncClient(headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
})
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
async def _refresh_token(self):
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
response = await self.client.post(
@@ -136,7 +151,7 @@ class AzureNativeProvider(TTSProvider):
headers={
"Authorization": f"Bearer {self.token}",
"User-Agent": f"AstrBot/{VERSION}"
}
}
)
response.raise_for_status()
file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -163,7 +178,6 @@ class AzureTTSProvider(TTSProvider):
required = {"OTTS_SKEY", "OTTS_URL", "OTTS_AUTH_TIME"}
if missing := required - otts_config.keys():
raise ValueError(f"缺少OTTS参数: {', '.join(missing)}")
return OTTSProvider(otts_config)
except json.JSONDecodeError as e:
error_msg = (
@@ -177,16 +191,20 @@ class AzureTTSProvider(TTSProvider):
if re.fullmatch(r'^[a-zA-Z0-9]{32}$', key_value):
return AzureNativeProvider(config, self.provider_settings)
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
async def get_audio(self, text: str) -> str:
if isinstance(self.provider, OTTSProvider):
return await self.provider.get_audio(
text,
{
"voice": self.provider_config.get("azure_tts_voice"),
"style": self.provider_config.get("azure_tts_style"),
"role": self.provider_config.get("azure_tts_role"),
"rate": self.provider_config.get("azure_tts_rate"),
"volume": self.provider_config.get("azure_tts_volume")
}
)
return await self.provider.get_audio(text)
async with self.provider as provider:
return await provider.get_audio(
text,
{
"voice": self.provider_config.get("azure_tts_voice"),
"style": self.provider_config.get("azure_tts_style"),
"role": self.provider_config.get("azure_tts_role"),
"rate": self.provider_config.get("azure_tts_rate"),
"volume": self.provider_config.get("azure_tts_volume")
}
)
else:
async with self.provider as provider:
return await provider.get_audio(text)