feat: enhance audio processing and metrics display in live mode

This commit is contained in:
Soulter
2026-01-17 15:35:02 +08:00
parent 19e6253d5d
commit 856d3496fa
2 changed files with 273 additions and 156 deletions
+150 -127
View File
@@ -1,4 +1,5 @@
import asyncio
import re
import time
import traceback
from collections.abc import AsyncGenerator
@@ -155,55 +156,85 @@ async def run_live_agent(
Yields:
MessageChain: 包含文本或音频数据的消息链
"""
support_stream = tts_provider.support_stream() if tts_provider else False
if support_stream:
logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream")
elif tts_provider:
logger.info(
f"[Live Agent] 使用 TTS{tts_provider.meta().type} "
"使用 get_audio,将累积完整文本后生成音频)"
)
# 收集 LLM 输出
llm_stream_chunks: list[MessageChain] = []
# 运行普通 agent
async for chain in run_agent(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
stream_to_general=False,
show_reasoning=show_reasoning,
):
if chain is not None:
llm_stream_chunks.append(chain)
# 如果没有 TTS Provider,直接发送文本
if not tts_provider:
for chain in llm_stream_chunks:
async for chain in run_agent(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
stream_to_general=False,
show_reasoning=show_reasoning,
):
yield chain
return
# 处理 TTS
support_stream = tts_provider.support_stream()
if support_stream:
logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream")
else:
logger.info(
f"[Live Agent] 使用 TTS{tts_provider.meta().type} "
"使用 get_audio,将按句子分块生成音频)"
)
# 统计数据初始化
tts_start_time = time.time()
tts_first_frame_time = 0.0
first_chunk_received = False
# 创建队列
text_queue: asyncio.Queue[str | None] = asyncio.Queue()
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
# 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue
feeder_task = asyncio.create_task(
_run_agent_feeder(
agent_runner, text_queue, max_step, show_tool_use, show_reasoning
)
)
# 2. 启动 TTS 任务:负责从 text_queue 读取文本并生成音频到 audio_queue
if support_stream:
# 使用流式 TTS
async for audio_chunk in _process_stream_tts(llm_stream_chunks, tts_provider):
if not first_chunk_received:
tts_first_frame_time = time.time() - tts_start_time
first_chunk_received = True
yield audio_chunk
tts_task = asyncio.create_task(
_safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue)
)
else:
# 使用完整音频 TTS
async for audio_chunk in _process_full_tts(llm_stream_chunks, tts_provider):
tts_task = asyncio.create_task(
_simulated_stream_tts(tts_provider, text_queue, audio_queue)
)
# 3. 主循环:从 audio_queue 读取音频并 yield
try:
while True:
audio_data = await audio_queue.get()
if audio_data is None:
break
if not first_chunk_received:
# 记录首帧延迟(从开始处理到收到第一个音频块)
tts_first_frame_time = time.time() - tts_start_time
first_chunk_received = True
yield audio_chunk
# 将音频数据封装为 MessageChain
import base64
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk")
yield chain
except Exception as e:
logger.error(f"[Live Agent] 运行时发生错误: {e}", exc_info=True)
finally:
# 清理任务
if not feeder_task.done():
feeder_task.cancel()
if not tts_task.done():
tts_task.cancel()
# 确保队列被消费
pass
tts_end_time = time.time()
# 发送 TTS 统计信息
@@ -228,113 +259,105 @@ async def run_live_agent(
logger.error(f"发送 TTS 统计信息失败: {e}")
async def _process_stream_tts(chunks: list[MessageChain], tts_provider):
"""处理流式 TTS"""
text_queue: asyncio.Queue[str | None] = asyncio.Queue()
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
# 启动 TTS 处理任务
tts_task = asyncio.create_task(
tts_provider.get_audio_stream(text_queue, audio_queue)
)
chunk_size = 50 # 每 50 个字符发送一次给 TTS
try:
# 喂文本给 TTS
feed_task = asyncio.create_task(
_feed_text_to_tts(chunks, text_queue, chunk_size)
)
# 从 TTS 输出队列中读取音频数据
while True:
audio_data = await audio_queue.get()
if audio_data is None:
break
# 将音频数据封装为 MessageChain
import base64
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk")
yield chain
await feed_task
except Exception as e:
logger.error(f"[Live TTS] 流式处理失败: {e}", exc_info=True)
await text_queue.put(None)
finally:
try:
await asyncio.wait_for(tts_task, timeout=5.0)
except asyncio.TimeoutError:
logger.warning("[Live TTS] TTS 任务超时,强制取消")
tts_task.cancel()
async def _feed_text_to_tts(
chunks: list[MessageChain], text_queue: asyncio.Queue, chunk_size: int
async def _run_agent_feeder(
agent_runner: AgentRunner,
text_queue: asyncio.Queue,
max_step: int,
show_tool_use: bool,
show_reasoning: bool,
):
"""从消息链中提取文本并分块发送给 TTS"""
accumulated_text = ""
"""运行 Agent 并将文本输出分句放入队列"""
buffer = ""
try:
for chain in chunks:
text = chain.get_plain_text()
if not text:
async for chain in run_agent(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
stream_to_general=False,
show_reasoning=show_reasoning,
):
if chain is None:
continue
accumulated_text += text
# 提取文本
text = chain.get_plain_text()
if text:
buffer += text
# 当累积的文本达到chunk_size时,发送给TTS
while len(accumulated_text) >= chunk_size:
chunk = accumulated_text[:chunk_size]
await text_queue.put(chunk)
accumulated_text = accumulated_text[chunk_size:]
# 分句逻辑:匹配标点符号
# r"([.。!?\n]+)" 会保留分隔符
parts = re.split(r"([.。!?\n]+)", buffer)
# 处理剩余文本
if accumulated_text:
await text_queue.put(accumulated_text)
if len(parts) > 1:
# 处理完整的句子
# range step 2 因为 split 后是 [text, delim, text, delim, ...]
temp_buffer = ""
for i in range(0, len(parts) - 1, 2):
sentence = parts[i]
delim = parts[i + 1]
full_sentence = sentence + delim
temp_buffer += full_sentence
if len(temp_buffer) >= 10:
if temp_buffer.strip():
logger.info(f"[Live Agent Feeder] 分句: {temp_buffer}")
await text_queue.put(temp_buffer)
temp_buffer = ""
# 更新 buffer 为剩余部分
buffer = temp_buffer + parts[-1]
# 处理剩余 buffer
if buffer.strip():
await text_queue.put(buffer)
except Exception as e:
logger.error(f"[Live Agent Feeder] Error: {e}", exc_info=True)
finally:
# 发送结束标记
# 发送结束信号
await text_queue.put(None)
async def _process_full_tts(chunks: list[MessageChain], tts_provider):
"""处理完整音频 TTS"""
accumulated_text = ""
async def _safe_tts_stream_wrapper(
tts_provider: TTSProvider,
text_queue: asyncio.Queue[str | None],
audio_queue: asyncio.Queue[bytes | None],
):
"""包装原生流式 TTS 确保异常处理和队列关闭"""
try:
# 累积所有文本
for chain in chunks:
text = chain.get_plain_text()
if text:
accumulated_text += text
await tts_provider.get_audio_stream(text_queue, audio_queue)
except Exception as e:
logger.error(f"[Live TTS Stream] Error: {e}", exc_info=True)
finally:
await audio_queue.put(None)
# 如果没有文本,直接返回
if not accumulated_text:
return
logger.info(f"[Live TTS] 累积完整文本,长度: {len(accumulated_text)}")
async def _simulated_stream_tts(
tts_provider: TTSProvider,
text_queue: asyncio.Queue[str | None],
audio_queue: asyncio.Queue[bytes | None],
):
"""模拟流式 TTS 分句生成音频"""
try:
while True:
text = await text_queue.get()
if text is None:
break
# 调用 get_audio 生成完整音频
audio_path = await tts_provider.get_audio(accumulated_text)
try:
audio_path = await tts_provider.get_audio(text)
# 读取音频文件
with open(audio_path, "rb") as f:
audio_data = f.read()
# 将音频数据封装为 MessageChain
import base64
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk")
yield chain
if audio_path:
with open(audio_path, "rb") as f:
audio_data = f.read()
await audio_queue.put(audio_data)
except Exception as e:
logger.error(
f"[Live TTS Simulated] Error processing text '{text[:20]}...': {e}"
)
# 继续处理下一句
except Exception as e:
logger.error(f"[Live TTS] 完整音频生成失败: {e}", exc_info=True)
logger.error(f"[Live TTS Simulated] Critical Error: {e}", exc_info=True)
finally:
await audio_queue.put(None)
+123 -29
View File
@@ -21,13 +21,20 @@
</div>
<div class="metrics-container" v-if="Object.keys(metrics).length > 0">
<span v-if="metrics.wav_assemble_time">WAV Assemble: {{ (metrics.wav_assemble_time * 1000).toFixed(0) }}ms</span>
<span v-if="metrics.llm_ttft">LLM First Token Latency: {{ (metrics.llm_ttft * 1000).toFixed(0) }}ms</span>
<span v-if="metrics.llm_total_time">LLM Total Latency: {{ (metrics.llm_total_time * 1000).toFixed(0) }}ms</span>
<span v-if="metrics.tts_first_frame_time">TTS First Frame Latency: {{ (metrics.tts_first_frame_time * 1000).toFixed(0) }}ms</span>
<span v-if="metrics.tts_total_time">TTS Total Larency: {{ (metrics.tts_total_time * 1000).toFixed(0) }}ms</span>
<span v-if="metrics.speak_to_first_frame">Speak -> First TTS Frame: {{ (metrics.speak_to_first_frame * 1000).toFixed(0) }}ms</span>
<span v-if="metrics.wav_to_tts_total_time">Speak -> End: {{ (metrics.wav_to_tts_total_time * 1000).toFixed(0) }}ms</span>
<span v-if="metrics.wav_assemble_time">WAV Assemble: {{ (metrics.wav_assemble_time * 1000).toFixed(0)
}}ms</span>
<span v-if="metrics.llm_ttft">LLM First Token Latency: {{ (metrics.llm_ttft * 1000).toFixed(0)
}}ms</span>
<span v-if="metrics.llm_total_time">LLM Total Latency: {{ (metrics.llm_total_time * 1000).toFixed(0)
}}ms</span>
<span v-if="metrics.tts_first_frame_time">TTS First Frame Latency: {{ (metrics.tts_first_frame_time *
1000).toFixed(0) }}ms</span>
<span v-if="metrics.tts_total_time">TTS Total Larency: {{ (metrics.tts_total_time * 1000).toFixed(0)
}}ms</span>
<span v-if="metrics.speak_to_first_frame">Speak -> First TTS Frame: {{ (metrics.speak_to_first_frame *
1000).toFixed(0) }}ms</span>
<span v-if="metrics.wav_to_tts_total_time">Speak -> End: {{ (metrics.wav_to_tts_total_time *
1000).toFixed(0) }}ms</span>
</div>
</div>
</div>
@@ -65,7 +72,15 @@ let audioContext: AudioContext | null = null;
let analyser: AnalyserNode | null = null;
const botEnergy = ref(0);
let energyLoopId: number;
let isPlaying = ref(false);
let isPlaying = ref(false); // UI 状态:是否正在播放
// 音频播放队列管理
const rawAudioQueue: Uint8Array[] = []; // 待解码队列
const audioBufferQueue: AudioBuffer[] = []; // 待播放队列
let isDecoding = false;
let isPlayingAudio = false; // 内部状态:是否正在播放音频
let currentSource: AudioBufferSourceNode | null = null;
// 消息历史
const messages = ref<Array<{ type: 'user' | 'bot', text: string }>>([]);
@@ -324,7 +339,7 @@ function handleWebSocketMessage(event: MessageEvent) {
isProcessing.value = false;
isListening.value = true;
break;
case 'metrics':
metrics.value = { ...metrics.value, ...message.data };
break;
@@ -345,35 +360,112 @@ function playAudioChunk(base64Data: string) {
bytes[i] = binaryString.charCodeAt(i);
}
// 解码 WAV 音频
audioContext.decodeAudioData(bytes.buffer).then(audioBuffer => {
const source = audioContext!.createBufferSource();
source.buffer = audioBuffer;
// 连接到分析器
if (analyser) {
source.connect(analyser);
analyser.connect(audioContext!.destination);
} else {
source.connect(audioContext!.destination);
}
source.start();
isPlaying.value = true;
// 放入待解码队列
rawAudioQueue.push(bytes);
source.onended = () => {
isPlaying.value = false;
};
}).catch(error => {
console.error('[Live Mode] 解码音频失败:', error);
});
// 触发解码处理
processRawAudioQueue();
} catch (error) {
console.error('[Live Mode] 接收音频数据失败:', error);
}
}
async function processRawAudioQueue() {
if (isDecoding || rawAudioQueue.length === 0) return;
isDecoding = true;
try {
while (rawAudioQueue.length > 0) {
const bytes = rawAudioQueue.shift();
if (!bytes || !audioContext) continue;
try {
// 解码
const audioBuffer = await audioContext.decodeAudioData(bytes.buffer as ArrayBuffer);
audioBufferQueue.push(audioBuffer);
// 如果当前没有播放,立即开始播放
if (!isPlayingAudio) {
playNextAudio();
}
} catch (err) {
console.error('[Live Mode] 解码音频失败:', err);
}
}
} finally {
isDecoding = false;
// 如果在解码过程中又有新数据进来,继续处理
if (rawAudioQueue.length > 0) {
processRawAudioQueue();
}
}
}
function playNextAudio() {
if (audioBufferQueue.length === 0) {
isPlayingAudio = false;
isPlaying.value = false;
return;
}
if (!audioContext) return;
isPlayingAudio = true;
isPlaying.value = true;
try {
const audioBuffer = audioBufferQueue.shift();
if (!audioBuffer) return;
const source = audioContext.createBufferSource();
source.buffer = audioBuffer;
// 连接到分析器
if (analyser) {
source.connect(analyser);
analyser.connect(audioContext.destination);
} else {
source.connect(audioContext.destination);
}
currentSource = source;
source.start();
source.onended = () => {
currentSource = null;
playNextAudio();
};
} catch (error) {
console.error('[Live Mode] 播放音频失败:', error);
isPlayingAudio = false;
isPlaying.value = false;
playNextAudio(); // 尝试播放下一个
}
}
function stopAudioPlayback() {
// TODO: 实现停止当前播放的音频
// 停止当前播放
if (currentSource) {
try {
currentSource.stop();
currentSource.disconnect();
} catch (e) {
// ignore
}
currentSource = null;
}
// 清空队列
rawAudioQueue.length = 0;
audioBufferQueue.length = 0;
// 重置状态
isPlayingAudio = false;
isPlaying.value = false;
isDecoding = false;
}
function generateStamp(): string {
@@ -415,6 +507,8 @@ watch(isSpeaking, (newVal) => {
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ t: 'interrupt' }));
}
// 本地立即停止播放
stopAudioPlayback();
}
});