feat: 支持自部署 Whisper 模型
This commit is contained in:
@@ -323,13 +323,26 @@ CONFIG_METADATA_2 = {
|
||||
"whisper(API)": {
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_api",
|
||||
"enable": True,
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"whisper(本地加载)": {
|
||||
"whisper_hint": "(不用修改我)",
|
||||
"enable": False,
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
}
|
||||
},
|
||||
"items": {
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
|
||||
@@ -43,8 +43,9 @@ class PreProcessStage(Stage):
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError:
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
|
||||
@@ -13,12 +13,12 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not message:
|
||||
await web_chat_back_queue.put_nowait(None)
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
return
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await web_chat_back_queue.put_nowait(comp.text)
|
||||
web_chat_back_queue.put_nowait(comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
@@ -30,6 +30,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
await web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
|
||||
await web_chat_back_queue.put_nowait(None)
|
||||
web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -39,21 +39,29 @@ class ProviderManager():
|
||||
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。")
|
||||
self.loaded_ids[provider_cfg['id']] = True
|
||||
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify # noqa: F401
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
||||
|
||||
try:
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify # noqa: F401
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
|
||||
continue
|
||||
|
||||
async def initialize(self):
|
||||
for provider_config in self.providers_config:
|
||||
@@ -75,6 +83,10 @@ class ProviderManager():
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
|
||||
self.curr_stt_provider_inst = inst
|
||||
@@ -83,6 +95,10 @@ class ProviderManager():
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if selected_provider_id == provider_config['id'] and provider_enabled:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import whisper
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.model = None
|
||||
|
||||
async def initialize(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
||||
self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
|
||||
logger.info("Whisper 模型加载完成。")
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
filename = str(uuid.uuid4()) + '.mp3'
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
|
||||
import wave
|
||||
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(input_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def _convert_silk(self, path: str) -> str:
|
||||
import pysilk
|
||||
filename = str(uuid.uuid4()) + '.wav'
|
||||
output_path = os.path.join('data/temp', filename)
|
||||
with open(path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
# tencent 我爱你
|
||||
input_data = input_data[1:]
|
||||
input_io = io.BytesIO(input_data)
|
||||
output_io = io.BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
await self._pcm_to_wav(output_io, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
if audio_url.startswith("http"):
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
audio_url = await download_file(audio_url, path)
|
||||
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
audio_url = await self._convert_silk(audio_url)
|
||||
|
||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||
return result['text']
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .server import AstrBotDashboard
|
||||
@@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle:
|
||||
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
core_task = []
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
except Exception as e:
|
||||
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
|
||||
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
|
||||
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
|
||||
task = asyncio.gather(core_task, self.dashboard_server.run())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user