Files
AstrBot/astrbot/dashboard/routes/live_chat.py
T
Soulter 9d93bda3fe feat: temporary file handling and introduce TempDirCleaner (#5026)
* feat: temporary file handling and introduce TempDirCleaner

- Updated various modules to use `get_astrbot_temp_path()` instead of `get_astrbot_data_path()` for temporary file storage.
- Renamed temporary files for better identification and organization.
- Introduced `TempDirCleaner` to manage the size of the temporary directory, ensuring it does not exceed a specified limit by deleting the oldest files.
- Added configuration option for maximum temporary directory size in the dashboard.
- Implemented tests for `TempDirCleaner` to verify cleanup functionality and size management.

* ruff
2026-02-12 01:04:48 +08:00

429 lines
16 KiB
Python

import asyncio
import json
import os
import time
import uuid
import wave
from typing import Any
import jwt
from quart import websocket
from astrbot import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Route, RouteContext
class LiveChatSession:
"""Live Chat 会话管理器"""
def __init__(self, session_id: str, username: str) -> None:
self.session_id = session_id
self.username = username
self.conversation_id = str(uuid.uuid4())
self.is_speaking = False
self.is_processing = False
self.should_interrupt = False
self.audio_frames: list[bytes] = []
self.current_stamp: str | None = None
self.temp_audio_path: str | None = None
def start_speaking(self, stamp: str) -> None:
"""开始说话"""
self.is_speaking = True
self.current_stamp = stamp
self.audio_frames = []
logger.debug(f"[Live Chat] {self.username} 开始说话 stamp={stamp}")
def add_audio_frame(self, data: bytes) -> None:
"""添加音频帧"""
if self.is_speaking:
self.audio_frames.append(data)
async def end_speaking(self, stamp: str) -> tuple[str | None, float]:
"""结束说话,返回组装的 WAV 文件路径和耗时"""
start_time = time.time()
if not self.is_speaking or stamp != self.current_stamp:
logger.warning(
f"[Live Chat] stamp 不匹配或未在说话状态: {stamp} vs {self.current_stamp}"
)
return None, 0.0
self.is_speaking = False
if not self.audio_frames:
logger.warning("[Live Chat] 没有音频帧数据")
return None, 0.0
# 组装 WAV 文件
try:
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
# 假设前端发送的是 PCM 数据,采样率 16000Hz,单声道,16位
with wave.open(audio_path, "wb") as wav_file:
wav_file.setnchannels(1) # 单声道
wav_file.setsampwidth(2) # 16位 = 2字节
wav_file.setframerate(16000) # 采样率 16000Hz
for frame in self.audio_frames:
wav_file.writeframes(frame)
self.temp_audio_path = audio_path
logger.info(
f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes"
)
return audio_path, time.time() - start_time
except Exception as e:
logger.error(f"[Live Chat] 组装 WAV 文件失败: {e}", exc_info=True)
return None, 0.0
def cleanup(self) -> None:
"""清理临时文件"""
if self.temp_audio_path and os.path.exists(self.temp_audio_path):
try:
os.remove(self.temp_audio_path)
logger.debug(f"[Live Chat] 已删除临时文件: {self.temp_audio_path}")
except Exception as e:
logger.warning(f"[Live Chat] 删除临时文件失败: {e}")
self.temp_audio_path = None
class LiveChatRoute(Route):
"""Live Chat WebSocket 路由"""
def __init__(
self,
context: RouteContext,
db: Any,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.db = db
self.plugin_manager = core_lifecycle.plugin_manager
self.sessions: dict[str, LiveChatSession] = {}
# 注册 WebSocket 路由
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
async def live_chat_ws(self) -> None:
"""Live Chat WebSocket 处理器"""
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
token = websocket.args.get("token")
if not token:
await websocket.close(1008, "Missing authentication token")
return
try:
jwt_secret = self.config["dashboard"].get("jwt_secret")
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
username = payload["username"]
except jwt.ExpiredSignatureError:
await websocket.close(1008, "Token expired")
return
except jwt.InvalidTokenError:
await websocket.close(1008, "Invalid token")
return
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
live_session = LiveChatSession(session_id, username)
self.sessions[session_id] = live_session
logger.info(f"[Live Chat] WebSocket 连接建立: {username}")
try:
while True:
message = await websocket.receive_json()
await self._handle_message(live_session, message)
except Exception as e:
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
finally:
# 清理会话
if session_id in self.sessions:
live_session.cleanup()
del self.sessions[session_id]
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
"""处理 WebSocket 消息"""
msg_type = message.get("t") # 使用 t 代替 type
if msg_type == "start_speaking":
# 开始说话
stamp = message.get("stamp")
if not stamp:
logger.warning("[Live Chat] start_speaking 缺少 stamp")
return
session.start_speaking(stamp)
elif msg_type == "speaking_part":
# 音频片段
audio_data_b64 = message.get("data")
if not audio_data_b64:
return
# 解码 base64
import base64
try:
audio_data = base64.b64decode(audio_data_b64)
session.add_audio_frame(audio_data)
except Exception as e:
logger.error(f"[Live Chat] 解码音频数据失败: {e}")
elif msg_type == "end_speaking":
# 结束说话
stamp = message.get("stamp")
if not stamp:
logger.warning("[Live Chat] end_speaking 缺少 stamp")
return
audio_path, assemble_duration = await session.end_speaking(stamp)
if not audio_path:
await websocket.send_json({"t": "error", "data": "音频组装失败"})
return
# 处理音频:STT -> LLM -> TTS
await self._process_audio(session, audio_path, assemble_duration)
elif msg_type == "interrupt":
# 用户打断
session.should_interrupt = True
logger.info(f"[Live Chat] 用户打断: {session.username}")
async def _process_audio(
self, session: LiveChatSession, audio_path: str, assemble_duration: float
) -> None:
"""处理音频:STT -> LLM -> 流式 TTS"""
try:
# 发送 WAV 组装耗时
await websocket.send_json(
{"t": "metrics", "data": {"wav_assemble_time": assemble_duration}}
)
wav_assembly_finish_time = time.time()
session.is_processing = True
session.should_interrupt = False
# 1. STT - 语音转文字
ctx = self.plugin_manager.context
stt_provider = ctx.provider_manager.stt_provider_insts[0]
if not stt_provider:
logger.error("[Live Chat] STT Provider 未配置")
await websocket.send_json({"t": "error", "data": "语音识别服务未配置"})
return
await websocket.send_json(
{"t": "metrics", "data": {"stt": stt_provider.meta().type}}
)
user_text = await stt_provider.get_text(audio_path)
if not user_text:
logger.warning("[Live Chat] STT 识别结果为空")
return
logger.info(f"[Live Chat] STT 结果: {user_text}")
await websocket.send_json(
{
"t": "user_msg",
"data": {"text": user_text, "ts": int(time.time() * 1000)},
}
)
# 2. 构造消息事件并发送到 pipeline
# 使用 webchat queue 机制
cid = session.conversation_id
queue = webchat_queue_mgr.get_or_create_queue(cid)
message_id = str(uuid.uuid4())
payload = {
"message_id": message_id,
"message": [{"type": "plain", "text": user_text}], # 直接发送文本
"action_type": "live", # 标记为 live mode
}
# 将消息放入队列
await queue.put((session.username, cid, payload))
# 3. 等待响应并流式发送 TTS 音频
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid)
bot_text = ""
audio_playing = False
try:
while True:
if session.should_interrupt:
# 用户打断,停止处理
logger.info("[Live Chat] 检测到用户打断")
await websocket.send_json({"t": "stop_play"})
# 保存消息并标记为被打断
await self._save_interrupted_message(
session, user_text, bot_text
)
# 清空队列中未处理的消息
while not back_queue.empty():
try:
back_queue.get_nowait()
except asyncio.QueueEmpty:
break
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
except asyncio.TimeoutError:
continue
if not result:
continue
result_message_id = result.get("message_id")
if result_message_id != message_id:
logger.warning(
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
)
continue
result_type = result.get("type")
result_chain_type = result.get("chain_type")
data = result.get("data", "")
if result_chain_type == "agent_stats":
try:
stats = json.loads(data)
await websocket.send_json(
{
"t": "metrics",
"data": {
"llm_ttft": stats.get("time_to_first_token", 0),
"llm_total_time": stats.get("end_time", 0)
- stats.get("start_time", 0),
},
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
continue
if result_chain_type == "tts_stats":
try:
stats = json.loads(data)
await websocket.send_json(
{
"t": "metrics",
"data": stats,
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
continue
if result_type == "plain":
# 普通文本消息
bot_text += data
elif result_type == "audio_chunk":
# 流式音频数据
if not audio_playing:
audio_playing = True
logger.debug("[Live Chat] 开始播放音频流")
# Calculate latency from wav assembly finish to first audio chunk
speak_to_first_frame_latency = (
time.time() - wav_assembly_finish_time
)
await websocket.send_json(
{
"t": "metrics",
"data": {
"speak_to_first_frame": speak_to_first_frame_latency
},
}
)
text = result.get("text")
if text:
await websocket.send_json(
{
"t": "bot_text_chunk",
"data": {"text": text},
}
)
# 发送音频数据给前端
await websocket.send_json(
{
"t": "response",
"data": data, # base64 编码的音频数据
}
)
elif result_type in ["complete", "end"]:
# 处理完成
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
# 如果没有音频流,发送 bot 消息文本
if not audio_playing:
await websocket.send_json(
{
"t": "bot_msg",
"data": {
"text": bot_text,
"ts": int(time.time() * 1000),
},
}
)
# 发送结束标记
await websocket.send_json({"t": "end"})
# 发送总耗时
wav_to_tts_duration = time.time() - wav_assembly_finish_time
await websocket.send_json(
{
"t": "metrics",
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
}
)
break
finally:
webchat_queue_mgr.remove_back_queue(message_id)
except Exception as e:
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
await websocket.send_json({"t": "error", "data": f"处理失败: {str(e)}"})
finally:
session.is_processing = False
session.should_interrupt = False
async def _save_interrupted_message(
self, session: LiveChatSession, user_text: str, bot_text: str
) -> None:
"""保存被打断的消息"""
interrupted_text = bot_text + " [用户打断]"
logger.info(f"[Live Chat] 保存打断消息: {interrupted_text}")
# 简单记录到日志,实际保存逻辑可以后续完善
try:
timestamp = int(time.time() * 1000)
logger.info(
f"[Live Chat] 用户消息: {user_text} (session: {session.session_id}, ts: {timestamp})"
)
if bot_text:
logger.info(
f"[Live Chat] Bot 消息(打断): {interrupted_text} (session: {session.session_id}, ts: {timestamp})"
)
except Exception as e:
logger.error(f"[Live Chat] 记录消息失败: {e}", exc_info=True)