feat: add Xinference STT provider (#3197)

* feat: add Xinference STT provider

* chore:update comment in xinference_stt_provider

* style: ruff format xinference_stt_provider

* chore: remove unused import of base64 in xinference_stt_provider

* fix: enhance model initialization check in get_text method

---------

Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
RC-CHN
2025-10-31 01:49:35 +08:00
committed by GitHub
parent 2ce653caad
commit bc7f01ba36
3 changed files with 203 additions and 0 deletions
+12
View File
@@ -1274,6 +1274,18 @@ CONFIG_METADATA_2 = {
"timeout": 20,
"launch_model_if_not_running": False,
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
"provider": "xinference",
"provider_type": "speech_to_text",
"enable": False,
"api_key": "",
"api_base": "http://127.0.0.1:9997",
"model": "whisper-large-v3",
"timeout": 180,
"launch_model_if_not_running": False,
},
},
"items": {
"rerank_api_base": {
+4
View File
@@ -259,6 +259,10 @@ class ProviderManager:
from .sources.whisper_selfhosted_source import (
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
)
case "xinference_stt":
from .sources.xinference_stt_provider import (
ProviderXinferenceSTT as ProviderXinferenceSTT,
)
case "openai_tts_api":
from .sources.openai_tts_api_source import (
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
@@ -0,0 +1,187 @@
import uuid
import os
import aiohttp
from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client,
)
from ..provider import STTProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot.core import logger
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@register_provider_adapter(
"xinference_stt",
"Xinference STT",
provider_type=ProviderType.SPEECH_TO_TEXT,
)
class ProviderXinferenceSTT(STTProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
self.base_url = provider_config.get("api_base", "http://127.0.0.1:9997")
self.base_url = self.base_url.rstrip("/")
self.timeout = provider_config.get("timeout", 180)
self.model_name = provider_config.get("model", "whisper-large-v3")
self.api_key = provider_config.get("api_key")
self.launch_model_if_not_running = provider_config.get(
"launch_model_if_not_running", False
)
self.client = None
self.model_uid = None
async def initialize(self):
if self.api_key:
logger.info("Xinference STT: Using API key for authentication.")
self.client = Client(self.base_url, api_key=self.api_key)
else:
logger.info("Xinference STT: No API key provided.")
self.client = Client(self.base_url)
try:
running_models = await self.client.list_models()
for uid, model_spec in running_models.items():
if model_spec.get("model_name") == self.model_name:
logger.info(
f"Model '{self.model_name}' is already running with UID: {uid}"
)
self.model_uid = uid
break
if self.model_uid is None:
if self.launch_model_if_not_running:
logger.info(f"Launching {self.model_name} model...")
self.model_uid = await self.client.launch_model(
model_name=self.model_name, model_type="audio"
)
logger.info("Model launched.")
else:
logger.warning(
f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available."
)
return
except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}")
logger.debug(
f"Xinference initialization failed with exception: {e}", exc_info=True
)
async def get_text(self, audio_url: str) -> str:
if not self.model_uid or self.client is None or self.client.session is None:
logger.error("Xinference STT model is not initialized.")
return ""
audio_bytes = None
temp_files = []
is_tencent = False
try:
# 1. Get audio bytes
if audio_url.startswith("http"):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
async with aiohttp.ClientSession() as session:
async with session.get(audio_url, timeout=self.timeout) as resp:
if resp.status == 200:
audio_bytes = await resp.read()
else:
logger.error(
f"Failed to download audio from {audio_url}, status: {resp.status}"
)
return ""
else:
if os.path.exists(audio_url):
with open(audio_url, "rb") as f:
audio_bytes = f.read()
else:
logger.error(f"File not found: {audio_url}")
return ""
if not audio_bytes:
logger.error("Audio bytes are empty.")
return ""
# 2. Check for conversion
needs_conversion = False
if (
audio_url.endswith((".amr", ".silk"))
or is_tencent
or b"SILK" in audio_bytes[:8]
):
needs_conversion = True
# 3. Perform conversion if needed
if needs_conversion:
logger.info("Audio requires conversion, using temporary files...")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
input_path = os.path.join(temp_dir, str(uuid.uuid4()))
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
temp_files.extend([input_path, output_path])
with open(input_path, "wb") as f:
f.write(audio_bytes)
logger.info("Converting silk/amr file to wav ...")
await tencent_silk_to_wav(input_path, output_path)
with open(output_path, "rb") as f:
audio_bytes = f.read()
# 4. Transcribe
# 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来
url = f"{self.base_url}/v1/audio/transcriptions"
headers = {
"accept": "application/json",
}
if self.client and self.client._headers:
headers.update(self.client._headers)
data = aiohttp.FormData()
data.add_field("model", self.model_uid)
data.add_field(
"file", audio_bytes, filename="audio.wav", content_type="audio/wav"
)
async with self.client.session.post(
url, data=data, headers=headers, timeout=self.timeout
) as resp:
if resp.status == 200:
result = await resp.json()
text = result.get("text", "")
logger.debug(f"Xinference STT result: {text}")
return text
else:
error_text = await resp.text()
logger.error(
f"Xinference STT transcription failed with status {resp.status}: {error_text}"
)
return ""
except Exception as e:
logger.error(f"Xinference STT failed: {e}")
logger.debug(f"Xinference STT failed with exception: {e}", exc_info=True)
return ""
finally:
# 5. Cleanup
for temp_file in temp_files:
try:
if os.path.exists(temp_file):
os.remove(temp_file)
logger.debug(f"Removed temporary file: {temp_file}")
except Exception as e:
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
async def terminate(self) -> None:
"""关闭客户端会话"""
if self.client:
logger.info("Closing Xinference STT client...")
try:
await self.client.close()
except Exception as e:
logger.error(f"Failed to close Xinference client: {e}", exc_info=True)