From c0846bc7894b1df49510ba2ee562fb8a8f8d8cbe Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 17 Jan 2026 14:41:05 +0800 Subject: [PATCH] feat: astr live --- astrbot/core/astr_agent_run_util.py | 175 +++++- .../method/agent_sub_stages/internal.py | 48 +- .../sources/webchat/webchat_adapter.py | 1 + .../platform/sources/webchat/webchat_event.py | 14 + astrbot/core/provider/provider.py | 54 ++ astrbot/dashboard/routes/live_chat.py | 350 ++++++++++++ astrbot/dashboard/server.py | 2 + dashboard/index.html | 3 + dashboard/src/components/chat/Chat.vue | 201 ++++--- dashboard/src/components/chat/ChatInput.vue | 25 +- dashboard/src/components/chat/LiveMode.vue | 518 ++++++++++++++++++ dashboard/src/components/chat/LiveOrb.vue | 248 +++++++++ .../src/components/chat/StandaloneChat.vue | 1 + dashboard/src/composables/useVADRecording.ts | 163 ++++++ .../src/i18n/locales/en-US/features/chat.json | 6 +- .../src/i18n/locales/zh-CN/features/chat.json | 6 +- pyproject.toml | 4 + 17 files changed, 1721 insertions(+), 98 deletions(-) create mode 100644 astrbot/dashboard/routes/live_chat.py create mode 100644 dashboard/src/components/chat/LiveMode.vue create mode 100644 dashboard/src/components/chat/LiveOrb.vue create mode 100644 dashboard/src/composables/useVADRecording.ts diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index d57cf5e93..c9b0ea04c 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -1,3 +1,4 @@ +import asyncio import traceback from collections.abc import AsyncGenerator @@ -5,7 +6,7 @@ from astrbot.core import logger from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.message.components import Json +from astrbot.core.message.components import Json, Plain from astrbot.core.message.message_event_result import ( MessageChain, MessageEventResult, @@ -131,3 +132,175 @@ async def run_agent( else: astr_event.set_result(MessageEventResult().message(err_msg)) return + + +async def run_live_agent( + agent_runner: AgentRunner, + tts_provider, + max_step: int = 30, + show_tool_use: bool = True, + show_reasoning: bool = False, +) -> AsyncGenerator[MessageChain | None, None]: + """Live Mode 的 Agent 运行器,支持流式 TTS + + Args: + agent_runner: Agent 运行器 + tts_provider: TTS Provider 实例 + max_step: 最大步数 + show_tool_use: 是否显示工具使用 + show_reasoning: 是否显示推理过程 + + Yields: + MessageChain: 包含文本或音频数据的消息链 + """ + support_stream = tts_provider.support_stream() if tts_provider else False + + if support_stream: + logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)") + elif tts_provider: + logger.info( + f"[Live Agent] 使用 TTS({tts_provider.meta().type} " + "使用 get_audio,将累积完整文本后生成音频)" + ) + + # 收集 LLM 输出 + llm_stream_chunks: list[MessageChain] = [] + + # 运行普通 agent + async for chain in run_agent( + agent_runner, + max_step=max_step, + show_tool_use=show_tool_use, + stream_to_general=False, + show_reasoning=show_reasoning, + ): + if chain is not None: + llm_stream_chunks.append(chain) + + # 如果没有 TTS Provider,直接发送文本 + if not tts_provider: + for chain in llm_stream_chunks: + yield chain + return + + # 处理 TTS + if support_stream: + # 使用流式 TTS + async for audio_chunk in _process_stream_tts(llm_stream_chunks, tts_provider): + yield audio_chunk + else: + # 使用完整音频 TTS + async for audio_chunk in _process_full_tts(llm_stream_chunks, tts_provider): + yield audio_chunk + + +async def _process_stream_tts(chunks: list[MessageChain], tts_provider): + """处理流式 TTS""" + text_queue: asyncio.Queue[str | None] = asyncio.Queue() + audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + + # 启动 TTS 处理任务 + tts_task = asyncio.create_task( + tts_provider.get_audio_stream(text_queue, audio_queue) + ) + + chunk_size = 50 # 每 50 个字符发送一次给 TTS + + try: + # 喂文本给 TTS + feed_task = asyncio.create_task( + _feed_text_to_tts(chunks, text_queue, chunk_size) + ) + + # 从 TTS 输出队列中读取音频数据 + while True: + audio_data = await audio_queue.get() + + if audio_data is None: + break + + # 将音频数据封装为 MessageChain + import base64 + + audio_b64 = base64.b64encode(audio_data).decode("utf-8") + + chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") + yield chain + + await feed_task + + except Exception as e: + logger.error(f"[Live TTS] 流式处理失败: {e}", exc_info=True) + await text_queue.put(None) + + finally: + try: + await asyncio.wait_for(tts_task, timeout=5.0) + except asyncio.TimeoutError: + logger.warning("[Live TTS] TTS 任务超时,强制取消") + tts_task.cancel() + + +async def _feed_text_to_tts( + chunks: list[MessageChain], text_queue: asyncio.Queue, chunk_size: int +): + """从消息链中提取文本并分块发送给 TTS""" + accumulated_text = "" + + try: + for chain in chunks: + text = chain.get_plain_text() + if not text: + continue + + accumulated_text += text + + # 当累积的文本达到chunk_size时,发送给TTS + while len(accumulated_text) >= chunk_size: + chunk = accumulated_text[:chunk_size] + await text_queue.put(chunk) + accumulated_text = accumulated_text[chunk_size:] + + # 处理剩余文本 + if accumulated_text: + await text_queue.put(accumulated_text) + + finally: + # 发送结束标记 + await text_queue.put(None) + + +async def _process_full_tts(chunks: list[MessageChain], tts_provider): + """处理完整音频 TTS""" + accumulated_text = "" + + try: + # 累积所有文本 + for chain in chunks: + text = chain.get_plain_text() + if text: + accumulated_text += text + + # 如果没有文本,直接返回 + if not accumulated_text: + return + + logger.info(f"[Live TTS] 累积完整文本,长度: {len(accumulated_text)}") + + # 调用 get_audio 生成完整音频 + audio_path = await tts_provider.get_audio(accumulated_text) + + # 读取音频文件 + with open(audio_path, "rb") as f: + audio_data = f.read() + + # 将音频数据封装为 MessageChain + import base64 + + audio_b64 = base64.b64encode(audio_data).decode("utf-8") + + chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") + yield chain + + except Exception as e: + logger.error(f"[Live TTS] 完整音频生成失败: {e}", exc_info=True) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 43d88c5ad..2c6583fb3 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -31,7 +31,7 @@ from astrbot.core.utils.session_lock import session_lock_manager from .....astr_agent_context import AgentContextWrapper from .....astr_agent_hooks import MAIN_AGENT_HOOKS -from .....astr_agent_run_util import AgentRunner, run_agent +from .....astr_agent_run_util import AgentRunner, run_agent, run_live_agent from .....astr_agent_tool_exec import FunctionToolExecutor from ....context import PipelineContext, call_event_hook from ...stage import Stage @@ -684,7 +684,51 @@ class InternalAgentSubStage(Stage): enforce_max_turns=self.max_context_length, ) - if streaming_response and not stream_to_general: + # 检测 Live Mode + action_type = event.get_extra("action_type") + if action_type == "live": + # Live Mode: 使用 run_live_agent + logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理") + + # 获取 TTS Provider + tts_provider = ( + self.ctx.plugin_manager.context.get_using_tts_provider( + event.unified_msg_origin + ) + ) + + if not tts_provider: + logger.warning( + "[Live Mode] TTS Provider 未配置,将使用普通流式模式" + ) + + # 使用 run_live_agent,总是使用流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_live_agent( + agent_runner, + tts_provider, + self.max_step, + self.show_tool_use, + show_reasoning=self.show_reasoning, + ), + ), + ) + yield + + # 保存历史记录 + if not event.is_stopped() and agent_runner.done(): + await self._save_to_history( + event, + req, + agent_runner.get_final_llm_resp(), + agent_runner.run_context.messages, + agent_runner.stats, + ) + + elif streaming_response and not stream_to_general: # 流式响应 event.set_result( MessageEventResult() diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index e799e396e..36a451fbd 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -235,6 +235,7 @@ class WebChatAdapter(Platform): message_event.set_extra( "enable_streaming", payload.get("enable_streaming", True) ) + message_event.set_extra("action_type", payload.get("action_type")) self.commit_event(message_event) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 7d1c966e4..d62559b8a 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -128,6 +128,20 @@ class WebChatMessageEvent(AstrMessageEvent): web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) message_id = self.message_obj.message_id async for chain in generator: + # 处理音频流(Live Mode) + if chain.type == "audio_chunk": + # 音频流数据,直接发送 + audio_b64 = chain.get_plain_text() + await web_chat_back_queue.put( + { + "type": "audio_chunk", + "data": audio_b64, + "streaming": True, + "message_id": message_id, + }, + ) + continue + # if chain.type == "break" and final_data: # # 分割符 # await web_chat_back_queue.put( diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 6fb6d8953..04f567805 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -221,11 +221,65 @@ class TTSProvider(AbstractProvider): self.provider_config = provider_config self.provider_settings = provider_settings + def support_stream(self) -> bool: + """是否支持流式 TTS + + Returns: + bool: True 表示支持流式处理,False 表示不支持(默认) + + Notes: + 子类可以重写此方法返回 True 来启用流式 TTS 支持 + """ + return False + @abc.abstractmethod async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" raise NotImplementedError + async def get_audio_stream( + self, + text_queue: asyncio.Queue[str | None], + audio_queue: asyncio.Queue[bytes | None], + ) -> None: + """流式 TTS 处理方法。 + + 从 text_queue 中读取文本片段,将生成的音频数据(WAV 格式的 in-memory bytes)放入 audio_queue。 + 当 text_queue 收到 None 时,表示文本输入结束,此时应该处理完所有剩余文本并向 audio_queue 发送 None 表示结束。 + + Args: + text_queue: 输入文本队列,None 表示输入结束 + audio_queue: 输出音频队列(bytes),None 表示输出结束 + + Notes: + - 默认实现会将文本累积后一次性调用 get_audio 生成完整音频 + - 子类可以重写此方法实现真正的流式 TTS + - 音频数据应该是 WAV 格式的 bytes + """ + accumulated_text = "" + + while True: + text_part = await text_queue.get() + + if text_part is None: + # 输入结束,处理累积的文本 + if accumulated_text: + try: + # 调用原有的 get_audio 方法获取音频文件路径 + audio_path = await self.get_audio(accumulated_text) + # 读取音频文件内容 + with open(audio_path, "rb") as f: + audio_data = f.read() + await audio_queue.put(audio_data) + except Exception: + # 出错时也要发送 None 结束标记 + pass + # 发送结束标记 + await audio_queue.put(None) + break + + accumulated_text += text_part + async def test(self): await self.get_audio("hi") diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py new file mode 100644 index 000000000..db1f51e14 --- /dev/null +++ b/astrbot/dashboard/routes/live_chat.py @@ -0,0 +1,350 @@ +import asyncio +import os +import time +import uuid +import wave +from typing import Any + +import jwt +from quart import websocket + +from astrbot import logger +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from .route import Route, RouteContext + + +class LiveChatSession: + """Live Chat 会话管理器""" + + def __init__(self, session_id: str, username: str): + self.session_id = session_id + self.username = username + self.conversation_id = str(uuid.uuid4()) + self.is_speaking = False + self.is_processing = False + self.should_interrupt = False + self.audio_frames: list[bytes] = [] + self.current_stamp: str | None = None + self.temp_audio_path: str | None = None + + def start_speaking(self, stamp: str): + """开始说话""" + self.is_speaking = True + self.current_stamp = stamp + self.audio_frames = [] + logger.debug(f"[Live Chat] {self.username} 开始说话 stamp={stamp}") + + def add_audio_frame(self, data: bytes): + """添加音频帧""" + if self.is_speaking: + self.audio_frames.append(data) + + async def end_speaking(self, stamp: str) -> str | None: + """结束说话,返回组装的 WAV 文件路径""" + if not self.is_speaking or stamp != self.current_stamp: + logger.warning( + f"[Live Chat] stamp 不匹配或未在说话状态: {stamp} vs {self.current_stamp}" + ) + return None + + self.is_speaking = False + + if not self.audio_frames: + logger.warning("[Live Chat] 没有音频帧数据") + return None + + # 组装 WAV 文件 + try: + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + os.makedirs(temp_dir, exist_ok=True) + audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav") + + # 假设前端发送的是 PCM 数据,采样率 16000Hz,单声道,16位 + with wave.open(audio_path, "wb") as wav_file: + wav_file.setnchannels(1) # 单声道 + wav_file.setsampwidth(2) # 16位 = 2字节 + wav_file.setframerate(16000) # 采样率 16000Hz + for frame in self.audio_frames: + wav_file.writeframes(frame) + + self.temp_audio_path = audio_path + logger.info( + f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes" + ) + return audio_path + + except Exception as e: + logger.error(f"[Live Chat] 组装 WAV 文件失败: {e}", exc_info=True) + return None + + def cleanup(self): + """清理临时文件""" + if self.temp_audio_path and os.path.exists(self.temp_audio_path): + try: + os.remove(self.temp_audio_path) + logger.debug(f"[Live Chat] 已删除临时文件: {self.temp_audio_path}") + except Exception as e: + logger.warning(f"[Live Chat] 删除临时文件失败: {e}") + self.temp_audio_path = None + + +class LiveChatRoute(Route): + """Live Chat WebSocket 路由""" + + def __init__( + self, + context: RouteContext, + db: Any, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.core_lifecycle = core_lifecycle + self.db = db + self.plugin_manager = core_lifecycle.plugin_manager + self.sessions: dict[str, LiveChatSession] = {} + + # 注册 WebSocket 路由 + self.app.websocket("/api/live_chat/ws")(self.live_chat_ws) + + async def live_chat_ws(self): + """Live Chat WebSocket 处理器""" + # WebSocket 不能通过 header 传递 token,需要从 query 参数获取 + # 注意:WebSocket 上下文使用 websocket.args 而不是 request.args + token = websocket.args.get("token") + if not token: + await websocket.close(1008, "Missing authentication token") + return + + try: + jwt_secret = self.config["dashboard"].get("jwt_secret") + payload = jwt.decode(token, jwt_secret, algorithms=["HS256"]) + username = payload["username"] + except jwt.ExpiredSignatureError: + await websocket.close(1008, "Token expired") + return + except jwt.InvalidTokenError: + await websocket.close(1008, "Invalid token") + return + + session_id = f"webchat_live!{username}!{uuid.uuid4()}" + live_session = LiveChatSession(session_id, username) + self.sessions[session_id] = live_session + + logger.info(f"[Live Chat] WebSocket 连接建立: {username}") + + try: + while True: + message = await websocket.receive_json() + await self._handle_message(live_session, message) + + except Exception as e: + logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True) + + finally: + # 清理会话 + if session_id in self.sessions: + live_session.cleanup() + del self.sessions[session_id] + logger.info(f"[Live Chat] WebSocket 连接关闭: {username}") + + async def _handle_message(self, session: LiveChatSession, message: dict): + """处理 WebSocket 消息""" + msg_type = message.get("t") # 使用 t 代替 type + + if msg_type == "start_speaking": + # 开始说话 + stamp = message.get("stamp") + if not stamp: + logger.warning("[Live Chat] start_speaking 缺少 stamp") + return + session.start_speaking(stamp) + + elif msg_type == "speaking_part": + # 音频片段 + audio_data_b64 = message.get("data") + if not audio_data_b64: + return + + # 解码 base64 + import base64 + + try: + audio_data = base64.b64decode(audio_data_b64) + session.add_audio_frame(audio_data) + except Exception as e: + logger.error(f"[Live Chat] 解码音频数据失败: {e}") + + elif msg_type == "end_speaking": + # 结束说话 + stamp = message.get("stamp") + if not stamp: + logger.warning("[Live Chat] end_speaking 缺少 stamp") + return + + audio_path = await session.end_speaking(stamp) + if not audio_path: + await websocket.send_json({"t": "error", "data": "音频组装失败"}) + return + + # 处理音频:STT -> LLM -> TTS + await self._process_audio(session, audio_path) + + elif msg_type == "interrupt": + # 用户打断 + session.should_interrupt = True + logger.info(f"[Live Chat] 用户打断: {session.username}") + + async def _process_audio(self, session: LiveChatSession, audio_path: str): + """处理音频:STT -> LLM -> 流式 TTS""" + try: + session.is_processing = True + session.should_interrupt = False + + # 1. STT - 语音转文字 + ctx = self.plugin_manager.context + stt_provider = ctx.provider_manager.stt_provider_insts[0] + + if not stt_provider: + logger.error("[Live Chat] STT Provider 未配置") + await websocket.send_json({"t": "error", "data": "语音识别服务未配置"}) + return + + user_text = await stt_provider.get_text(audio_path) + if not user_text: + logger.warning("[Live Chat] STT 识别结果为空") + return + + logger.info(f"[Live Chat] STT 结果: {user_text}") + + # 发送用户消息 + import time + + await websocket.send_json( + { + "t": "user_msg", + "data": {"text": user_text, "ts": int(time.time() * 1000)}, + } + ) + + # 2. 构造消息事件并发送到 pipeline + # 使用 webchat queue 机制 + cid = session.conversation_id + queue = webchat_queue_mgr.get_or_create_queue(cid) + + message_id = str(uuid.uuid4()) + payload = { + "message_id": message_id, + "message": [{"type": "plain", "text": user_text}], # 直接发送文本 + "action_type": "live", # 标记为 live mode + } + + # 将消息放入队列 + await queue.put((session.username, cid, payload)) + + # 3. 等待响应并流式发送 TTS 音频 + back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) + + bot_text = "" + audio_playing = False + + while True: + if session.should_interrupt: + # 用户打断,停止处理 + logger.info("[Live Chat] 检测到用户打断") + await websocket.send_json({"t": "stop_play"}) + # 保存消息并标记为被打断 + await self._save_interrupted_message(session, user_text, bot_text) + # 清空队列中未处理的消息 + while not back_queue.empty(): + try: + back_queue.get_nowait() + except asyncio.QueueEmpty: + break + break + + try: + result = await asyncio.wait_for(back_queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + + if not result: + continue + + result_message_id = result.get("message_id") + if result_message_id != message_id: + logger.warning( + f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}" + ) + continue + + result_type = result.get("type") + data = result.get("data", "") + + if result_type == "plain": + # 普通文本消息 + bot_text += data + + elif result_type == "audio_chunk": + # 流式音频数据 + if not audio_playing: + audio_playing = True + logger.debug("[Live Chat] 开始播放音频流") + + # 发送音频数据给前端 + await websocket.send_json( + { + "t": "response", + "data": data, # base64 编码的音频数据 + } + ) + + elif result_type in ["complete", "end"]: + # 处理完成 + logger.info(f"[Live Chat] Bot 回复完成: {bot_text}") + + # 如果没有音频流,发送 bot 消息文本 + if not audio_playing: + await websocket.send_json( + { + "t": "bot_msg", + "data": { + "text": bot_text, + "ts": int(time.time() * 1000), + }, + } + ) + + # 发送结束标记 + await websocket.send_json({"t": "end"}) + break + + except Exception as e: + logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True) + await websocket.send_json({"t": "error", "data": f"处理失败: {str(e)}"}) + + finally: + session.is_processing = False + session.should_interrupt = False + + async def _save_interrupted_message( + self, session: LiveChatSession, user_text: str, bot_text: str + ): + """保存被打断的消息""" + interrupted_text = bot_text + " [用户打断]" + logger.info(f"[Live Chat] 保存打断消息: {interrupted_text}") + + # 简单记录到日志,实际保存逻辑可以后续完善 + try: + timestamp = int(time.time() * 1000) + logger.info( + f"[Live Chat] 用户消息: {user_text} (session: {session.session_id}, ts: {timestamp})" + ) + if bot_text: + logger.info( + f"[Live Chat] Bot 消息(打断): {interrupted_text} (session: {session.session_id}, ts: {timestamp})" + ) + except Exception as e: + logger.error(f"[Live Chat] 记录消息失败: {e}", exc_info=True) diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index afac7fedb..0afee6037 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -20,6 +20,7 @@ from astrbot.core.utils.io import get_local_ip_addresses from .routes import * from .routes.backup import BackupRoute +from .routes.live_chat import LiveChatRoute from .routes.platform import PlatformRoute from .routes.route import Response, RouteContext from .routes.session_management import SessionManagementRoute @@ -88,6 +89,7 @@ class AstrBotDashboard: self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle) self.platform_route = PlatformRoute(self.context, core_lifecycle) self.backup_route = BackupRoute(self.context, db, core_lifecycle) + self.live_chat_route = LiveChatRoute(self.context, db, core_lifecycle) self.app.add_url_rule( "/api/plug/", diff --git a/dashboard/index.html b/dashboard/index.html index 367bec27b..d016f8748 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -10,6 +10,9 @@ rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Outfit&family=Poppins:wght@400;500;600;700&family=Roboto:wght@400;500;700&display=swap" /> + + + AstrBot - 仪表盘 diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index a2c85b946..adc8d65ee 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -30,44 +30,105 @@
+ + -
- - - mdi-menu - -
- - - + @@ -202,13 +211,14 @@ import ProjectDialog from '@/components/chat/ProjectDialog.vue'; import ProjectView from '@/components/chat/ProjectView.vue'; import WelcomeView from '@/components/chat/WelcomeView.vue'; import RefsSidebar from '@/components/chat/message_list_comps/RefsSidebar.vue'; +import LiveMode from '@/components/chat/LiveMode.vue'; import type { ProjectFormData } from '@/components/chat/ProjectDialog.vue'; import { useSessions } from '@/composables/useSessions'; import { useMessages } from '@/composables/useMessages'; import { useMediaHandling } from '@/composables/useMediaHandling'; -import { useRecording } from '@/composables/useRecording'; import { useProjects } from '@/composables/useProjects'; import type { Project } from '@/components/chat/ProjectList.vue'; +import { useRecording } from '@/composables/useRecording'; interface Props { chatboxMode?: boolean; @@ -230,6 +240,7 @@ const mobileMenuOpen = ref(false); const imagePreviewDialog = ref(false); const previewImageUrl = ref(''); const isLoadingMessages = ref(false); +const liveModeOpen = ref(false); // 使用 composables const { @@ -266,7 +277,7 @@ const { cleanupMediaCache } = useMediaHandling(); -const { isRecording, startRecording: startRec, stopRecording: stopRec } = useRecording(); +const { isRecording: isRecording, startRecording: startRec, stopRecording: stopRec } = useRecording(); const { projects, @@ -551,6 +562,14 @@ async function handleFileSelect(files: FileList) { } } +function openLiveMode() { + liveModeOpen.value = true; +} + +function closeLiveMode() { + liveModeOpen.value = false; +} + async function handleSendMessage() { // 只有引用不能发送,必须有输入内容 if (!prompt.value.trim() && stagedFiles.value.length === 0 && !stagedAudioUrl.value) { diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index b28e1edc1..740b15ffc 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -85,9 +85,29 @@ + + + + {{ tm('voice.liveMode') }} + + + icon + variant="text" + :color="isRecording ? 'error' : 'deep-purple'" + class="record-btn" + size="small" + > + + + {{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }} + +
@@ -179,6 +199,7 @@ const emit = defineEmits<{ pasteImage: [event: ClipboardEvent]; fileSelect: [files: FileList]; clearReply: []; + openLiveMode: []; }>(); const { tm } = useModuleI18n('features/chat'); diff --git a/dashboard/src/components/chat/LiveMode.vue b/dashboard/src/components/chat/LiveMode.vue new file mode 100644 index 000000000..737f05742 --- /dev/null +++ b/dashboard/src/components/chat/LiveMode.vue @@ -0,0 +1,518 @@ + + + + + diff --git a/dashboard/src/components/chat/LiveOrb.vue b/dashboard/src/components/chat/LiveOrb.vue new file mode 100644 index 000000000..7ca851eb7 --- /dev/null +++ b/dashboard/src/components/chat/LiveOrb.vue @@ -0,0 +1,248 @@ + + + + + diff --git a/dashboard/src/components/chat/StandaloneChat.vue b/dashboard/src/components/chat/StandaloneChat.vue index 2dcc8aeb8..25ca7faf9 100644 --- a/dashboard/src/components/chat/StandaloneChat.vue +++ b/dashboard/src/components/chat/StandaloneChat.vue @@ -36,6 +36,7 @@ @stopRecording="handleStopRecording" @pasteImage="handlePaste" @fileSelect="handleFileSelect" + @openLiveMode="" ref="chatInputRef" /> diff --git a/dashboard/src/composables/useVADRecording.ts b/dashboard/src/composables/useVADRecording.ts new file mode 100644 index 000000000..7a7998c68 --- /dev/null +++ b/dashboard/src/composables/useVADRecording.ts @@ -0,0 +1,163 @@ +import { ref, onBeforeUnmount } from 'vue'; +import axios from 'axios'; + +interface VADOptions { + onSpeechStart?: () => void; + onSpeechRealStart?: () => void; + onSpeechEnd: (audio: Float32Array) => void; + onVADMisfire?: () => void; + onFrameProcessed?: (probabilities: { isSpeech: number; notSpeech: number }, frame: Float32Array) => void; + positiveSpeechThreshold?: number; + negativeSpeechThreshold?: number; + redemptionMs?: number; + preSpeechPadMs?: number; + minSpeechMs?: number; + submitUserSpeechOnPause?: boolean; + model?: 'v5' | 'legacy'; + baseAssetPath?: string; + onnxWASMBasePath?: string; +} + +interface VADInstance { + start(): void; + pause(): void; + listening: boolean; +} + +// 声明全局 vad 对象类型 +declare global { + interface Window { + vad: { + MicVAD: { + new(options: VADOptions): Promise; + }; + }; + } +} + +/** + * 使用 VAD (Voice Activity Detection) 进行录音的 composable + * VAD 会自动检测用户何时开始和停止说话,无需手动控制 + */ +export function useVADRecording() { + const isRecording = ref(false); + const isSpeaking = ref(false); + const audioEnergy = ref(0); // 0-1 之间的能量值 + const vadInstance = ref(null); + const isInitialized = ref(false); + const onSpeechStartCallback = ref<(() => void) | null>(null); + const onSpeechEndCallback = ref<((audio: Float32Array) => void) | null>(null); + + // Live Mode 不需要上传音频,直接通过 WebSocket 实时发送 + + // 初始化 VAD + async function initVAD() { + if (!window.vad) { + console.error('VAD library not loaded. Please ensure the scripts are included in index.html'); + return; + } + + try { + vadInstance.value = await (window.vad.MicVAD as any).new({ + onSpeechStart: () => { + console.log('[VAD] Speech started'); + isSpeaking.value = true; + // 调用开始说话回调 + if (onSpeechStartCallback.value) { + onSpeechStartCallback.value(); + } + }, + onSpeechRealStart: () => { + console.log('[VAD] Real speech started'); + }, + onSpeechEnd: (audio: Float32Array) => { + console.log('[VAD] Speech ended, audio length:', audio.length); + isSpeaking.value = false; + // 调用语音结束回调,传递原始音频数据 + if (onSpeechEndCallback.value) { + onSpeechEndCallback.value(audio); + } + }, + onVADMisfire: () => { + console.log('[VAD] VAD misfire - speech segment too short'); + isSpeaking.value = false; + }, + onFrameProcessed: (probabilities: { isSpeech: number; notSpeech: number }, frame: Float32Array) => { + // 计算 RMS (Root Mean Square) 作为能量 + let sum = 0; + for (let i = 0; i < frame.length; i++) { + sum += frame[i] * frame[i]; + } + const rms = Math.sqrt(sum / frame.length); + // 简单的归一化及平滑处理,根据经验 RMS 通常较小 + // 放大系数可以根据实际情况调整 + const targetEnergy = Math.min(rms * 5, 1); + audioEnergy.value = audioEnergy.value * 0.8 + targetEnergy * 0.2; + }, + // VAD 配置参数 + positiveSpeechThreshold: 0.3, + negativeSpeechThreshold: 0.25, + redemptionMs: 1400, + preSpeechPadMs: 800, + minSpeechMs: 400, + submitUserSpeechOnPause: false, + model: 'v5', + baseAssetPath: 'https://cdn.jsdelivr.net/npm/@ricky0123/vad-web@0.0.29/dist/', + onnxWASMBasePath: 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.22.0/dist/' + }); + + isInitialized.value = true; + console.log('VAD initialized successfully'); + } catch (error) { + console.error('Failed to initialize VAD:', error); + isInitialized.value = false; + } + } + + // 开始录音(启动 VAD) + async function startRecording( + onSpeechStart: () => void, + onSpeechEnd: (audio: Float32Array) => void + ) { + // 存储回调函数 + onSpeechStartCallback.value = onSpeechStart; + onSpeechEndCallback.value = onSpeechEnd; + + if (!isInitialized.value) { + await initVAD(); + } + + if (vadInstance.value) { + vadInstance.value.start(); + isRecording.value = true; + console.log('[VAD] Started'); + } + } + + // 停止录音(暂停 VAD) + function stopRecording() { + if (vadInstance.value) { + vadInstance.value.pause(); + isRecording.value = false; + isSpeaking.value = false; + onSpeechStartCallback.value = null; + onSpeechEndCallback.value = null; + console.log('[VAD] Stopped'); + } + } + + // 清理资源 + onBeforeUnmount(() => { + if (vadInstance.value && isRecording.value) { + stopRecording(); + } + }); + + return { + isRecording, + isSpeaking, // 用户是否正在说话 + audioEnergy, // 当前音频能量 + startRecording, + stopRecording + }; +} diff --git a/dashboard/src/i18n/locales/en-US/features/chat.json b/dashboard/src/i18n/locales/en-US/features/chat.json index cb1695978..684afe23e 100644 --- a/dashboard/src/i18n/locales/en-US/features/chat.json +++ b/dashboard/src/i18n/locales/en-US/features/chat.json @@ -22,7 +22,11 @@ "stop": "Stop Recording", "recording": "New Recording", "processing": "Processing...", - "error": "Recording Failed" + "error": "Recording Failed", + "listening": "Listening...", + "speaking": "Speaking", + "startRecording": "Start Voice Input", + "liveMode": "Live Mode" }, "welcome": { "title": "Welcome to AstrBot", diff --git a/dashboard/src/i18n/locales/zh-CN/features/chat.json b/dashboard/src/i18n/locales/zh-CN/features/chat.json index c08e6ccd6..96c0931ce 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/chat.json +++ b/dashboard/src/i18n/locales/zh-CN/features/chat.json @@ -22,7 +22,11 @@ "stop": "停止录音", "recording": "新录音", "processing": "处理中...", - "error": "录音失败" + "error": "录音失败", + "listening": "等待语音...", + "speaking": "正在说话", + "startRecording": "开始语音输入", + "liveMode": "实时对话" }, "welcome": { "title": "欢迎使用 AstrBot", diff --git a/pyproject.toml b/pyproject.toml index 1fa8e056c..f0e05c634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,10 @@ dependencies = [ "xinference-client", "tenacity>=9.1.2", "shipyard-python-sdk>=0.2.4", + "funasr-onnx>=0.4.1", + "modelscope>=1.33.0", + "funasr>=1.3.0", + "torchaudio>=2.9.1", ] [dependency-groups]