From 856d3496fa567aa812011e02ef3a0b0aaff3fff6 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 17 Jan 2026 15:35:02 +0800 Subject: [PATCH] feat: enhance audio processing and metrics display in live mode --- astrbot/core/astr_agent_run_util.py | 277 +++++++++++---------- dashboard/src/components/chat/LiveMode.vue | 152 ++++++++--- 2 files changed, 273 insertions(+), 156 deletions(-) diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index f5962e622..9d7301516 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -1,4 +1,5 @@ import asyncio +import re import time import traceback from collections.abc import AsyncGenerator @@ -155,55 +156,85 @@ async def run_live_agent( Yields: MessageChain: 包含文本或音频数据的消息链 """ - support_stream = tts_provider.support_stream() if tts_provider else False - - if support_stream: - logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)") - elif tts_provider: - logger.info( - f"[Live Agent] 使用 TTS({tts_provider.meta().type} " - "使用 get_audio,将累积完整文本后生成音频)" - ) - - # 收集 LLM 输出 - llm_stream_chunks: list[MessageChain] = [] - - # 运行普通 agent - async for chain in run_agent( - agent_runner, - max_step=max_step, - show_tool_use=show_tool_use, - stream_to_general=False, - show_reasoning=show_reasoning, - ): - if chain is not None: - llm_stream_chunks.append(chain) - # 如果没有 TTS Provider,直接发送文本 if not tts_provider: - for chain in llm_stream_chunks: + async for chain in run_agent( + agent_runner, + max_step=max_step, + show_tool_use=show_tool_use, + stream_to_general=False, + show_reasoning=show_reasoning, + ): yield chain return - # 处理 TTS + support_stream = tts_provider.support_stream() + if support_stream: + logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)") + else: + logger.info( + f"[Live Agent] 使用 TTS({tts_provider.meta().type} " + "使用 get_audio,将按句子分块生成音频)" + ) + + # 统计数据初始化 tts_start_time = time.time() tts_first_frame_time = 0.0 first_chunk_received = False + # 创建队列 + text_queue: asyncio.Queue[str | None] = asyncio.Queue() + audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + + # 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue + feeder_task = asyncio.create_task( + _run_agent_feeder( + agent_runner, text_queue, max_step, show_tool_use, show_reasoning + ) + ) + + # 2. 启动 TTS 任务:负责从 text_queue 读取文本并生成音频到 audio_queue if support_stream: - # 使用流式 TTS - async for audio_chunk in _process_stream_tts(llm_stream_chunks, tts_provider): - if not first_chunk_received: - tts_first_frame_time = time.time() - tts_start_time - first_chunk_received = True - yield audio_chunk + tts_task = asyncio.create_task( + _safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue) + ) else: - # 使用完整音频 TTS - async for audio_chunk in _process_full_tts(llm_stream_chunks, tts_provider): + tts_task = asyncio.create_task( + _simulated_stream_tts(tts_provider, text_queue, audio_queue) + ) + + # 3. 主循环:从 audio_queue 读取音频并 yield + try: + while True: + audio_data = await audio_queue.get() + + if audio_data is None: + break + if not first_chunk_received: + # 记录首帧延迟(从开始处理到收到第一个音频块) tts_first_frame_time = time.time() - tts_start_time first_chunk_received = True - yield audio_chunk + + # 将音频数据封装为 MessageChain + import base64 + + audio_b64 = base64.b64encode(audio_data).decode("utf-8") + chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") + yield chain + + except Exception as e: + logger.error(f"[Live Agent] 运行时发生错误: {e}", exc_info=True) + finally: + # 清理任务 + if not feeder_task.done(): + feeder_task.cancel() + if not tts_task.done(): + tts_task.cancel() + + # 确保队列被消费 + pass + tts_end_time = time.time() # 发送 TTS 统计信息 @@ -228,113 +259,105 @@ async def run_live_agent( logger.error(f"发送 TTS 统计信息失败: {e}") -async def _process_stream_tts(chunks: list[MessageChain], tts_provider): - """处理流式 TTS""" - text_queue: asyncio.Queue[str | None] = asyncio.Queue() - audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() - - # 启动 TTS 处理任务 - tts_task = asyncio.create_task( - tts_provider.get_audio_stream(text_queue, audio_queue) - ) - - chunk_size = 50 # 每 50 个字符发送一次给 TTS - - try: - # 喂文本给 TTS - feed_task = asyncio.create_task( - _feed_text_to_tts(chunks, text_queue, chunk_size) - ) - - # 从 TTS 输出队列中读取音频数据 - while True: - audio_data = await audio_queue.get() - - if audio_data is None: - break - - # 将音频数据封装为 MessageChain - import base64 - - audio_b64 = base64.b64encode(audio_data).decode("utf-8") - - chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") - yield chain - - await feed_task - - except Exception as e: - logger.error(f"[Live TTS] 流式处理失败: {e}", exc_info=True) - await text_queue.put(None) - - finally: - try: - await asyncio.wait_for(tts_task, timeout=5.0) - except asyncio.TimeoutError: - logger.warning("[Live TTS] TTS 任务超时,强制取消") - tts_task.cancel() - - -async def _feed_text_to_tts( - chunks: list[MessageChain], text_queue: asyncio.Queue, chunk_size: int +async def _run_agent_feeder( + agent_runner: AgentRunner, + text_queue: asyncio.Queue, + max_step: int, + show_tool_use: bool, + show_reasoning: bool, ): - """从消息链中提取文本并分块发送给 TTS""" - accumulated_text = "" - + """运行 Agent 并将文本输出分句放入队列""" + buffer = "" try: - for chain in chunks: - text = chain.get_plain_text() - if not text: + async for chain in run_agent( + agent_runner, + max_step=max_step, + show_tool_use=show_tool_use, + stream_to_general=False, + show_reasoning=show_reasoning, + ): + if chain is None: continue - accumulated_text += text + # 提取文本 + text = chain.get_plain_text() + if text: + buffer += text - # 当累积的文本达到chunk_size时,发送给TTS - while len(accumulated_text) >= chunk_size: - chunk = accumulated_text[:chunk_size] - await text_queue.put(chunk) - accumulated_text = accumulated_text[chunk_size:] + # 分句逻辑:匹配标点符号 + # r"([.。!!??\n]+)" 会保留分隔符 + parts = re.split(r"([.。!!??\n]+)", buffer) - # 处理剩余文本 - if accumulated_text: - await text_queue.put(accumulated_text) + if len(parts) > 1: + # 处理完整的句子 + # range step 2 因为 split 后是 [text, delim, text, delim, ...] + temp_buffer = "" + for i in range(0, len(parts) - 1, 2): + sentence = parts[i] + delim = parts[i + 1] + full_sentence = sentence + delim + temp_buffer += full_sentence + if len(temp_buffer) >= 10: + if temp_buffer.strip(): + logger.info(f"[Live Agent Feeder] 分句: {temp_buffer}") + await text_queue.put(temp_buffer) + temp_buffer = "" + + # 更新 buffer 为剩余部分 + buffer = temp_buffer + parts[-1] + + # 处理剩余 buffer + if buffer.strip(): + await text_queue.put(buffer) + + except Exception as e: + logger.error(f"[Live Agent Feeder] Error: {e}", exc_info=True) finally: - # 发送结束标记 + # 发送结束信号 await text_queue.put(None) -async def _process_full_tts(chunks: list[MessageChain], tts_provider): - """处理完整音频 TTS""" - accumulated_text = "" - +async def _safe_tts_stream_wrapper( + tts_provider: TTSProvider, + text_queue: asyncio.Queue[str | None], + audio_queue: asyncio.Queue[bytes | None], +): + """包装原生流式 TTS 确保异常处理和队列关闭""" try: - # 累积所有文本 - for chain in chunks: - text = chain.get_plain_text() - if text: - accumulated_text += text + await tts_provider.get_audio_stream(text_queue, audio_queue) + except Exception as e: + logger.error(f"[Live TTS Stream] Error: {e}", exc_info=True) + finally: + await audio_queue.put(None) - # 如果没有文本,直接返回 - if not accumulated_text: - return - logger.info(f"[Live TTS] 累积完整文本,长度: {len(accumulated_text)}") +async def _simulated_stream_tts( + tts_provider: TTSProvider, + text_queue: asyncio.Queue[str | None], + audio_queue: asyncio.Queue[bytes | None], +): + """模拟流式 TTS 分句生成音频""" + try: + while True: + text = await text_queue.get() + if text is None: + break - # 调用 get_audio 生成完整音频 - audio_path = await tts_provider.get_audio(accumulated_text) + try: + audio_path = await tts_provider.get_audio(text) - # 读取音频文件 - with open(audio_path, "rb") as f: - audio_data = f.read() - - # 将音频数据封装为 MessageChain - import base64 - - audio_b64 = base64.b64encode(audio_data).decode("utf-8") - - chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") - yield chain + if audio_path: + with open(audio_path, "rb") as f: + audio_data = f.read() + await audio_queue.put(audio_data) + except Exception as e: + logger.error( + f"[Live TTS Simulated] Error processing text '{text[:20]}...': {e}" + ) + # 继续处理下一句 except Exception as e: - logger.error(f"[Live TTS] 完整音频生成失败: {e}", exc_info=True) + logger.error(f"[Live TTS Simulated] Critical Error: {e}", exc_info=True) + finally: + await audio_queue.put(None) diff --git a/dashboard/src/components/chat/LiveMode.vue b/dashboard/src/components/chat/LiveMode.vue index bfd602c5a..81e333c34 100644 --- a/dashboard/src/components/chat/LiveMode.vue +++ b/dashboard/src/components/chat/LiveMode.vue @@ -21,13 +21,20 @@
- WAV Assemble: {{ (metrics.wav_assemble_time * 1000).toFixed(0) }}ms - LLM First Token Latency: {{ (metrics.llm_ttft * 1000).toFixed(0) }}ms - LLM Total Latency: {{ (metrics.llm_total_time * 1000).toFixed(0) }}ms - TTS First Frame Latency: {{ (metrics.tts_first_frame_time * 1000).toFixed(0) }}ms - TTS Total Larency: {{ (metrics.tts_total_time * 1000).toFixed(0) }}ms - Speak -> First TTS Frame: {{ (metrics.speak_to_first_frame * 1000).toFixed(0) }}ms - Speak -> End: {{ (metrics.wav_to_tts_total_time * 1000).toFixed(0) }}ms + WAV Assemble: {{ (metrics.wav_assemble_time * 1000).toFixed(0) + }}ms + LLM First Token Latency: {{ (metrics.llm_ttft * 1000).toFixed(0) + }}ms + LLM Total Latency: {{ (metrics.llm_total_time * 1000).toFixed(0) + }}ms + TTS First Frame Latency: {{ (metrics.tts_first_frame_time * + 1000).toFixed(0) }}ms + TTS Total Larency: {{ (metrics.tts_total_time * 1000).toFixed(0) + }}ms + Speak -> First TTS Frame: {{ (metrics.speak_to_first_frame * + 1000).toFixed(0) }}ms + Speak -> End: {{ (metrics.wav_to_tts_total_time * + 1000).toFixed(0) }}ms
@@ -65,7 +72,15 @@ let audioContext: AudioContext | null = null; let analyser: AnalyserNode | null = null; const botEnergy = ref(0); let energyLoopId: number; -let isPlaying = ref(false); +let isPlaying = ref(false); // UI 状态:是否正在播放 + +// 音频播放队列管理 +const rawAudioQueue: Uint8Array[] = []; // 待解码队列 +const audioBufferQueue: AudioBuffer[] = []; // 待播放队列 +let isDecoding = false; +let isPlayingAudio = false; // 内部状态:是否正在播放音频 +let currentSource: AudioBufferSourceNode | null = null; + // 消息历史 const messages = ref>([]); @@ -324,7 +339,7 @@ function handleWebSocketMessage(event: MessageEvent) { isProcessing.value = false; isListening.value = true; break; - + case 'metrics': metrics.value = { ...metrics.value, ...message.data }; break; @@ -345,35 +360,112 @@ function playAudioChunk(base64Data: string) { bytes[i] = binaryString.charCodeAt(i); } - // 解码 WAV 音频 - audioContext.decodeAudioData(bytes.buffer).then(audioBuffer => { - const source = audioContext!.createBufferSource(); - source.buffer = audioBuffer; - // 连接到分析器 - if (analyser) { - source.connect(analyser); - analyser.connect(audioContext!.destination); - } else { - source.connect(audioContext!.destination); - } - source.start(); - isPlaying.value = true; + // 放入待解码队列 + rawAudioQueue.push(bytes); - source.onended = () => { - isPlaying.value = false; - }; - }).catch(error => { - console.error('[Live Mode] 解码音频失败:', error); - }); + // 触发解码处理 + processRawAudioQueue(); + + } catch (error) { + console.error('[Live Mode] 接收音频数据失败:', error); + } +} + +async function processRawAudioQueue() { + if (isDecoding || rawAudioQueue.length === 0) return; + + isDecoding = true; + + try { + while (rawAudioQueue.length > 0) { + const bytes = rawAudioQueue.shift(); + if (!bytes || !audioContext) continue; + + try { + // 解码 + const audioBuffer = await audioContext.decodeAudioData(bytes.buffer as ArrayBuffer); + audioBufferQueue.push(audioBuffer); + + // 如果当前没有播放,立即开始播放 + if (!isPlayingAudio) { + playNextAudio(); + } + } catch (err) { + console.error('[Live Mode] 解码音频失败:', err); + } + } + } finally { + isDecoding = false; + // 如果在解码过程中又有新数据进来,继续处理 + if (rawAudioQueue.length > 0) { + processRawAudioQueue(); + } + } +} + +function playNextAudio() { + if (audioBufferQueue.length === 0) { + isPlayingAudio = false; + isPlaying.value = false; + return; + } + + if (!audioContext) return; + + isPlayingAudio = true; + isPlaying.value = true; + + try { + const audioBuffer = audioBufferQueue.shift(); + if (!audioBuffer) return; + + const source = audioContext.createBufferSource(); + source.buffer = audioBuffer; + + // 连接到分析器 + if (analyser) { + source.connect(analyser); + analyser.connect(audioContext.destination); + } else { + source.connect(audioContext.destination); + } + + currentSource = source; + source.start(); + + source.onended = () => { + currentSource = null; + playNextAudio(); + }; } catch (error) { console.error('[Live Mode] 播放音频失败:', error); + isPlayingAudio = false; + isPlaying.value = false; + playNextAudio(); // 尝试播放下一个 } } function stopAudioPlayback() { - // TODO: 实现停止当前播放的音频 + // 停止当前播放源 + if (currentSource) { + try { + currentSource.stop(); + currentSource.disconnect(); + } catch (e) { + // ignore + } + currentSource = null; + } + + // 清空队列 + rawAudioQueue.length = 0; + audioBufferQueue.length = 0; + + // 重置状态 + isPlayingAudio = false; isPlaying.value = false; + isDecoding = false; } function generateStamp(): string { @@ -415,6 +507,8 @@ watch(isSpeaking, (newVal) => { if (ws && ws.readyState === WebSocket.OPEN) { ws.send(JSON.stringify({ t: 'interrupt' })); } + // 本地立即停止播放 + stopAudioPlayback(); } });