feat: enhance audio processing and metrics display in live mode
This commit is contained in:
+150
-127
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -155,55 +156,85 @@ async def run_live_agent(
|
||||
Yields:
|
||||
MessageChain: 包含文本或音频数据的消息链
|
||||
"""
|
||||
support_stream = tts_provider.support_stream() if tts_provider else False
|
||||
|
||||
if support_stream:
|
||||
logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)")
|
||||
elif tts_provider:
|
||||
logger.info(
|
||||
f"[Live Agent] 使用 TTS({tts_provider.meta().type} "
|
||||
"使用 get_audio,将累积完整文本后生成音频)"
|
||||
)
|
||||
|
||||
# 收集 LLM 输出
|
||||
llm_stream_chunks: list[MessageChain] = []
|
||||
|
||||
# 运行普通 agent
|
||||
async for chain in run_agent(
|
||||
agent_runner,
|
||||
max_step=max_step,
|
||||
show_tool_use=show_tool_use,
|
||||
stream_to_general=False,
|
||||
show_reasoning=show_reasoning,
|
||||
):
|
||||
if chain is not None:
|
||||
llm_stream_chunks.append(chain)
|
||||
|
||||
# 如果没有 TTS Provider,直接发送文本
|
||||
if not tts_provider:
|
||||
for chain in llm_stream_chunks:
|
||||
async for chain in run_agent(
|
||||
agent_runner,
|
||||
max_step=max_step,
|
||||
show_tool_use=show_tool_use,
|
||||
stream_to_general=False,
|
||||
show_reasoning=show_reasoning,
|
||||
):
|
||||
yield chain
|
||||
return
|
||||
|
||||
# 处理 TTS
|
||||
support_stream = tts_provider.support_stream()
|
||||
if support_stream:
|
||||
logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)")
|
||||
else:
|
||||
logger.info(
|
||||
f"[Live Agent] 使用 TTS({tts_provider.meta().type} "
|
||||
"使用 get_audio,将按句子分块生成音频)"
|
||||
)
|
||||
|
||||
# 统计数据初始化
|
||||
tts_start_time = time.time()
|
||||
tts_first_frame_time = 0.0
|
||||
first_chunk_received = False
|
||||
|
||||
# 创建队列
|
||||
text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
|
||||
# 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue
|
||||
feeder_task = asyncio.create_task(
|
||||
_run_agent_feeder(
|
||||
agent_runner, text_queue, max_step, show_tool_use, show_reasoning
|
||||
)
|
||||
)
|
||||
|
||||
# 2. 启动 TTS 任务:负责从 text_queue 读取文本并生成音频到 audio_queue
|
||||
if support_stream:
|
||||
# 使用流式 TTS
|
||||
async for audio_chunk in _process_stream_tts(llm_stream_chunks, tts_provider):
|
||||
if not first_chunk_received:
|
||||
tts_first_frame_time = time.time() - tts_start_time
|
||||
first_chunk_received = True
|
||||
yield audio_chunk
|
||||
tts_task = asyncio.create_task(
|
||||
_safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue)
|
||||
)
|
||||
else:
|
||||
# 使用完整音频 TTS
|
||||
async for audio_chunk in _process_full_tts(llm_stream_chunks, tts_provider):
|
||||
tts_task = asyncio.create_task(
|
||||
_simulated_stream_tts(tts_provider, text_queue, audio_queue)
|
||||
)
|
||||
|
||||
# 3. 主循环:从 audio_queue 读取音频并 yield
|
||||
try:
|
||||
while True:
|
||||
audio_data = await audio_queue.get()
|
||||
|
||||
if audio_data is None:
|
||||
break
|
||||
|
||||
if not first_chunk_received:
|
||||
# 记录首帧延迟(从开始处理到收到第一个音频块)
|
||||
tts_first_frame_time = time.time() - tts_start_time
|
||||
first_chunk_received = True
|
||||
yield audio_chunk
|
||||
|
||||
# 将音频数据封装为 MessageChain
|
||||
import base64
|
||||
|
||||
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk")
|
||||
yield chain
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Agent] 运行时发生错误: {e}", exc_info=True)
|
||||
finally:
|
||||
# 清理任务
|
||||
if not feeder_task.done():
|
||||
feeder_task.cancel()
|
||||
if not tts_task.done():
|
||||
tts_task.cancel()
|
||||
|
||||
# 确保队列被消费
|
||||
pass
|
||||
|
||||
tts_end_time = time.time()
|
||||
|
||||
# 发送 TTS 统计信息
|
||||
@@ -228,113 +259,105 @@ async def run_live_agent(
|
||||
logger.error(f"发送 TTS 统计信息失败: {e}")
|
||||
|
||||
|
||||
async def _process_stream_tts(chunks: list[MessageChain], tts_provider):
|
||||
"""处理流式 TTS"""
|
||||
text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
|
||||
# 启动 TTS 处理任务
|
||||
tts_task = asyncio.create_task(
|
||||
tts_provider.get_audio_stream(text_queue, audio_queue)
|
||||
)
|
||||
|
||||
chunk_size = 50 # 每 50 个字符发送一次给 TTS
|
||||
|
||||
try:
|
||||
# 喂文本给 TTS
|
||||
feed_task = asyncio.create_task(
|
||||
_feed_text_to_tts(chunks, text_queue, chunk_size)
|
||||
)
|
||||
|
||||
# 从 TTS 输出队列中读取音频数据
|
||||
while True:
|
||||
audio_data = await audio_queue.get()
|
||||
|
||||
if audio_data is None:
|
||||
break
|
||||
|
||||
# 将音频数据封装为 MessageChain
|
||||
import base64
|
||||
|
||||
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk")
|
||||
yield chain
|
||||
|
||||
await feed_task
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live TTS] 流式处理失败: {e}", exc_info=True)
|
||||
await text_queue.put(None)
|
||||
|
||||
finally:
|
||||
try:
|
||||
await asyncio.wait_for(tts_task, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("[Live TTS] TTS 任务超时,强制取消")
|
||||
tts_task.cancel()
|
||||
|
||||
|
||||
async def _feed_text_to_tts(
|
||||
chunks: list[MessageChain], text_queue: asyncio.Queue, chunk_size: int
|
||||
async def _run_agent_feeder(
|
||||
agent_runner: AgentRunner,
|
||||
text_queue: asyncio.Queue,
|
||||
max_step: int,
|
||||
show_tool_use: bool,
|
||||
show_reasoning: bool,
|
||||
):
|
||||
"""从消息链中提取文本并分块发送给 TTS"""
|
||||
accumulated_text = ""
|
||||
|
||||
"""运行 Agent 并将文本输出分句放入队列"""
|
||||
buffer = ""
|
||||
try:
|
||||
for chain in chunks:
|
||||
text = chain.get_plain_text()
|
||||
if not text:
|
||||
async for chain in run_agent(
|
||||
agent_runner,
|
||||
max_step=max_step,
|
||||
show_tool_use=show_tool_use,
|
||||
stream_to_general=False,
|
||||
show_reasoning=show_reasoning,
|
||||
):
|
||||
if chain is None:
|
||||
continue
|
||||
|
||||
accumulated_text += text
|
||||
# 提取文本
|
||||
text = chain.get_plain_text()
|
||||
if text:
|
||||
buffer += text
|
||||
|
||||
# 当累积的文本达到chunk_size时,发送给TTS
|
||||
while len(accumulated_text) >= chunk_size:
|
||||
chunk = accumulated_text[:chunk_size]
|
||||
await text_queue.put(chunk)
|
||||
accumulated_text = accumulated_text[chunk_size:]
|
||||
# 分句逻辑:匹配标点符号
|
||||
# r"([.。!!??\n]+)" 会保留分隔符
|
||||
parts = re.split(r"([.。!!??\n]+)", buffer)
|
||||
|
||||
# 处理剩余文本
|
||||
if accumulated_text:
|
||||
await text_queue.put(accumulated_text)
|
||||
if len(parts) > 1:
|
||||
# 处理完整的句子
|
||||
# range step 2 因为 split 后是 [text, delim, text, delim, ...]
|
||||
temp_buffer = ""
|
||||
for i in range(0, len(parts) - 1, 2):
|
||||
sentence = parts[i]
|
||||
delim = parts[i + 1]
|
||||
full_sentence = sentence + delim
|
||||
temp_buffer += full_sentence
|
||||
|
||||
if len(temp_buffer) >= 10:
|
||||
if temp_buffer.strip():
|
||||
logger.info(f"[Live Agent Feeder] 分句: {temp_buffer}")
|
||||
await text_queue.put(temp_buffer)
|
||||
temp_buffer = ""
|
||||
|
||||
# 更新 buffer 为剩余部分
|
||||
buffer = temp_buffer + parts[-1]
|
||||
|
||||
# 处理剩余 buffer
|
||||
if buffer.strip():
|
||||
await text_queue.put(buffer)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Agent Feeder] Error: {e}", exc_info=True)
|
||||
finally:
|
||||
# 发送结束标记
|
||||
# 发送结束信号
|
||||
await text_queue.put(None)
|
||||
|
||||
|
||||
async def _process_full_tts(chunks: list[MessageChain], tts_provider):
|
||||
"""处理完整音频 TTS"""
|
||||
accumulated_text = ""
|
||||
|
||||
async def _safe_tts_stream_wrapper(
|
||||
tts_provider: TTSProvider,
|
||||
text_queue: asyncio.Queue[str | None],
|
||||
audio_queue: asyncio.Queue[bytes | None],
|
||||
):
|
||||
"""包装原生流式 TTS 确保异常处理和队列关闭"""
|
||||
try:
|
||||
# 累积所有文本
|
||||
for chain in chunks:
|
||||
text = chain.get_plain_text()
|
||||
if text:
|
||||
accumulated_text += text
|
||||
await tts_provider.get_audio_stream(text_queue, audio_queue)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live TTS Stream] Error: {e}", exc_info=True)
|
||||
finally:
|
||||
await audio_queue.put(None)
|
||||
|
||||
# 如果没有文本,直接返回
|
||||
if not accumulated_text:
|
||||
return
|
||||
|
||||
logger.info(f"[Live TTS] 累积完整文本,长度: {len(accumulated_text)}")
|
||||
async def _simulated_stream_tts(
|
||||
tts_provider: TTSProvider,
|
||||
text_queue: asyncio.Queue[str | None],
|
||||
audio_queue: asyncio.Queue[bytes | None],
|
||||
):
|
||||
"""模拟流式 TTS 分句生成音频"""
|
||||
try:
|
||||
while True:
|
||||
text = await text_queue.get()
|
||||
if text is None:
|
||||
break
|
||||
|
||||
# 调用 get_audio 生成完整音频
|
||||
audio_path = await tts_provider.get_audio(accumulated_text)
|
||||
try:
|
||||
audio_path = await tts_provider.get_audio(text)
|
||||
|
||||
# 读取音频文件
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
# 将音频数据封装为 MessageChain
|
||||
import base64
|
||||
|
||||
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
chain = MessageChain(chain=[Plain(audio_b64)], type="audio_chunk")
|
||||
yield chain
|
||||
if audio_path:
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
await audio_queue.put(audio_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Live TTS Simulated] Error processing text '{text[:20]}...': {e}"
|
||||
)
|
||||
# 继续处理下一句
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live TTS] 完整音频生成失败: {e}", exc_info=True)
|
||||
logger.error(f"[Live TTS Simulated] Critical Error: {e}", exc_info=True)
|
||||
finally:
|
||||
await audio_queue.put(None)
|
||||
|
||||
@@ -21,13 +21,20 @@
|
||||
</div>
|
||||
|
||||
<div class="metrics-container" v-if="Object.keys(metrics).length > 0">
|
||||
<span v-if="metrics.wav_assemble_time">WAV Assemble: {{ (metrics.wav_assemble_time * 1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.llm_ttft">LLM First Token Latency: {{ (metrics.llm_ttft * 1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.llm_total_time">LLM Total Latency: {{ (metrics.llm_total_time * 1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.tts_first_frame_time">TTS First Frame Latency: {{ (metrics.tts_first_frame_time * 1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.tts_total_time">TTS Total Larency: {{ (metrics.tts_total_time * 1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.speak_to_first_frame">Speak -> First TTS Frame: {{ (metrics.speak_to_first_frame * 1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.wav_to_tts_total_time">Speak -> End: {{ (metrics.wav_to_tts_total_time * 1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.wav_assemble_time">WAV Assemble: {{ (metrics.wav_assemble_time * 1000).toFixed(0)
|
||||
}}ms</span>
|
||||
<span v-if="metrics.llm_ttft">LLM First Token Latency: {{ (metrics.llm_ttft * 1000).toFixed(0)
|
||||
}}ms</span>
|
||||
<span v-if="metrics.llm_total_time">LLM Total Latency: {{ (metrics.llm_total_time * 1000).toFixed(0)
|
||||
}}ms</span>
|
||||
<span v-if="metrics.tts_first_frame_time">TTS First Frame Latency: {{ (metrics.tts_first_frame_time *
|
||||
1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.tts_total_time">TTS Total Larency: {{ (metrics.tts_total_time * 1000).toFixed(0)
|
||||
}}ms</span>
|
||||
<span v-if="metrics.speak_to_first_frame">Speak -> First TTS Frame: {{ (metrics.speak_to_first_frame *
|
||||
1000).toFixed(0) }}ms</span>
|
||||
<span v-if="metrics.wav_to_tts_total_time">Speak -> End: {{ (metrics.wav_to_tts_total_time *
|
||||
1000).toFixed(0) }}ms</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -65,7 +72,15 @@ let audioContext: AudioContext | null = null;
|
||||
let analyser: AnalyserNode | null = null;
|
||||
const botEnergy = ref(0);
|
||||
let energyLoopId: number;
|
||||
let isPlaying = ref(false);
|
||||
let isPlaying = ref(false); // UI 状态:是否正在播放
|
||||
|
||||
// 音频播放队列管理
|
||||
const rawAudioQueue: Uint8Array[] = []; // 待解码队列
|
||||
const audioBufferQueue: AudioBuffer[] = []; // 待播放队列
|
||||
let isDecoding = false;
|
||||
let isPlayingAudio = false; // 内部状态:是否正在播放音频
|
||||
let currentSource: AudioBufferSourceNode | null = null;
|
||||
|
||||
|
||||
// 消息历史
|
||||
const messages = ref<Array<{ type: 'user' | 'bot', text: string }>>([]);
|
||||
@@ -324,7 +339,7 @@ function handleWebSocketMessage(event: MessageEvent) {
|
||||
isProcessing.value = false;
|
||||
isListening.value = true;
|
||||
break;
|
||||
|
||||
|
||||
case 'metrics':
|
||||
metrics.value = { ...metrics.value, ...message.data };
|
||||
break;
|
||||
@@ -345,35 +360,112 @@ function playAudioChunk(base64Data: string) {
|
||||
bytes[i] = binaryString.charCodeAt(i);
|
||||
}
|
||||
|
||||
// 解码 WAV 音频
|
||||
audioContext.decodeAudioData(bytes.buffer).then(audioBuffer => {
|
||||
const source = audioContext!.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
// 连接到分析器
|
||||
if (analyser) {
|
||||
source.connect(analyser);
|
||||
analyser.connect(audioContext!.destination);
|
||||
} else {
|
||||
source.connect(audioContext!.destination);
|
||||
}
|
||||
source.start();
|
||||
isPlaying.value = true;
|
||||
// 放入待解码队列
|
||||
rawAudioQueue.push(bytes);
|
||||
|
||||
source.onended = () => {
|
||||
isPlaying.value = false;
|
||||
};
|
||||
}).catch(error => {
|
||||
console.error('[Live Mode] 解码音频失败:', error);
|
||||
});
|
||||
// 触发解码处理
|
||||
processRawAudioQueue();
|
||||
|
||||
} catch (error) {
|
||||
console.error('[Live Mode] 接收音频数据失败:', error);
|
||||
}
|
||||
}
|
||||
|
||||
async function processRawAudioQueue() {
|
||||
if (isDecoding || rawAudioQueue.length === 0) return;
|
||||
|
||||
isDecoding = true;
|
||||
|
||||
try {
|
||||
while (rawAudioQueue.length > 0) {
|
||||
const bytes = rawAudioQueue.shift();
|
||||
if (!bytes || !audioContext) continue;
|
||||
|
||||
try {
|
||||
// 解码
|
||||
const audioBuffer = await audioContext.decodeAudioData(bytes.buffer as ArrayBuffer);
|
||||
audioBufferQueue.push(audioBuffer);
|
||||
|
||||
// 如果当前没有播放,立即开始播放
|
||||
if (!isPlayingAudio) {
|
||||
playNextAudio();
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('[Live Mode] 解码音频失败:', err);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
isDecoding = false;
|
||||
// 如果在解码过程中又有新数据进来,继续处理
|
||||
if (rawAudioQueue.length > 0) {
|
||||
processRawAudioQueue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function playNextAudio() {
|
||||
if (audioBufferQueue.length === 0) {
|
||||
isPlayingAudio = false;
|
||||
isPlaying.value = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!audioContext) return;
|
||||
|
||||
isPlayingAudio = true;
|
||||
isPlaying.value = true;
|
||||
|
||||
try {
|
||||
const audioBuffer = audioBufferQueue.shift();
|
||||
if (!audioBuffer) return;
|
||||
|
||||
const source = audioContext.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
|
||||
// 连接到分析器
|
||||
if (analyser) {
|
||||
source.connect(analyser);
|
||||
analyser.connect(audioContext.destination);
|
||||
} else {
|
||||
source.connect(audioContext.destination);
|
||||
}
|
||||
|
||||
currentSource = source;
|
||||
source.start();
|
||||
|
||||
source.onended = () => {
|
||||
currentSource = null;
|
||||
playNextAudio();
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
console.error('[Live Mode] 播放音频失败:', error);
|
||||
isPlayingAudio = false;
|
||||
isPlaying.value = false;
|
||||
playNextAudio(); // 尝试播放下一个
|
||||
}
|
||||
}
|
||||
|
||||
function stopAudioPlayback() {
|
||||
// TODO: 实现停止当前播放的音频
|
||||
// 停止当前播放源
|
||||
if (currentSource) {
|
||||
try {
|
||||
currentSource.stop();
|
||||
currentSource.disconnect();
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
currentSource = null;
|
||||
}
|
||||
|
||||
// 清空队列
|
||||
rawAudioQueue.length = 0;
|
||||
audioBufferQueue.length = 0;
|
||||
|
||||
// 重置状态
|
||||
isPlayingAudio = false;
|
||||
isPlaying.value = false;
|
||||
isDecoding = false;
|
||||
}
|
||||
|
||||
function generateStamp(): string {
|
||||
@@ -415,6 +507,8 @@ watch(isSpeaking, (newVal) => {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ t: 'interrupt' }));
|
||||
}
|
||||
// 本地立即停止播放
|
||||
stopAudioPlayback();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user