From 856d3496fa567aa812011e02ef3a0b0aaff3fff6 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 17 Jan 2026 15:35:02 +0800 Subject: [PATCH] feat: enhance audio processing and metrics display in live mode --- astrbot/core/astr_agent_run_util.py | 277 +++++++++++---------- dashboard/src/components/chat/LiveMode.vue | 152 ++++++++--- 2 files changed, 273 insertions(+), 156 deletions(-) diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index f5962e622..9d7301516 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -1,4 +1,5 @@ import asyncio +import re import time import traceback from collections.abc import AsyncGenerator @@ -155,55 +156,85 @@ async def run_live_agent( Yields: MessageChain: 包含文本或音频数据的消息链 """ - support_stream = tts_provider.support_stream() if tts_provider else False - - if support_stream: - logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)") - elif tts_provider: - logger.info( - f"[Live Agent] 使用 TTS({tts_provider.meta().type} " - "使用 get_audio,将累积完整文本后生成音频)" - ) - - # 收集 LLM 输出 - llm_stream_chunks: list[MessageChain] = [] - - # 运行普通 agent - async for chain in run_agent( - agent_runner, - max_step=max_step, - show_tool_use=show_tool_use, - stream_to_general=False, - show_reasoning=show_reasoning, - ): - if chain is not None: - llm_stream_chunks.append(chain) - # 如果没有 TTS Provider,直接发送文本 if not tts_provider: - for chain in llm_stream_chunks: + async for chain in run_agent( + agent_runner, + max_step=max_step, + show_tool_use=show_tool_use, + stream_to_general=False, + show_reasoning=show_reasoning, + ): yield chain return - # 处理 TTS + support_stream = tts_provider.support_stream() + if support_stream: + logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)") + else: + logger.info( + f"[Live Agent] 使用 TTS({tts_provider.meta().type} " + "使用 get_audio,将按句子分块生成音频)" + ) + + # 统计数据初始化 tts_start_time = time.time() tts_first_frame_time = 0.0 first_chunk_received = False + # 创建队列 + text_queue: asyncio.Queue[str | None] = asyncio.Queue() + audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + + # 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue + feeder_task = asyncio.create_task( + _run_agent_feeder( + agent_runner, text_queue, max_step, show_tool_use, show_reasoning + ) + ) + + # 2. 启动 TTS 任务:负责从 text_queue 读取文本并生成音频到 audio_queue if support_stream: - # 使用流式 TTS - async for audio_chunk in _process_stream_tts(llm_stream_chunks, tts_provider): - if not first_chunk_received: - tts_first_frame_time = time.time() - tts_start_time - first_chunk_received = True - yield audio_chunk + tts_task = asyncio.create_task( + _safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue) + ) else: - # 使用完整音频 TTS - async for audio_chunk in _process_full_tts(llm_stream_chunks, tts_provider): + tts_task = asyncio.create_task( + _simulated_stream_tts(tts_provider, text_queue, audio_queue) + ) + + # 3. 主循环:从 audio_queue 读取音频并 yield + try: + while True: + audio_data = await audio_queue.get() + + if audio_data is None: + break + if not first_chunk_received: + # 记录首帧延迟(从开始处理到收到第一个音频块) tts_first_frame_time = time.time() - tts_start_time first_chunk_received = True - yield audio_chunk + + # 将音频数据封装为 MessageChain + import base64 + + audio_b64 = base64.b64encode(audio_data).decode("utf-8") + chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") + yield chain + + except Exception as e: + logger.error(f"[Live Agent] 运行时发生错误: {e}", exc_info=True) + finally: + # 清理任务 + if not feeder_task.done(): + feeder_task.cancel() + if not tts_task.done(): + tts_task.cancel() + + # 确保队列被消费 + pass + tts_end_time = time.time() # 发送 TTS 统计信息 @@ -228,113 +259,105 @@ async def run_live_agent( logger.error(f"发送 TTS 统计信息失败: {e}") -async def _process_stream_tts(chunks: list[MessageChain], tts_provider): - """处理流式 TTS""" - text_queue: asyncio.Queue[str | None] = asyncio.Queue() - audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() - - # 启动 TTS 处理任务 - tts_task = asyncio.create_task( - tts_provider.get_audio_stream(text_queue, audio_queue) - ) - - chunk_size = 50 # 每 50 个字符发送一次给 TTS - - try: - # 喂文本给 TTS - feed_task = asyncio.create_task( - _feed_text_to_tts(chunks, text_queue, chunk_size) - ) - - # 从 TTS 输出队列中读取音频数据 - while True: - audio_data = await audio_queue.get() - - if audio_data is None: - break - - # 将音频数据封装为 MessageChain - import base64 - - audio_b64 = base64.b64encode(audio_data).decode("utf-8") - - chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") - yield chain - - await feed_task - - except Exception as e: - logger.error(f"[Live TTS] 流式处理失败: {e}", exc_info=True) - await text_queue.put(None) - - finally: - try: - await asyncio.wait_for(tts_task, timeout=5.0) - except asyncio.TimeoutError: - logger.warning("[Live TTS] TTS 任务超时,强制取消") - tts_task.cancel() - - -async def _feed_text_to_tts( - chunks: list[MessageChain], text_queue: asyncio.Queue, chunk_size: int +async def _run_agent_feeder( + agent_runner: AgentRunner, + text_queue: asyncio.Queue, + max_step: int, + show_tool_use: bool, + show_reasoning: bool, ): - """从消息链中提取文本并分块发送给 TTS""" - accumulated_text = "" - + """运行 Agent 并将文本输出分句放入队列""" + buffer = "" try: - for chain in chunks: - text = chain.get_plain_text() - if not text: + async for chain in run_agent( + agent_runner, + max_step=max_step, + show_tool_use=show_tool_use, + stream_to_general=False, + show_reasoning=show_reasoning, + ): + if chain is None: continue - accumulated_text += text + # 提取文本 + text = chain.get_plain_text() + if text: + buffer += text - # 当累积的文本达到chunk_size时,发送给TTS - while len(accumulated_text) >= chunk_size: - chunk = accumulated_text[:chunk_size] - await text_queue.put(chunk) - accumulated_text = accumulated_text[chunk_size:] + # 分句逻辑:匹配标点符号 + # r"([.。!!??\n]+)" 会保留分隔符 + parts = re.split(r"([.。!!??\n]+)", buffer) - # 处理剩余文本 - if accumulated_text: - await text_queue.put(accumulated_text) + if len(parts) > 1: + # 处理完整的句子 + # range step 2 因为 split 后是 [text, delim, text, delim, ...] + temp_buffer = "" + for i in range(0, len(parts) - 1, 2): + sentence = parts[i] + delim = parts[i + 1] + full_sentence = sentence + delim + temp_buffer += full_sentence + if len(temp_buffer) >= 10: + if temp_buffer.strip(): + logger.info(f"[Live Agent Feeder] 分句: {temp_buffer}") + await text_queue.put(temp_buffer) + temp_buffer = "" + + # 更新 buffer 为剩余部分 + buffer = temp_buffer + parts[-1] + + # 处理剩余 buffer + if buffer.strip(): + await text_queue.put(buffer) + + except Exception as e: + logger.error(f"[Live Agent Feeder] Error: {e}", exc_info=True) finally: - # 发送结束标记 + # 发送结束信号 await text_queue.put(None) -async def _process_full_tts(chunks: list[MessageChain], tts_provider): - """处理完整音频 TTS""" - accumulated_text = "" - +async def _safe_tts_stream_wrapper( + tts_provider: TTSProvider, + text_queue: asyncio.Queue[str | None], + audio_queue: asyncio.Queue[bytes | None], +): + """包装原生流式 TTS 确保异常处理和队列关闭""" try: - # 累积所有文本 - for chain in chunks: - text = chain.get_plain_text() - if text: - accumulated_text += text + await tts_provider.get_audio_stream(text_queue, audio_queue) + except Exception as e: + logger.error(f"[Live TTS Stream] Error: {e}", exc_info=True) + finally: + await audio_queue.put(None) - # 如果没有文本,直接返回 - if not accumulated_text: - return - logger.info(f"[Live TTS] 累积完整文本,长度: {len(accumulated_text)}") +async def _simulated_stream_tts( + tts_provider: TTSProvider, + text_queue: asyncio.Queue[str | None], + audio_queue: asyncio.Queue[bytes | None], +): + """模拟流式 TTS 分句生成音频""" + try: + while True: + text = await text_queue.get() + if text is None: + break - # 调用 get_audio 生成完整音频 - audio_path = await tts_provider.get_audio(accumulated_text) + try: + audio_path = await tts_provider.get_audio(text) - # 读取音频文件 - with open(audio_path, "rb") as f: - audio_data = f.read() - - # 将音频数据封装为 MessageChain - import base64 - - audio_b64 = base64.b64encode(audio_data).decode("utf-8") - - chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk") - yield chain + if audio_path: + with open(audio_path, "rb") as f: + audio_data = f.read() + await audio_queue.put(audio_data) + except Exception as e: + logger.error( + f"[Live TTS Simulated] Error processing text '{text[:20]}...': {e}" + ) + # 继续处理下一句 except Exception as e: - logger.error(f"[Live TTS] 完整音频生成失败: {e}", exc_info=True) + logger.error(f"[Live TTS Simulated] Critical Error: {e}", exc_info=True) + finally: + await audio_queue.put(None) diff --git a/dashboard/src/components/chat/LiveMode.vue b/dashboard/src/components/chat/LiveMode.vue index bfd602c5a..81e333c34 100644 --- a/dashboard/src/components/chat/LiveMode.vue +++ b/dashboard/src/components/chat/LiveMode.vue @@ -21,13 +21,20 @@