From ba198490fa1f619c08c673dbc598dfd5a0f1aa1e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 11 Jan 2025 20:31:21 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=87=AA=E9=83=A8?= =?UTF-8?q?=E7=BD=B2=20Whisper=20=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 15 ++- .../core/pipeline/preprocess_stage/stage.py | 3 +- .../platform/sources/webchat/webchat_event.py | 8 +- astrbot/core/provider/manager.py | 46 +++++++--- .../sources/whisper_selfhosted_source.py | 92 +++++++++++++++++++ astrbot/dashboard/dashboard_lifecycle.py | 13 ++- 6 files changed, 154 insertions(+), 23 deletions(-) create mode 100644 astrbot/core/provider/sources/whisper_selfhosted_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 983ca3cd3..f9e2a3485 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -323,13 +323,26 @@ CONFIG_METADATA_2 = { "whisper(API)": { "id": "whisper", "type": "openai_whisper_api", - "enable": True, + "enable": False, "api_key": "", "api_base": "", "model": "whisper-1", + }, + "whisper(本地加载)": { + "whisper_hint": "(不用修改我)", + "enable": False, + "id": "whisper", + "type": "openai_whisper_selfhost", + "model": "tiny", } }, "items": { + "whisper_hint": { + "description": "本地部署 Whisper 模型须知", + "type": "string", + "hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。", + "obvious_hint": True + }, "id": { "description": "ID", "type": "string", diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index a28e15485..455fd35c8 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -43,8 +43,9 @@ class PreProcessStage(Stage): event.message_str += result event.message_obj.message_str += result break - except FileNotFoundError: + except FileNotFoundError as e: # napcat workaround + logger.warning(e) logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}") await asyncio.sleep(0.5) continue diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 4312b0cb1..0ef57ed5f 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -13,12 +13,12 @@ class WebChatMessageEvent(AstrMessageEvent): async def send(self, message: MessageChain): if not message: - await web_chat_back_queue.put_nowait(None) + web_chat_back_queue.put_nowait(None) return for comp in message.chain: if isinstance(comp, Plain): - await web_chat_back_queue.put_nowait(comp.text) + web_chat_back_queue.put_nowait(comp.text) elif isinstance(comp, Image): # save image to local filename = str(uuid.uuid4()) + ".jpg" @@ -30,6 +30,6 @@ class WebChatMessageEvent(AstrMessageEvent): f.write(f2.read()) elif comp.file and comp.file.startswith("http"): await download_image_by_url(comp.file, path=path) - await web_chat_back_queue.put_nowait(f"[IMAGE]{filename}") - await web_chat_back_queue.put_nowait(None) + web_chat_back_queue.put_nowait(f"[IMAGE]{filename}") + web_chat_back_queue.put_nowait(None) await super().send(message) \ No newline at end of file diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7523960ef..3b64126f4 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -39,21 +39,29 @@ class ProviderManager(): raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。") self.loaded_ids[provider_cfg['id']] = True - match provider_cfg['type']: - case "openai_chat_completion": - from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401 - case "zhipu_chat_completion": - from .sources.zhipu_source import ProviderZhipu # noqa: F401 - case "llm_tuner": - logger.info("加载 LLM Tuner 工具 ...") - from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401 - case "dify": - from .sources.dify_source import ProviderDify # noqa: F401 - case "googlegenai_chat_completion": - from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401 - case "openai_whisper_api": - from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401 - + try: + match provider_cfg['type']: + case "openai_chat_completion": + from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401 + case "zhipu_chat_completion": + from .sources.zhipu_source import ProviderZhipu # noqa: F401 + case "llm_tuner": + logger.info("加载 LLM Tuner 工具 ...") + from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401 + case "dify": + from .sources.dify_source import ProviderDify # noqa: F401 + case "googlegenai_chat_completion": + from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401 + case "openai_whisper_api": + from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401 + case "openai_whisper_selfhost": + from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401 + except (ImportError, ModuleNotFoundError) as e: + logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") + continue + except Exception as e: + logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因") + continue async def initialize(self): for provider_config in self.providers_config: @@ -75,6 +83,10 @@ class ProviderManager(): if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: # STT 任务 inst = provider_metadata.cls_type(provider_config, self.provider_settings) + + if getattr(inst, "initialize", None): + await inst.initialize() + self.stt_provider_insts.append(inst) if selected_stt_provider_id == provider_config['id'] and stt_enabled: self.curr_stt_provider_inst = inst @@ -83,6 +95,10 @@ class ProviderManager(): elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: # 文本生成任务 inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) + + if getattr(inst, "initialize", None): + await inst.initialize() + self.provider_insts.append(inst) if selected_provider_id == provider_config['id'] and provider_enabled: self.curr_provider_inst = inst diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py new file mode 100644 index 000000000..6f16d559a --- /dev/null +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -0,0 +1,92 @@ +import uuid +import os +import io +import asyncio +import whisper +from ..provider import STTProvider +from ..entites import ProviderType +from astrbot.core.utils.io import download_file +from ..register import register_provider_adapter +from astrbot.core import logger + + +@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT) +class ProviderOpenAIWhisperSelfHost(STTProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.set_model(provider_config.get("model", None)) + self.model = None + + async def initialize(self): + loop = asyncio.get_event_loop() + logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") + self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name) + logger.info("Whisper 模型加载完成。") + + async def _convert_audio(self, path: str) -> str: + from pyffmpeg import FFmpeg + filename = str(uuid.uuid4()) + '.mp3' + ff = FFmpeg() + output_path = ff.convert(path, os.path.join('data/temp', filename)) + return output_path + + async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str: + import wave + + with wave.open(output_path, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(input_io.read()) + + return output_path + + async def _convert_silk(self, path: str) -> str: + import pysilk + filename = str(uuid.uuid4()) + '.wav' + output_path = os.path.join('data/temp', filename) + with open(path, "rb") as f: + input_data = f.read() + if input_data.startswith(b'\x02'): + # tencent 我爱你 + input_data = input_data[1:] + input_io = io.BytesIO(input_data) + output_io = io.BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + await self._pcm_to_wav(output_io, output_path) + + return output_path + + async def _is_silk_file(self, file_path): + silk_header = b"SILK" + with open(file_path, "rb") as f: + file_header = f.read(8) + + if silk_header in file_header: + return True + else: + return False + + async def get_text(self, audio_url: str) -> str: + loop = asyncio.get_event_loop() + if audio_url.startswith("http"): + name = str(uuid.uuid4()) + path = os.path.join("data/temp", name) + audio_url = await download_file(audio_url, path) + + if not os.path.exists(audio_url): + raise FileNotFoundError(f"文件不存在: {audio_url}") + + if audio_url.endswith(".amr") or audio_url.endswith(".silk"): + is_silk = await self._is_silk_file(audio_url) + if is_silk: + logger.info("Converting silk file to wav ...") + audio_url = await self._convert_silk(audio_url) + + result = await loop.run_in_executor(None, self.model.transcribe, audio_url) + return result['text'] \ No newline at end of file diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/dashboard/dashboard_lifecycle.py index 176ca4dc1..b363ae3a7 100644 --- a/astrbot/dashboard/dashboard_lifecycle.py +++ b/astrbot/dashboard/dashboard_lifecycle.py @@ -1,4 +1,5 @@ import asyncio +import traceback from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from .server import AstrBotDashboard @@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle: async def start(self): core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) - await core_lifecycle.initialize() - core_task = core_lifecycle.start() + + core_task = [] + try: + await core_lifecycle.initialize() + core_task = core_lifecycle.start() + except Exception as e: + logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") + logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") + logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") + self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db) task = asyncio.gather(core_task, self.dashboard_server.run())