From be662b913c4b01643132aa3c70fc6bd988d349a0 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 11 Jan 2025 17:19:28 +0800 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20Whisper=20STT?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E9=80=82=E9=85=8D=20Tencent=20=E8=AF=AD?= =?UTF-8?q?=E9=9F=B3=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/api/provider/__init__.py | 4 +- astrbot/core/config/default.py | 32 ++++++- astrbot/core/message/components.py | 2 +- astrbot/core/pipeline/__init__.py | 3 + .../core/pipeline/preprocess_stage/stage.py | 54 +++++++++++ astrbot/core/provider/__init__.py | 3 +- astrbot/core/provider/entites.py | 17 +++- astrbot/core/provider/manager.py | 47 +++++++-- astrbot/core/provider/provider.py | 27 ++++++ astrbot/core/provider/register.py | 16 +++- .../provider/sources/whisper_api_source.py | 95 +++++++++++++++++++ astrbot/core/star/context.py | 12 +-- dashboard/package.json | 2 - requirements.txt | 3 +- 14 files changed, 284 insertions(+), 33 deletions(-) create mode 100644 astrbot/core/pipeline/preprocess_stage/stage.py create mode 100644 astrbot/core/provider/sources/whisper_api_source.py diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 377f8d4b3..17e379478 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,2 +1,2 @@ -from astrbot.core.provider import Provider, Personality, ProviderMetaData -from astrbot.core.provider.entites import ProviderRequest \ No newline at end of file +from astrbot.core.provider import Provider, STTProvider, Personality +from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData \ No newline at end of file diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 43a010d45..983ca3cd3 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -33,6 +33,10 @@ DEFAULT_CONFIG = { "default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。", "prompt_prefix": "", }, + "provider_stt_settings": { + "enable": False, + "provider_id": "", + }, "content_safety": { "internal_keywords": {"enable": True, "extra_keywords": []}, "baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""}, @@ -315,6 +319,14 @@ CONFIG_METADATA_2 = { "dify_api_key": "", "dify_api_base": "https://api.dify.ai/v1", "dify_workflow_output_key": "", + }, + "whisper(API)": { + "id": "whisper", + "type": "openai_whisper_api", + "enable": True, + "api_key": "", + "api_base": "", + "model": "whisper-1", } }, "items": { @@ -416,7 +428,8 @@ CONFIG_METADATA_2 = { "enable": { "description": "启用大语言模型聊天", "type": "bool", - "hint": "是否启用大语言模型聊天。默认启用", + "hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。", + "obvious_hint": True }, "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", @@ -450,6 +463,23 @@ CONFIG_METADATA_2 = { }, }, }, + "provider_stt_settings": { + "description": "语音转文本(STT)", + "type": "object", + "items": { + "enable": { + "description": "启用语音转文本(STT)", + "type": "bool", + "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。", + "obvious_hint": True + }, + "provider_id": { + "description": "提供商 ID,不填则默认第一个STT提供商", + "type": "string", + "hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。", + }, + }, + }, }, }, "misc_config_group": { diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 50e81c86d..74d90c535 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -123,7 +123,7 @@ class Record(BaseMessageComponent): proxy: T.Optional[bool] = True timeout: T.Optional[int] = 0 # 额外 - path: T.Optional[str] + path: T.Optional[str] # 用这个 def __init__(self, file: T.Optional[str], **_): for k in _.keys(): diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 108b1f134..bdba27699 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -3,6 +3,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR from .waking_check.stage import WakingCheckStage from .whitelist_check.stage import WhitelistCheckStage from .content_safety_check.stage import ContentSafetyCheckStage +from .preprocess_stage.stage import PreProcessStage from .process_stage.stage import ProcessStage from .result_decorate.stage import ResultDecorateStage from .respond.stage import RespondStage @@ -12,6 +13,7 @@ STAGES_ORDER = [ "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 "RateLimitCheckStage", # 检查会话是否超过频率限制 "ContentSafetyCheckStage", # 检查内容安全 + "PreProcessStage", # 预处理 "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 "RespondStage" # 发送消息 @@ -21,6 +23,7 @@ __all__ = [ "WakingCheckStage", "WhitelistCheckStage", "ContentSafetyCheckStage", + "PreProcessStage", "ProcessStage", "ResultDecorateStage", "RespondStage", diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py new file mode 100644 index 000000000..a28e15485 --- /dev/null +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -0,0 +1,54 @@ +import traceback +import asyncio +from typing import Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core import logger +from astrbot.core.message.components import Plain, Record + +@register_stage +class PreProcessStage(Stage): + + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.config = ctx.astrbot_config + self.plugin_manager = ctx.plugin_manager + + self.stt_settings: dict = self.config.get('provider_stt_settings', {}) + + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + '''在处理事件之前的预处理''' + + if self.stt_settings.get('enable', False): + # STT 处理 + # TODO: 独立 + stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst + if stt_provider: + message_chain = event.get_messages() + for idx, component in enumerate(message_chain): + if isinstance(component, Record) and component.path: + + path = component.path + + retry = 5 + + for i in range(retry): + try: + result = await stt_provider.get_text(audio_url=path) + if result: + logger.info("语音转文本结果: " + result) + message_chain[idx] = Plain(result) + event.message_str += result + event.message_obj.message_str += result + break + except FileNotFoundError: + # napcat workaround + logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}") + await asyncio.sleep(0.5) + continue + except BaseException as e: + logger.error(traceback.format_exc()) + logger.error(f"语音转文本失败: {e}") + break diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index b1dfe8732..246a74e57 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1,4 +1,4 @@ -from .provider import Provider, Personality +from .provider import Provider, Personality, STTProvider from .entites import ProviderMetaData @@ -6,4 +6,5 @@ __all__ = [ "Provider", "Personality", "ProviderMetaData", + "STTProvider" ] \ No newline at end of file diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 8dae2680d..0a733e3b9 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -1,13 +1,22 @@ +import enum from dataclasses import dataclass, field -from typing import List, Dict +from typing import List, Dict, Type from .func_tool_manager import FuncCall +class ProviderType(enum.Enum): + CHAT_COMPLETION = "chat_completion" + SPEECH_TO_TEXT = "speech_to_text" + TEXT_TO_SPEECH = "text_to_speech" + @dataclass class ProviderMetaData(): - type: str # 提供商适配器名称,如 openai, ollama - desc: str = "" # 提供商适配器描述. - + type: str + '''提供商适配器名称,如 openai, ollama''' + desc: str = "" + '''提供商适配器描述.''' + provider_type: ProviderType = ProviderType.CHAT_COMPLETION + cls_type: Type = None @dataclass class ProviderRequest(): diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 19075338c..4f70c33f0 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,6 +1,7 @@ import traceback from astrbot.core.config.astrbot_config import AstrBotConfig -from .provider import Provider +from .provider import Provider, STTProvider +from .entites import ProviderType from typing import List from astrbot.core.db import BaseDatabase from collections import defaultdict @@ -11,10 +12,17 @@ class ProviderManager(): def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase): self.providers_config: List = config['provider'] self.provider_settings: dict = config['provider_settings'] + self.provider_stt_settings: dict = config.get('provider_stt_settings', {}) + self.provider_insts: List[Provider] = [] '''加载的 Provider 的实例''' + self.stt_provider_insts: List[STTProvider] = [] + '''加载的 Speech To Text Provider 的实例''' self.llm_tools = llm_tools self.curr_provider_inst: Provider = None + '''当前使用的 Provider 实例''' + self.curr_stt_provider_inst: STTProvider = None + '''当前使用的 Speech To Text Provider 实例''' self.loaded_ids = defaultdict(bool) self.db_helper = db_helper @@ -43,6 +51,8 @@ class ProviderManager(): from .sources.dify_source import ProviderDify # noqa: F401 case "googlegenai_chat_completion": from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401 + case "openai_whisper_api": + from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401 async def initialize(self): @@ -53,14 +63,29 @@ class ProviderManager(): logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") continue selected_provider_id = sp.get("curr_provider") - cls_type = provider_cls_map[provider_config['type']] + selected_stt_provider_id = self.provider_stt_settings.get("provider_id") + + provider_metadata = provider_cls_map[provider_config['type']] logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...") try: - inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) - self.provider_insts.append(inst) - if selected_provider_id == provider_config['id']: - self.curr_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") + # 按任务实例化提供商 + + if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: + # STT 任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings) + self.stt_provider_insts.append(inst) + if selected_stt_provider_id == provider_config['id']: + self.curr_stt_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") + + elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: + # 文本生成任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) + self.provider_insts.append(inst) + if selected_provider_id == provider_config['id']: + self.curr_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") + except Exception as e: traceback.print_exc() logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") @@ -68,8 +93,14 @@ class ProviderManager(): if len(self.provider_insts) > 0 and not self.curr_provider_inst: self.curr_provider_inst = self.provider_insts[0] + if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst: + self.curr_stt_provider_inst = self.stt_provider_insts[0] + if not self.curr_provider_inst: - logger.warning("未启用任何提供商适配器。") + logger.warning("未启用任何用于 文本生成 的提供商适配器。") + if self.provider_stt_settings.get("enable"): + if not self.curr_stt_provider_inst: + logger.warning("未启用任何用于 语音转文本 的提供商适配器。") def get_insts(self): return self.provider_insts diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 553dd78dc..30c6b02d0 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -125,6 +125,33 @@ class Provider(abc.ABC): '''重置某一个 session_id 的上下文''' raise NotImplementedError() + def meta(self) -> ProviderMeta: + '''获取 Provider 的元数据''' + return ProviderMeta( + id=self.provider_config['id'], + model=self.get_model(), + type=self.provider_config['type'] + ) + + +class STTProvider(): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + self.provider_config = provider_config + self.provider_settings = provider_settings + + @abc.abstractmethod + async def get_text(self, audio_url: str) -> str: + '''获取音频的文本''' + raise NotImplementedError() + + def set_model(self, model_name: str): + '''设置当前使用的模型名称''' + self.model_name = model_name + + def get_model(self) -> str: + '''获取当前使用的模型''' + return self.provider_config.get("model", "") + def meta(self) -> ProviderMeta: '''获取 Provider 的元数据''' return ProviderMeta( diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 00c3ad877..61e64408f 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,16 +1,20 @@ from typing import List, Dict, Type -from .entites import ProviderMetaData +from .entites import ProviderMetaData, ProviderType from astrbot.core import logger from .func_tool_manager import FuncCall provider_registry: List[ProviderMetaData] = [] '''维护了通过装饰器注册的 Provider''' -provider_cls_map: Dict[str, Type] = {} -'''维护了 Provider 类型名称和 Provider 类的映射''' +provider_cls_map: Dict[str, ProviderMetaData] = {} +'''维护了 Provider 类型名称和 ProviderMetadata 的映射''' llm_tools = FuncCall() -def register_provider_adapter(provider_type_name: str, desc: str): +def register_provider_adapter( + provider_type_name: str, + desc: str, + provider_type: ProviderType = ProviderType.CHAT_COMPLETION +): '''用于注册平台适配器的带参装饰器''' def decorator(cls): if provider_type_name in provider_cls_map: @@ -19,9 +23,11 @@ def register_provider_adapter(provider_type_name: str, desc: str): pm = ProviderMetaData( type=provider_type_name, desc=desc, + provider_type=provider_type, + cls_type=cls ) provider_registry.append(pm) - provider_cls_map[provider_type_name] = cls + provider_cls_map[provider_type_name] = pm logger.debug(f"Provider {provider_type_name} 已注册") return cls diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py new file mode 100644 index 000000000..7159bf9d4 --- /dev/null +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -0,0 +1,95 @@ +import uuid +import os +import io +from openai import AsyncOpenAI, NOT_GIVEN +from ..provider import STTProvider +from ..entites import ProviderType +from astrbot.core.utils.io import download_file +from ..register import register_provider_adapter +from astrbot.core import logger + +@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT) +class ProviderOpenAIWhisperAPI(STTProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.chosen_api_key = provider_config.get("api_key", "") + + self.client = AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=provider_config.get("api_base", None), + timeout=provider_config.get("timeout", NOT_GIVEN), + ) + + self.set_model(provider_config.get("model", None)) + + async def _convert_audio(self, path: str) -> str: + from pyffmpeg import FFmpeg + filename = str(uuid.uuid4()) + '.mp3' + ff = FFmpeg() + output_path = ff.convert(path, os.path.join('data/temp', filename)) + return output_path + + async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str: + import wave + + with wave.open(output_path, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(input_io.read()) + + return output_path + + async def _convert_silk(self, path: str) -> str: + import pysilk + filename = str(uuid.uuid4()) + '.wav' + output_path = os.path.join('data/temp', filename) + with open(path, "rb") as f: + input_data = f.read() + if input_data.startswith(b'\x02'): + # tencent 我爱你 + input_data = input_data[1:] + input_io = io.BytesIO(input_data) + output_io = io.BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + await self._pcm_to_wav(output_io, output_path) + + return output_path + + async def _is_silk_file(self, file_path): + silk_header = b"SILK" + with open(file_path, "rb") as f: + file_header = f.read(8) + + if silk_header in file_header: + return True + else: + return False + + async def get_text(self, audio_url: str) -> str: + '''only supports mp3, mp4, mpeg, m4a, wav, webm''' + if audio_url.startswith("http"): + name = str(uuid.uuid4()) + path = os.path.join("data/temp", name) + audio_url = await download_file(audio_url, path) + + if not os.path.exists(audio_url): + raise FileNotFoundError(f"文件不存在: {audio_url}") + + if audio_url.endswith(".amr") or audio_url.endswith(".silk"): + is_silk = await self._is_silk_file(audio_url) + if is_silk: + logger.info("Converting silk file to wav ...") + audio_url = await self._convert_silk(audio_url) + + + result = await self.client.audio.transcriptions.create( + model=self.model_name, + file=open(audio_url, "rb"), + ) + return result.text \ No newline at end of file diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 58810de96..b94b50a52 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -17,10 +17,6 @@ from .filter.regex import RegexFilter from typing import Awaitable from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager -class StarCommand(TypedDict): - full_command_name: str - command_name: str - class Context: ''' 暴露给插件的接口上下文。 @@ -168,13 +164,13 @@ class Context: def register_provider(self, provider: Provider): ''' - 注册一个 LLM Provider。 + 注册一个 LLM Provider(Chat_Completion 类型)。 ''' self.provider_manager.provider_insts.append(provider) def get_provider_by_id(self, provider_id: str) -> Provider: ''' - 通过 ID 获取 LLM Provider。 + 通过 ID 获取 LLM Provider(Chat_Completion 类型)。 ''' for provider in self.provider_manager.provider_insts: if provider.meta().id == provider_id: @@ -183,13 +179,13 @@ class Context: def get_all_providers(self) -> List[Provider]: ''' - 获取所有 LLM Provider。 + 获取所有 LLM Provider(Chat_Completion 类型)。 ''' return self.provider_manager.provider_insts def get_using_provider(self) -> Provider: ''' - 获取当前使用的 LLM Provider。 + 获取当前使用的 LLM Provider(Chat_Completion 类型)。 通过 /provider 指令切换。 ''' diff --git a/dashboard/package.json b/dashboard/package.json index 2888d4415..9b59cfa92 100644 --- a/dashboard/package.json +++ b/dashboard/package.json @@ -33,8 +33,6 @@ "vue3-apexcharts": "1.4.4", "vue3-print-nb": "0.1.4", "vuetify": "3.3.14", - "xterm": "^5.3.0", - "xterm-addon-fit": "^0.8.0", "yup": "1.2.0" }, "devDependencies": { diff --git a/requirements.txt b/requirements.txt index 965651870..432d5bb19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ aiocqhttp pyjwt apscheduler docstring_parser -aiodocker \ No newline at end of file +aiodocker +silk-python \ No newline at end of file From a09998f910b36e52ee09ffb96730f4b0f39725bb Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 11 Jan 2025 18:54:40 +0800 Subject: [PATCH 2/6] =?UTF-8?q?feat:=20webchat=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E8=AF=AD=E9=9F=B3=E8=BE=93=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sources/webchat/webchat_adapter.py | 10 +- astrbot/dashboard/routes/chat.py | 42 +++++- dashboard/src/views/ChatPage.vue | 124 ++++++++++++++---- 3 files changed, 148 insertions(+), 28 deletions(-) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 87e9d5f7b..e2a438caf 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -5,7 +5,7 @@ import os from typing import Awaitable, Any from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, Image # noqa: F403 +from astrbot.api.message_components import Plain, Image, Record # noqa: F403 from astrbot.api import logger from astrbot.core import web_chat_queue, web_chat_back_queue from .webchat_event import WebChatMessageEvent @@ -70,6 +70,14 @@ class WebChatAdapter(Platform): abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img))) else: abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url']))) + if payload['audio_url']: + if isinstance(payload['audio_url'], list): + for audio in payload['audio_url']: + path = os.path.join(self.imgs_dir, audio) + abm.message.append(Record(file=path, path=path)) + else: + path = os.path.join(self.imgs_dir, payload['audio_url']) + abm.message.append(Record(file=path, path=path)) logger.debug(f"WebChatAdapter: {abm.message}") diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index cef7e563f..02f40d648 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -17,11 +17,14 @@ class ChatRoute(Route): '/chat/get_conversation': ('GET', self.get_conversation), '/chat/delete_conversation': ('GET', self.delete_conversation), '/chat/get_file': ('GET', self.get_file), - '/chat/post_image': ('POST', self.post_image) + '/chat/post_image': ('POST', self.post_image), + '/chat/post_file': ('POST', self.post_file) } self.db = db self.register_routes() self.imgs_dir = "data/webchat/imgs" + + self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp'] async def get_file(self): filename = request.args.get('filename') @@ -30,7 +33,13 @@ class ChatRoute(Route): try: with open(os.path.join(self.imgs_dir, filename), "rb") as f: - return QuartResponse(f.read(), mimetype="image/jpeg") + if filename.endswith(".wav"): + return QuartResponse(f.read(), mimetype="audio/wav") + elif filename.split('.')[-1] in self.supported_imgs: + return QuartResponse(f.read(), mimetype="image/jpeg") + else: + return QuartResponse(f.read()) + except FileNotFoundError: return Response().error("File not found").__dict__ @@ -47,6 +56,25 @@ class ChatRoute(Route): return Response().ok(data={ 'filename': filename }).__dict__ + + async def post_file(self): + post_data = await request.files + if 'file' not in post_data: + return Response().error("Missing key: file").__dict__ + + file = post_data['file'] + filename = f"{str(uuid.uuid4())}" + print(file) + # 通过文件格式判断文件类型 + if file.content_type.startswith('audio'): + filename += ".wav" + + path = os.path.join(self.imgs_dir, filename) + await file.save(path) + + return Response().ok(data={ + 'filename': filename + }).__dict__ async def chat(self): username = g.get('username', 'guest') @@ -61,14 +89,16 @@ class ChatRoute(Route): message = post_data['message'] conversation_id = post_data['conversation_id'] image_url = post_data.get('image_url') - if not message and not image_url: - return Response().error("Message and image_url are empty").__dict__ + audio_url = post_data.get('audio_url') + if not message and not image_url and not audio_url: + return Response().error("Message and image_url and audio_url are empty").__dict__ if not conversation_id: return Response().error("conversation_id is empty").__dict__ await web_chat_queue.put((username, conversation_id, { 'message': message, - 'image_url': image_url # list + 'image_url': image_url, # list + 'audio_url': audio_url })) async def stream(): @@ -98,6 +128,8 @@ class ChatRoute(Route): } if image_url: new_his['image_url'] = image_url + if audio_url: + new_his['audio_url'] = audio_url history.append(new_his) for r in ret: history.append({ diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index f4cb9b0e5..f3c0227a5 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -58,13 +58,21 @@ marked.setOptions({
{{ msg.message }} -
+
+ +
+ +
@@ -79,26 +87,28 @@ marked.setOptions({
- - + style="width: 100%; max-width: 850px;"> -
+
mdi-close-circle
+
+
+ 新录音 + mdi-close-circle +
+ +
@@ -128,7 +146,14 @@ export default { conversations: [], currCid: '', stagedImagesUrl: [], - loadingChat: false + loadingChat: false, + + inputFieldLabel: '聊天吧!', + + isRecording: false, + audioChunks: [], + stagedAudioUrl: "", + mediaRecorder: null } }, @@ -136,10 +161,54 @@ export default { this.getConversations(); let inputField = document.getElementById('input-field'); inputField.addEventListener('paste', this.handlePaste); - }, methods: { + + removeAudio() { + this.stagedAudioUrl = null; + }, + + async startRecording() { + const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + this.mediaRecorder = new MediaRecorder(stream); + this.mediaRecorder.ondataavailable = (event) => { + this.audioChunks.push(event.data); + }; + this.mediaRecorder.start(); + this.isRecording = true; + this.inputFieldLabel = "录音中,请说话..."; + }, + + async stopRecording() { + this.isRecording = false; + this.inputFieldLabel = "聊天吧!"; + this.mediaRecorder.stop(); + this.mediaRecorder.onstop = async () => { + const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' }); + this.audioChunks = []; + + const formData = new FormData(); + formData.append('file', audioBlob); + + try { + const response = await axios.post('/api/chat/post_file', formData, { + headers: { + 'Content-Type': 'multipart/form-data', + 'Authorization': 'Bearer ' + localStorage.getItem('token') + } + }); + + const audio = response.data.data.filename; + console.log('Audio uploaded:', audio); + + this.stagedAudioUrl = `/api/chat/get_file?filename=${audio}`; + } catch (err) { + console.error('Error uploading audio:', err); + } + }; + }, + async handlePaste(event) { console.log('Pasting image...'); const items = event.clipboardData.items; @@ -198,6 +267,9 @@ export default { message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`; } } + if (message[i].audio_url) { + message[i].audio_url = `/api/chat/get_file?filename=${message[i].audio_url}`; + } } this.messages = message; }).catch(err => { @@ -250,24 +322,26 @@ export default { this.messages.push({ type: 'user', message: this.prompt, - image_url: this.stagedImagesUrl + image_url: this.stagedImagesUrl, + audio_url: this.stagedAudioUrl }); - // let bot_resp = { - // type: 'bot', - // message: ref('') - // } - - // this.messages.push(bot_resp); - this.scrollToBottom(); + // images let image_filenames = []; for (let i = 0; i < this.stagedImagesUrl.length; i++) { let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', ''); image_filenames.push(img); } + // audio + let audio_filenames = []; + if (this.stagedAudioUrl) { + let audio = this.stagedAudioUrl.replace('/api/chat/get_file?filename=', ''); + audio_filenames.push(audio); + } + this.loadingChat = true; @@ -277,11 +351,17 @@ export default { 'Content-Type': 'application/json', 'Authorization': 'Bearer ' + localStorage.getItem('token') }, - body: JSON.stringify({ message: this.prompt, conversation_id: this.currCid, image_url: image_filenames }) // 发送请求体 + body: JSON.stringify({ + message: this.prompt, + conversation_id: this.currCid, + image_url: image_filenames, + audio_url: audio_filenames + }) // 发送请求体 }) .then(response => { this.prompt = ''; this.stagedImagesUrl = []; + this.stagedAudioUrl = ""; this.loadingChat = false; From f2566c68e358fa1e84c36e782be268e255eeb87f Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 11 Jan 2025 19:07:26 +0800 Subject: [PATCH 3/6] =?UTF-8?q?feat:=20=E6=8C=89=20K=20=E8=AF=AD=E9=9F=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dashboard/src/views/ChatPage.vue | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index f3c0227a5..e921080fc 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -49,6 +49,12 @@ marked.setOptions({ style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help 获取帮助 😊
+
+ + K + 开始语音 🎤 +
@@ -90,15 +96,21 @@ marked.setOptions({