diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 377f8d4b3..17e379478 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,2 +1,2 @@ -from astrbot.core.provider import Provider, Personality, ProviderMetaData -from astrbot.core.provider.entites import ProviderRequest \ No newline at end of file +from astrbot.core.provider import Provider, STTProvider, Personality +from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData \ No newline at end of file diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 0ef8e039d..aac8fc117 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -20,6 +20,6 @@ if os.environ.get('TESTING', ""): db_helper = SQLiteDatabase(DB_PATH) sp = SharedPreferences() # 简单的偏好设置存储 pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', '')) -web_chat_queue = asyncio.Queue() -web_chat_back_queue = asyncio.Queue() +web_chat_queue = asyncio.Queue(maxsize=32) +web_chat_back_queue = asyncio.Queue(maxsize=32) WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 43a010d45..f9e2a3485 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -33,6 +33,10 @@ DEFAULT_CONFIG = { "default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。", "prompt_prefix": "", }, + "provider_stt_settings": { + "enable": False, + "provider_id": "", + }, "content_safety": { "internal_keywords": {"enable": True, "extra_keywords": []}, "baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""}, @@ -315,9 +319,30 @@ CONFIG_METADATA_2 = { "dify_api_key": "", "dify_api_base": "https://api.dify.ai/v1", "dify_workflow_output_key": "", + }, + "whisper(API)": { + "id": "whisper", + "type": "openai_whisper_api", + "enable": False, + "api_key": "", + "api_base": "", + "model": "whisper-1", + }, + "whisper(本地加载)": { + "whisper_hint": "(不用修改我)", + "enable": False, + "id": "whisper", + "type": "openai_whisper_selfhost", + "model": "tiny", } }, "items": { + "whisper_hint": { + "description": "本地部署 Whisper 模型须知", + "type": "string", + "hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。", + "obvious_hint": True + }, "id": { "description": "ID", "type": "string", @@ -416,7 +441,8 @@ CONFIG_METADATA_2 = { "enable": { "description": "启用大语言模型聊天", "type": "bool", - "hint": "是否启用大语言模型聊天。默认启用", + "hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。", + "obvious_hint": True }, "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", @@ -450,6 +476,23 @@ CONFIG_METADATA_2 = { }, }, }, + "provider_stt_settings": { + "description": "语音转文本(STT)", + "type": "object", + "items": { + "enable": { + "description": "启用语音转文本(STT)", + "type": "bool", + "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。", + "obvious_hint": True + }, + "provider_id": { + "description": "提供商 ID,不填则默认第一个STT提供商", + "type": "string", + "hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。", + }, + }, + }, }, }, "misc_config_group": { diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 50e81c86d..74d90c535 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -123,7 +123,7 @@ class Record(BaseMessageComponent): proxy: T.Optional[bool] = True timeout: T.Optional[int] = 0 # 额外 - path: T.Optional[str] + path: T.Optional[str] # 用这个 def __init__(self, file: T.Optional[str], **_): for k in _.keys(): diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 108b1f134..bdba27699 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -3,6 +3,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR from .waking_check.stage import WakingCheckStage from .whitelist_check.stage import WhitelistCheckStage from .content_safety_check.stage import ContentSafetyCheckStage +from .preprocess_stage.stage import PreProcessStage from .process_stage.stage import ProcessStage from .result_decorate.stage import ResultDecorateStage from .respond.stage import RespondStage @@ -12,6 +13,7 @@ STAGES_ORDER = [ "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 "RateLimitCheckStage", # 检查会话是否超过频率限制 "ContentSafetyCheckStage", # 检查内容安全 + "PreProcessStage", # 预处理 "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 "RespondStage" # 发送消息 @@ -21,6 +23,7 @@ __all__ = [ "WakingCheckStage", "WhitelistCheckStage", "ContentSafetyCheckStage", + "PreProcessStage", "ProcessStage", "ResultDecorateStage", "RespondStage", diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py new file mode 100644 index 000000000..455fd35c8 --- /dev/null +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -0,0 +1,55 @@ +import traceback +import asyncio +from typing import Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core import logger +from astrbot.core.message.components import Plain, Record + +@register_stage +class PreProcessStage(Stage): + + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.config = ctx.astrbot_config + self.plugin_manager = ctx.plugin_manager + + self.stt_settings: dict = self.config.get('provider_stt_settings', {}) + + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + '''在处理事件之前的预处理''' + + if self.stt_settings.get('enable', False): + # STT 处理 + # TODO: 独立 + stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst + if stt_provider: + message_chain = event.get_messages() + for idx, component in enumerate(message_chain): + if isinstance(component, Record) and component.path: + + path = component.path + + retry = 5 + + for i in range(retry): + try: + result = await stt_provider.get_text(audio_url=path) + if result: + logger.info("语音转文本结果: " + result) + message_chain[idx] = Plain(result) + event.message_str += result + event.message_obj.message_str += result + break + except FileNotFoundError as e: + # napcat workaround + logger.warning(e) + logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}") + await asyncio.sleep(0.5) + continue + except BaseException as e: + logger.error(traceback.format_exc()) + logger.error(f"语音转文本失败: {e}") + break diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 8842057b1..e18ee92be 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -41,4 +41,8 @@ class PipelineScheduler(): async def execute(self, event: AstrMessageEvent): '''执行 pipeline''' await self._process_stages(event) + + if not event._has_send_oper and event.get_platform_name() == "webchat": + await event.send(None) + logger.debug("pipeline 执行完毕。") \ No newline at end of file diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 87e9d5f7b..e2a438caf 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -5,7 +5,7 @@ import os from typing import Awaitable, Any from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, Image # noqa: F403 +from astrbot.api.message_components import Plain, Image, Record # noqa: F403 from astrbot.api import logger from astrbot.core import web_chat_queue, web_chat_back_queue from .webchat_event import WebChatMessageEvent @@ -70,6 +70,14 @@ class WebChatAdapter(Platform): abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img))) else: abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url']))) + if payload['audio_url']: + if isinstance(payload['audio_url'], list): + for audio in payload['audio_url']: + path = os.path.join(self.imgs_dir, audio) + abm.message.append(Record(file=path, path=path)) + else: + path = os.path.join(self.imgs_dir, payload['audio_url']) + abm.message.append(Record(file=path, path=path)) logger.debug(f"WebChatAdapter: {abm.message}") diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index c988724be..0ef57ed5f 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -12,9 +12,13 @@ class WebChatMessageEvent(AstrMessageEvent): os.makedirs(self.imgs_dir, exist_ok=True) async def send(self, message: MessageChain): + if not message: + web_chat_back_queue.put_nowait(None) + return + for comp in message.chain: if isinstance(comp, Plain): - await web_chat_back_queue.put(comp.text) + web_chat_back_queue.put_nowait(comp.text) elif isinstance(comp, Image): # save image to local filename = str(uuid.uuid4()) + ".jpg" @@ -26,6 +30,6 @@ class WebChatMessageEvent(AstrMessageEvent): f.write(f2.read()) elif comp.file and comp.file.startswith("http"): await download_image_by_url(comp.file, path=path) - await web_chat_back_queue.put(f"[IMAGE]{filename}") - await web_chat_back_queue.put(None) + web_chat_back_queue.put_nowait(f"[IMAGE]{filename}") + web_chat_back_queue.put_nowait(None) await super().send(message) \ No newline at end of file diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index b1dfe8732..246a74e57 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1,4 +1,4 @@ -from .provider import Provider, Personality +from .provider import Provider, Personality, STTProvider from .entites import ProviderMetaData @@ -6,4 +6,5 @@ __all__ = [ "Provider", "Personality", "ProviderMetaData", + "STTProvider" ] \ No newline at end of file diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 8dae2680d..0a733e3b9 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -1,13 +1,22 @@ +import enum from dataclasses import dataclass, field -from typing import List, Dict +from typing import List, Dict, Type from .func_tool_manager import FuncCall +class ProviderType(enum.Enum): + CHAT_COMPLETION = "chat_completion" + SPEECH_TO_TEXT = "speech_to_text" + TEXT_TO_SPEECH = "text_to_speech" + @dataclass class ProviderMetaData(): - type: str # 提供商适配器名称,如 openai, ollama - desc: str = "" # 提供商适配器描述. - + type: str + '''提供商适配器名称,如 openai, ollama''' + desc: str = "" + '''提供商适配器描述.''' + provider_type: ProviderType = ProviderType.CHAT_COMPLETION + cls_type: Type = None @dataclass class ProviderRequest(): diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 19075338c..3b64126f4 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,6 +1,7 @@ import traceback from astrbot.core.config.astrbot_config import AstrBotConfig -from .provider import Provider +from .provider import Provider, STTProvider +from .entites import ProviderType from typing import List from astrbot.core.db import BaseDatabase from collections import defaultdict @@ -11,10 +12,17 @@ class ProviderManager(): def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase): self.providers_config: List = config['provider'] self.provider_settings: dict = config['provider_settings'] + self.provider_stt_settings: dict = config.get('provider_stt_settings', {}) + self.provider_insts: List[Provider] = [] '''加载的 Provider 的实例''' + self.stt_provider_insts: List[STTProvider] = [] + '''加载的 Speech To Text Provider 的实例''' self.llm_tools = llm_tools self.curr_provider_inst: Provider = None + '''当前使用的 Provider 实例''' + self.curr_stt_provider_inst: STTProvider = None + '''当前使用的 Speech To Text Provider 实例''' self.loaded_ids = defaultdict(bool) self.db_helper = db_helper @@ -31,19 +39,29 @@ class ProviderManager(): raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。") self.loaded_ids[provider_cfg['id']] = True - match provider_cfg['type']: - case "openai_chat_completion": - from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401 - case "zhipu_chat_completion": - from .sources.zhipu_source import ProviderZhipu # noqa: F401 - case "llm_tuner": - logger.info("加载 LLM Tuner 工具 ...") - from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401 - case "dify": - from .sources.dify_source import ProviderDify # noqa: F401 - case "googlegenai_chat_completion": - from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401 - + try: + match provider_cfg['type']: + case "openai_chat_completion": + from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401 + case "zhipu_chat_completion": + from .sources.zhipu_source import ProviderZhipu # noqa: F401 + case "llm_tuner": + logger.info("加载 LLM Tuner 工具 ...") + from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401 + case "dify": + from .sources.dify_source import ProviderDify # noqa: F401 + case "googlegenai_chat_completion": + from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401 + case "openai_whisper_api": + from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401 + case "openai_whisper_selfhost": + from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401 + except (ImportError, ModuleNotFoundError) as e: + logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") + continue + except Exception as e: + logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因") + continue async def initialize(self): for provider_config in self.providers_config: @@ -53,23 +71,54 @@ class ProviderManager(): logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") continue selected_provider_id = sp.get("curr_provider") - cls_type = provider_cls_map[provider_config['type']] + selected_stt_provider_id = self.provider_stt_settings.get("provider_id") + provider_enabled = self.provider_settings.get("enable", False) + stt_enabled = self.provider_stt_settings.get("enable", False) + + provider_metadata = provider_cls_map[provider_config['type']] logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...") try: - inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) - self.provider_insts.append(inst) - if selected_provider_id == provider_config['id']: - self.curr_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") + # 按任务实例化提供商 + + if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: + # STT 任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.stt_provider_insts.append(inst) + if selected_stt_provider_id == provider_config['id'] and stt_enabled: + self.curr_stt_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") + + elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: + # 文本生成任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.provider_insts.append(inst) + if selected_provider_id == provider_config['id'] and provider_enabled: + self.curr_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") + except Exception as e: traceback.print_exc() logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") - if len(self.provider_insts) > 0 and not self.curr_provider_inst: + if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled: self.curr_provider_inst = self.provider_insts[0] + if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled: + self.curr_stt_provider_inst = self.stt_provider_insts[0] + if not self.curr_provider_inst: - logger.warning("未启用任何提供商适配器。") + logger.warning("未启用任何用于 文本生成 的提供商适配器。") + if self.provider_stt_settings.get("enable"): + if not self.curr_stt_provider_inst: + logger.warning("未启用任何用于 语音转文本 的提供商适配器。") def get_insts(self): return self.provider_insts diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 553dd78dc..30c6b02d0 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -125,6 +125,33 @@ class Provider(abc.ABC): '''重置某一个 session_id 的上下文''' raise NotImplementedError() + def meta(self) -> ProviderMeta: + '''获取 Provider 的元数据''' + return ProviderMeta( + id=self.provider_config['id'], + model=self.get_model(), + type=self.provider_config['type'] + ) + + +class STTProvider(): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + self.provider_config = provider_config + self.provider_settings = provider_settings + + @abc.abstractmethod + async def get_text(self, audio_url: str) -> str: + '''获取音频的文本''' + raise NotImplementedError() + + def set_model(self, model_name: str): + '''设置当前使用的模型名称''' + self.model_name = model_name + + def get_model(self) -> str: + '''获取当前使用的模型''' + return self.provider_config.get("model", "") + def meta(self) -> ProviderMeta: '''获取 Provider 的元数据''' return ProviderMeta( diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 00c3ad877..61e64408f 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,16 +1,20 @@ from typing import List, Dict, Type -from .entites import ProviderMetaData +from .entites import ProviderMetaData, ProviderType from astrbot.core import logger from .func_tool_manager import FuncCall provider_registry: List[ProviderMetaData] = [] '''维护了通过装饰器注册的 Provider''' -provider_cls_map: Dict[str, Type] = {} -'''维护了 Provider 类型名称和 Provider 类的映射''' +provider_cls_map: Dict[str, ProviderMetaData] = {} +'''维护了 Provider 类型名称和 ProviderMetadata 的映射''' llm_tools = FuncCall() -def register_provider_adapter(provider_type_name: str, desc: str): +def register_provider_adapter( + provider_type_name: str, + desc: str, + provider_type: ProviderType = ProviderType.CHAT_COMPLETION +): '''用于注册平台适配器的带参装饰器''' def decorator(cls): if provider_type_name in provider_cls_map: @@ -19,9 +23,11 @@ def register_provider_adapter(provider_type_name: str, desc: str): pm = ProviderMetaData( type=provider_type_name, desc=desc, + provider_type=provider_type, + cls_type=cls ) provider_registry.append(pm) - provider_cls_map[provider_type_name] = cls + provider_cls_map[provider_type_name] = pm logger.debug(f"Provider {provider_type_name} 已注册") return cls diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py new file mode 100644 index 000000000..7159bf9d4 --- /dev/null +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -0,0 +1,95 @@ +import uuid +import os +import io +from openai import AsyncOpenAI, NOT_GIVEN +from ..provider import STTProvider +from ..entites import ProviderType +from astrbot.core.utils.io import download_file +from ..register import register_provider_adapter +from astrbot.core import logger + +@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT) +class ProviderOpenAIWhisperAPI(STTProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.chosen_api_key = provider_config.get("api_key", "") + + self.client = AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=provider_config.get("api_base", None), + timeout=provider_config.get("timeout", NOT_GIVEN), + ) + + self.set_model(provider_config.get("model", None)) + + async def _convert_audio(self, path: str) -> str: + from pyffmpeg import FFmpeg + filename = str(uuid.uuid4()) + '.mp3' + ff = FFmpeg() + output_path = ff.convert(path, os.path.join('data/temp', filename)) + return output_path + + async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str: + import wave + + with wave.open(output_path, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(input_io.read()) + + return output_path + + async def _convert_silk(self, path: str) -> str: + import pysilk + filename = str(uuid.uuid4()) + '.wav' + output_path = os.path.join('data/temp', filename) + with open(path, "rb") as f: + input_data = f.read() + if input_data.startswith(b'\x02'): + # tencent 我爱你 + input_data = input_data[1:] + input_io = io.BytesIO(input_data) + output_io = io.BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + await self._pcm_to_wav(output_io, output_path) + + return output_path + + async def _is_silk_file(self, file_path): + silk_header = b"SILK" + with open(file_path, "rb") as f: + file_header = f.read(8) + + if silk_header in file_header: + return True + else: + return False + + async def get_text(self, audio_url: str) -> str: + '''only supports mp3, mp4, mpeg, m4a, wav, webm''' + if audio_url.startswith("http"): + name = str(uuid.uuid4()) + path = os.path.join("data/temp", name) + audio_url = await download_file(audio_url, path) + + if not os.path.exists(audio_url): + raise FileNotFoundError(f"文件不存在: {audio_url}") + + if audio_url.endswith(".amr") or audio_url.endswith(".silk"): + is_silk = await self._is_silk_file(audio_url) + if is_silk: + logger.info("Converting silk file to wav ...") + audio_url = await self._convert_silk(audio_url) + + + result = await self.client.audio.transcriptions.create( + model=self.model_name, + file=open(audio_url, "rb"), + ) + return result.text \ No newline at end of file diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py new file mode 100644 index 000000000..6f16d559a --- /dev/null +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -0,0 +1,92 @@ +import uuid +import os +import io +import asyncio +import whisper +from ..provider import STTProvider +from ..entites import ProviderType +from astrbot.core.utils.io import download_file +from ..register import register_provider_adapter +from astrbot.core import logger + + +@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT) +class ProviderOpenAIWhisperSelfHost(STTProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.set_model(provider_config.get("model", None)) + self.model = None + + async def initialize(self): + loop = asyncio.get_event_loop() + logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") + self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name) + logger.info("Whisper 模型加载完成。") + + async def _convert_audio(self, path: str) -> str: + from pyffmpeg import FFmpeg + filename = str(uuid.uuid4()) + '.mp3' + ff = FFmpeg() + output_path = ff.convert(path, os.path.join('data/temp', filename)) + return output_path + + async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str: + import wave + + with wave.open(output_path, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(input_io.read()) + + return output_path + + async def _convert_silk(self, path: str) -> str: + import pysilk + filename = str(uuid.uuid4()) + '.wav' + output_path = os.path.join('data/temp', filename) + with open(path, "rb") as f: + input_data = f.read() + if input_data.startswith(b'\x02'): + # tencent 我爱你 + input_data = input_data[1:] + input_io = io.BytesIO(input_data) + output_io = io.BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + await self._pcm_to_wav(output_io, output_path) + + return output_path + + async def _is_silk_file(self, file_path): + silk_header = b"SILK" + with open(file_path, "rb") as f: + file_header = f.read(8) + + if silk_header in file_header: + return True + else: + return False + + async def get_text(self, audio_url: str) -> str: + loop = asyncio.get_event_loop() + if audio_url.startswith("http"): + name = str(uuid.uuid4()) + path = os.path.join("data/temp", name) + audio_url = await download_file(audio_url, path) + + if not os.path.exists(audio_url): + raise FileNotFoundError(f"文件不存在: {audio_url}") + + if audio_url.endswith(".amr") or audio_url.endswith(".silk"): + is_silk = await self._is_silk_file(audio_url) + if is_silk: + logger.info("Converting silk file to wav ...") + audio_url = await self._convert_silk(audio_url) + + result = await loop.run_in_executor(None, self.model.transcribe, audio_url) + return result['text'] \ No newline at end of file diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 58810de96..b94b50a52 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -17,10 +17,6 @@ from .filter.regex import RegexFilter from typing import Awaitable from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager -class StarCommand(TypedDict): - full_command_name: str - command_name: str - class Context: ''' 暴露给插件的接口上下文。 @@ -168,13 +164,13 @@ class Context: def register_provider(self, provider: Provider): ''' - 注册一个 LLM Provider。 + 注册一个 LLM Provider(Chat_Completion 类型)。 ''' self.provider_manager.provider_insts.append(provider) def get_provider_by_id(self, provider_id: str) -> Provider: ''' - 通过 ID 获取 LLM Provider。 + 通过 ID 获取 LLM Provider(Chat_Completion 类型)。 ''' for provider in self.provider_manager.provider_insts: if provider.meta().id == provider_id: @@ -183,13 +179,13 @@ class Context: def get_all_providers(self) -> List[Provider]: ''' - 获取所有 LLM Provider。 + 获取所有 LLM Provider(Chat_Completion 类型)。 ''' return self.provider_manager.provider_insts def get_using_provider(self) -> Provider: ''' - 获取当前使用的 LLM Provider。 + 获取当前使用的 LLM Provider(Chat_Completion 类型)。 通过 /provider 指令切换。 ''' diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/dashboard/dashboard_lifecycle.py index 176ca4dc1..b363ae3a7 100644 --- a/astrbot/dashboard/dashboard_lifecycle.py +++ b/astrbot/dashboard/dashboard_lifecycle.py @@ -1,4 +1,5 @@ import asyncio +import traceback from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from .server import AstrBotDashboard @@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle: async def start(self): core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) - await core_lifecycle.initialize() - core_task = core_lifecycle.start() + + core_task = [] + try: + await core_lifecycle.initialize() + core_task = core_lifecycle.start() + except Exception as e: + logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") + logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") + logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") + self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db) task = asyncio.gather(core_task, self.dashboard_server.run()) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index cef7e563f..639c7cafa 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -6,9 +6,11 @@ from astrbot.core import web_chat_queue, web_chat_back_queue from quart import request, Response as QuartResponse, g from astrbot.core.db import BaseDatabase import asyncio +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + class ChatRoute(Route): - def __init__(self, context: RouteContext, db: BaseDatabase) -> None: + def __init__(self, context: RouteContext, db: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None: super().__init__(context) self.routes = { '/chat/send': ('POST', self.chat), @@ -17,11 +19,24 @@ class ChatRoute(Route): '/chat/get_conversation': ('GET', self.get_conversation), '/chat/delete_conversation': ('GET', self.delete_conversation), '/chat/get_file': ('GET', self.get_file), - '/chat/post_image': ('POST', self.post_image) + '/chat/post_image': ('POST', self.post_image), + '/chat/post_file': ('POST', self.post_file), + '/chat/status': ('GET', self.status), } self.db = db + self.core_lifecycle = core_lifecycle self.register_routes() self.imgs_dir = "data/webchat/imgs" + + self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp'] + + async def status(self): + has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None + has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None + return Response().ok(data={ + 'llm_enabled': has_llm_enabled, + 'stt_enabled': has_stt_enabled + }).__dict__ async def get_file(self): filename = request.args.get('filename') @@ -30,7 +45,13 @@ class ChatRoute(Route): try: with open(os.path.join(self.imgs_dir, filename), "rb") as f: - return QuartResponse(f.read(), mimetype="image/jpeg") + if filename.endswith(".wav"): + return QuartResponse(f.read(), mimetype="audio/wav") + elif filename.split('.')[-1] in self.supported_imgs: + return QuartResponse(f.read(), mimetype="image/jpeg") + else: + return QuartResponse(f.read()) + except FileNotFoundError: return Response().error("File not found").__dict__ @@ -47,6 +68,25 @@ class ChatRoute(Route): return Response().ok(data={ 'filename': filename }).__dict__ + + async def post_file(self): + post_data = await request.files + if 'file' not in post_data: + return Response().error("Missing key: file").__dict__ + + file = post_data['file'] + filename = f"{str(uuid.uuid4())}" + print(file) + # 通过文件格式判断文件类型 + if file.content_type.startswith('audio'): + filename += ".wav" + + path = os.path.join(self.imgs_dir, filename) + await file.save(path) + + return Response().ok(data={ + 'filename': filename + }).__dict__ async def chat(self): username = g.get('username', 'guest') @@ -61,20 +101,26 @@ class ChatRoute(Route): message = post_data['message'] conversation_id = post_data['conversation_id'] image_url = post_data.get('image_url') - if not message and not image_url: - return Response().error("Message and image_url are empty").__dict__ + audio_url = post_data.get('audio_url') + if not message and not image_url and not audio_url: + return Response().error("Message and image_url and audio_url are empty").__dict__ if not conversation_id: return Response().error("conversation_id is empty").__dict__ await web_chat_queue.put((username, conversation_id, { 'message': message, - 'image_url': image_url # list + 'image_url': image_url, # list + 'audio_url': audio_url })) async def stream(): ret = [] while True: - result = await web_chat_back_queue.get() + try: + result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒 + except asyncio.TimeoutError: + yield '[Error] 30 秒内没有返回数据,已放弃。\n' + return if result is None: break @@ -98,6 +144,8 @@ class ChatRoute(Route): } if image_url: new_his['image_url'] = image_url + if audio_url: + new_his['audio_url'] = audio_url history.append(new_his) for r in ret: history.append({ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 97013b423..798b212c2 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -31,7 +31,7 @@ class AstrBotDashboard(): self.lr = LogRoute(self.context, core_lifecycle.log_broker) self.sfr = StaticFileRoute(self.context) self.ar = AuthRoute(self.context) - self.chat_route = ChatRoute(self.context, db) + self.chat_route = ChatRoute(self.context, db, core_lifecycle) async def auth_middleware(self): if not request.path.startswith("/api"): diff --git a/dashboard/package.json b/dashboard/package.json index 2888d4415..9b59cfa92 100644 --- a/dashboard/package.json +++ b/dashboard/package.json @@ -33,8 +33,6 @@ "vue3-apexcharts": "1.4.4", "vue3-print-nb": "0.1.4", "vuetify": "3.3.14", - "xterm": "^5.3.0", - "xterm-addon-fit": "^0.8.0", "yup": "1.2.0" }, "devDependencies": { diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index f4cb9b0e5..cfe96e544 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -20,7 +20,7 @@ marked.setOptions({ :disabled="!currCid">+ 创建对话 - @@ -31,12 +31,24 @@ marked.setOptions({ +
+ + + LLM + + + + 语音转文本 + +
+ 删除此对话 +
-
+
@@ -49,6 +61,12 @@ marked.setOptions({ style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help 获取帮助 😊
+
+ + K + 开始语音 🎤 +
@@ -58,13 +76,21 @@ marked.setOptions({
{{ msg.message }} -
+
+ +
+ +
@@ -79,26 +105,36 @@ marked.setOptions({
- - + @click:clear="clearMessage" style="width: 100%; max-width: 850px;"> -
+
mdi-close-circle
+
+
+ 新录音 + mdi-close-circle +
+ +
@@ -128,18 +173,95 @@ export default { conversations: [], currCid: '', stagedImagesUrl: [], - loadingChat: false + loadingChat: false, + + inputFieldLabel: '聊天吧!', + + isRecording: false, + audioChunks: [], + stagedAudioUrl: "", + mediaRecorder: null, + + status: {}, + statusText: '' } }, mounted() { + this.checkStatus(); this.getConversations(); let inputField = document.getElementById('input-field'); inputField.addEventListener('paste', this.handlePaste); - + inputField.addEventListener('keydown', function (e) { + if (e.keyCode == 13 && !e.shiftKey) { + e.preventDefault(); + this.sendMessage(); + } + }.bind(this)); + document.addEventListener('keydown', function (e) { + if (e.keyCode == 75) { + this.isRecording ? this.stopRecording() : this.startRecording(); + } + }.bind(this)); }, methods: { + + removeAudio() { + this.stagedAudioUrl = null; + }, + + checkStatus() { + axios.get('/api/chat/status').then(response => { + console.log(response.data); + this.status = response.data.data; + }).catch(err => { + console.error(err); + }); + }, + + async startRecording() { + const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + this.mediaRecorder = new MediaRecorder(stream); + this.mediaRecorder.ondataavailable = (event) => { + this.audioChunks.push(event.data); + }; + this.mediaRecorder.start(); + this.isRecording = true; + this.inputFieldLabel = "录音中,请说话..."; + }, + + async stopRecording() { + this.isRecording = false; + this.inputFieldLabel = "聊天吧!"; + this.mediaRecorder.stop(); + this.mediaRecorder.onstop = async () => { + const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' }); + this.audioChunks = []; + + this.mediaRecorder.stream.getTracks().forEach(track => track.stop()); + + const formData = new FormData(); + formData.append('file', audioBlob); + + try { + const response = await axios.post('/api/chat/post_file', formData, { + headers: { + 'Content-Type': 'multipart/form-data', + 'Authorization': 'Bearer ' + localStorage.getItem('token') + } + }); + + const audio = response.data.data.filename; + console.log('Audio uploaded:', audio); + + this.stagedAudioUrl = `/api/chat/get_file?filename=${audio}`; + } catch (err) { + console.error('Error uploading audio:', err); + } + }; + }, + async handlePaste(event) { console.log('Pasting image...'); const items = event.clipboardData.items; @@ -160,7 +282,6 @@ export default { const img = response.data.data.filename; this.stagedImagesUrl.push(`/api/chat/get_file?filename=${img}`); - scrollToBottom(); } catch (err) { console.error('Error uploading image:', err); } @@ -198,6 +319,9 @@ export default { message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`; } } + if (message[i].audio_url) { + message[i].audio_url = `/api/chat/get_file?filename=${message[i].audio_url}`; + } } this.messages = message; }).catch(err => { @@ -250,24 +374,26 @@ export default { this.messages.push({ type: 'user', message: this.prompt, - image_url: this.stagedImagesUrl + image_url: this.stagedImagesUrl, + audio_url: this.stagedAudioUrl }); - // let bot_resp = { - // type: 'bot', - // message: ref('') - // } - - // this.messages.push(bot_resp); - this.scrollToBottom(); + // images let image_filenames = []; for (let i = 0; i < this.stagedImagesUrl.length; i++) { let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', ''); image_filenames.push(img); } + // audio + let audio_filenames = []; + if (this.stagedAudioUrl) { + let audio = this.stagedAudioUrl.replace('/api/chat/get_file?filename=', ''); + audio_filenames.push(audio); + } + this.loadingChat = true; @@ -277,11 +403,17 @@ export default { 'Content-Type': 'application/json', 'Authorization': 'Bearer ' + localStorage.getItem('token') }, - body: JSON.stringify({ message: this.prompt, conversation_id: this.currCid, image_url: image_filenames }) // 发送请求体 + body: JSON.stringify({ + message: this.prompt, + conversation_id: this.currCid, + image_url: image_filenames, + audio_url: audio_filenames + }) // 发送请求体 }) .then(response => { this.prompt = ''; this.stagedImagesUrl = []; + this.stagedAudioUrl = ""; this.loadingChat = false; diff --git a/requirements.txt b/requirements.txt index 965651870..432d5bb19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ aiocqhttp pyjwt apscheduler docstring_parser -aiodocker \ No newline at end of file +aiodocker +silk-python \ No newline at end of file