Compare commits

...

4 Commits

Author SHA1 Message Date
Soulter 6b4498a554 chore: bump version to 4.14.6 2026-02-07 23:48:42 +08:00
Soulter 5e5207da95 perf: optimize webchat and wecom ai queue lifecycle (#4941)
* perf: optimize webchat and wecom ai queue lifecycle

* perf: enhance webchat back queue management with conversation ID support
2026-02-07 14:03:33 +08:00
Soulter def8b730b7 fix: correct spelling of 'temporary' in SharedPreferences class 2026-02-07 14:01:08 +08:00
Soulter 22a109c2ae feat: implement feishu / lark media file handling utilities for file, audio and video processing (#4938)
* feat: implement media file handling utilities for audio and video processing

* feat: refactor file upload handling for audio and video in LarkMessageEvent

* feat: add cleanup for failed audio and video conversion outputs in media_utils

* feat: add utility methods for sending messages and uploading files in LarkMessageEvent
2026-02-07 12:40:05 +08:00
15 changed files with 1050 additions and 289 deletions
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.14.5"
__version__ = "4.14.6"
+1 -1
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.14.5"
VERSION = "4.14.6"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -3,13 +3,10 @@ import base64
import json
import re
import time
import uuid
from typing import Any, cast
import lark_oapi as lark
from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
GetMessageResourceRequest,
)
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
@@ -125,44 +122,23 @@ class LarkPlatformAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
):
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
return
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = {
"zh_cn": {
"title": "",
"content": res,
},
}
if session.message_type == MessageType.GROUP_MESSAGE:
id_type = "chat_id"
if "%" in session.session_id:
session.session_id = session.session_id.split("%")[1]
receive_id = session.session_id
if "%" in receive_id:
receive_id = receive_id.split("%")[1]
else:
id_type = "open_id"
receive_id = session.session_id
request = (
CreateMessageRequest.builder()
.receive_id_type(id_type)
.request_body(
CreateMessageRequestBody.builder()
.receive_id(session.session_id)
.content(json.dumps(wrapped))
.msg_type("post")
.uuid(str(uuid.uuid4()))
.build(),
)
.build()
# 复用 LarkMessageEvent 中的通用发送逻辑
await LarkMessageEvent.send_message_chain(
message_chain,
self.lark_api,
receive_id=receive_id,
receive_id_type=id_type,
)
response = await self.lark_api.im.v1.message.acreate(request)
if not response.success():
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
+415 -28
View File
@@ -6,6 +6,8 @@ from io import BytesIO
import lark_oapi as lark
from lark_oapi.api.im.v1 import (
CreateFileRequest,
CreateFileRequestBody,
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
@@ -17,10 +19,15 @@ from lark_oapi.api.im.v1 import (
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Plain
from astrbot.api.message_components import At, File, Plain, Record, Video
from astrbot.api.message_components import Image as AstrBotImage
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.media_utils import (
convert_audio_to_opus,
convert_video_format,
get_media_duration,
)
class LarkMessageEvent(AstrMessageEvent):
@@ -35,6 +42,144 @@ class LarkMessageEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@staticmethod
async def _send_im_message(
lark_client: lark.Client,
*,
content: str,
msg_type: str,
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
) -> bool:
"""发送飞书 IM 消息的通用辅助函数
Args:
lark_client: 飞书客户端
content: 消息内容(JSON字符串)
msg_type: 消息类型(post/file/audio/media等)
reply_message_id: 回复的消息ID(用于回复消息)
receive_id: 接收者ID(用于主动发送)
receive_id_type: 接收者ID类型(用于主动发送)
Returns:
是否发送成功
"""
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化")
return False
if reply_message_id:
request = (
ReplyMessageRequest.builder()
.message_id(reply_message_id)
.request_body(
ReplyMessageRequestBody.builder()
.content(content)
.msg_type(msg_type)
.uuid(str(uuid.uuid4()))
.reply_in_thread(False)
.build()
)
.build()
)
response = await lark_client.im.v1.message.areply(request)
else:
from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
)
if receive_id_type is None or receive_id is None:
logger.error(
"[Lark] 主动发送消息时,receive_id 和 receive_id_type 不能为空",
)
return False
request = (
CreateMessageRequest.builder()
.receive_id_type(receive_id_type)
.request_body(
CreateMessageRequestBody.builder()
.receive_id(receive_id)
.content(content)
.msg_type(msg_type)
.uuid(str(uuid.uuid4()))
.build()
)
.build()
)
response = await lark_client.im.v1.message.acreate(request)
if not response.success():
logger.error(f"[Lark] 发送飞书消息失败({response.code}): {response.msg}")
return False
return True
@staticmethod
async def _upload_lark_file(
lark_client: lark.Client,
*,
path: str,
file_type: str,
duration: int | None = None,
) -> str | None:
"""上传文件到飞书的通用辅助函数
Args:
lark_client: 飞书客户端
path: 文件路径
file_type: 文件类型(stream/opus/mp4等)
duration: 媒体时长(毫秒),可选
Returns:
成功返回file_key,失败返回None
"""
if not path or not os.path.exists(path):
logger.error(f"[Lark] 文件不存在: {path}")
return None
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法上传文件")
return None
try:
with open(path, "rb") as file_obj:
body_builder = (
CreateFileRequestBody.builder()
.file_type(file_type)
.file_name(os.path.basename(path))
.file(file_obj)
)
if duration is not None:
body_builder.duration(duration)
request = (
CreateFileRequest.builder()
.request_body(body_builder.build())
.build()
)
response = await lark_client.im.v1.file.acreate(request)
if not response.success():
logger.error(
f"[Lark] 无法上传文件({response.code}): {response.msg}"
)
return None
if response.data is None:
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
return None
file_key = response.data.file_key
logger.debug(f"[Lark] 文件上传成功: {file_key}")
return file_key
except Exception as e:
logger.error(f"[Lark] 无法打开或上传文件: {e}")
return None
@staticmethod
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> list:
ret = []
@@ -103,6 +248,18 @@ class LarkMessageEvent(AstrMessageEvent):
ret.append(_stage)
ret.append([{"tag": "img", "image_key": image_key}])
_stage.clear()
elif isinstance(comp, File):
# 文件将通过 _send_file_message 方法单独发送,这里跳过
logger.debug("[Lark] 检测到文件组件,将单独发送")
continue
elif isinstance(comp, Record):
# 音频将通过 _send_audio_message 方法单独发送,这里跳过
logger.debug("[Lark] 检测到音频组件,将单独发送")
continue
elif isinstance(comp, Video):
# 视频将通过 _send_media_message 方法单独发送,这里跳过
logger.debug("[Lark] 检测到视频组件,将单独发送")
continue
else:
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
@@ -110,40 +267,270 @@ class LarkMessageEvent(AstrMessageEvent):
ret.append(_stage)
return ret
async def send(self, message: MessageChain):
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
wrapped = {
"zh_cn": {
"title": "",
"content": res,
},
}
@staticmethod
async def send_message_chain(
message_chain: MessageChain,
lark_client: lark.Client,
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
"""通用的消息链发送方法
request = (
ReplyMessageRequest.builder()
.message_id(self.message_obj.message_id)
.request_body(
ReplyMessageRequestBody.builder()
.content(json.dumps(wrapped))
.msg_type("post")
.uuid(str(uuid.uuid4()))
.reply_in_thread(False)
.build(),
)
.build()
)
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
Args:
message_chain: 要发送的消息链
lark_client: 飞书客户端
reply_message_id: 回复的消息ID(用于回复消息)
receive_id: 接收者ID(用于主动发送)
receive_id_type: 接收者ID类型,如 'open_id', 'chat_id'(用于主动发送)
"""
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化")
return
response = await self.bot.im.v1.message.areply(request)
# 分离文件、音频、视频组件和其他组件
file_components: list[File] = []
audio_components: list[Record] = []
media_components: list[Video] = []
other_components = []
if not response.success():
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
for comp in message_chain.chain:
if isinstance(comp, File):
file_components.append(comp)
elif isinstance(comp, Record):
audio_components.append(comp)
elif isinstance(comp, Video):
media_components.append(comp)
else:
other_components.append(comp)
# 先发送非文件内容(如果有)
if other_components:
temp_chain = MessageChain()
temp_chain.chain = other_components
res = await LarkMessageEvent._convert_to_lark(temp_chain, lark_client)
if res: # 只在有内容时发送
wrapped = {
"zh_cn": {
"title": "",
"content": res,
},
}
await LarkMessageEvent._send_im_message(
lark_client,
content=json.dumps(wrapped),
msg_type="post",
reply_message_id=reply_message_id,
receive_id=receive_id,
receive_id_type=receive_id_type,
)
# 发送附件
for file_comp in file_components:
await LarkMessageEvent._send_file_message(
file_comp, lark_client, reply_message_id, receive_id, receive_id_type
)
for audio_comp in audio_components:
await LarkMessageEvent._send_audio_message(
audio_comp, lark_client, reply_message_id, receive_id, receive_id_type
)
for media_comp in media_components:
await LarkMessageEvent._send_media_message(
media_comp, lark_client, reply_message_id, receive_id, receive_id_type
)
async def send(self, message: MessageChain):
"""发送消息链到飞书,然后交给父类做框架级发送/记录"""
await LarkMessageEvent.send_message_chain(
message,
self.bot,
reply_message_id=self.message_obj.message_id,
)
await super().send(message)
@staticmethod
async def _send_file_message(
file_comp: File,
lark_client: lark.Client,
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
"""发送文件消息
Args:
file_comp: 文件组件
lark_client: 飞书客户端
reply_message_id: 回复的消息ID(用于回复消息)
receive_id: 接收者ID(用于主动发送)
receive_id_type: 接收者ID类型(用于主动发送)
"""
file_path = file_comp.file or ""
file_key = await LarkMessageEvent._upload_lark_file(
lark_client, path=file_path, file_type="stream"
)
if not file_key:
return
content = json.dumps({"file_key": file_key})
await LarkMessageEvent._send_im_message(
lark_client,
content=content,
msg_type="file",
reply_message_id=reply_message_id,
receive_id=receive_id,
receive_id_type=receive_id_type,
)
@staticmethod
async def _send_audio_message(
audio_comp: Record,
lark_client: lark.Client,
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
"""发送音频消息
Args:
audio_comp: 音频组件
lark_client: 飞书客户端
reply_message_id: 回复的消息ID(用于回复消息)
receive_id: 接收者ID(用于主动发送)
receive_id_type: 接收者ID类型(用于主动发送)
"""
# 获取音频文件路径
try:
original_audio_path = await audio_comp.convert_to_file_path()
except Exception as e:
logger.error(f"[Lark] 无法获取音频文件路径: {e}")
return
if not original_audio_path or not os.path.exists(original_audio_path):
logger.error(f"[Lark] 音频文件不存在: {original_audio_path}")
return
# 转换为opus格式
converted_audio_path = None
try:
audio_path = await convert_audio_to_opus(original_audio_path)
# 如果转换后路径与原路径不同,说明生成了新文件
if audio_path != original_audio_path:
converted_audio_path = audio_path
else:
audio_path = original_audio_path
except Exception as e:
logger.error(f"[Lark] 音频格式转换失败,将尝试直接上传: {e}")
# 如果转换失败,继续尝试直接上传原始文件
audio_path = original_audio_path
# 获取音频时长
duration = await get_media_duration(audio_path)
# 上传音频文件
file_key = await LarkMessageEvent._upload_lark_file(
lark_client,
path=audio_path,
file_type="opus",
duration=duration,
)
# 清理转换后的临时音频文件
if converted_audio_path and os.path.exists(converted_audio_path):
try:
os.remove(converted_audio_path)
logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}")
except Exception as e:
logger.warning(f"[Lark] 删除转换后的音频文件失败: {e}")
if not file_key:
return
await LarkMessageEvent._send_im_message(
lark_client,
content=json.dumps({"file_key": file_key}),
msg_type="audio",
reply_message_id=reply_message_id,
receive_id=receive_id,
receive_id_type=receive_id_type,
)
@staticmethod
async def _send_media_message(
media_comp: Video,
lark_client: lark.Client,
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
"""发送视频消息
Args:
media_comp: 视频组件
lark_client: 飞书客户端
reply_message_id: 回复的消息ID(用于回复消息)
receive_id: 接收者ID(用于主动发送)
receive_id_type: 接收者ID类型(用于主动发送)
"""
# 获取视频文件路径
try:
original_video_path = await media_comp.convert_to_file_path()
except Exception as e:
logger.error(f"[Lark] 无法获取视频文件路径: {e}")
return
if not original_video_path or not os.path.exists(original_video_path):
logger.error(f"[Lark] 视频文件不存在: {original_video_path}")
return
# 转换为mp4格式
converted_video_path = None
try:
video_path = await convert_video_format(original_video_path, "mp4")
# 如果转换后路径与原路径不同,说明生成了新文件
if video_path != original_video_path:
converted_video_path = video_path
else:
video_path = original_video_path
except Exception as e:
logger.error(f"[Lark] 视频格式转换失败,将尝试直接上传: {e}")
# 如果转换失败,继续尝试直接上传原始文件
video_path = original_video_path
# 获取视频时长
duration = await get_media_duration(video_path)
# 上传视频文件
file_key = await LarkMessageEvent._upload_lark_file(
lark_client,
path=video_path,
file_type="mp4",
duration=duration,
)
# 清理转换后的临时视频文件
if converted_video_path and os.path.exists(converted_video_path):
try:
os.remove(converted_video_path)
logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}")
except Exception as e:
logger.warning(f"[Lark] 删除转换后的视频文件失败: {e}")
if not file_key:
return
await LarkMessageEvent._send_im_message(
lark_client,
content=json.dumps({"file_key": file_key}),
msg_type="media",
reply_message_id=reply_message_id,
receive_id=receive_id,
receive_id_type=receive_id_type,
)
async def react(self, emoji: str):
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
@@ -29,43 +29,11 @@ class QueueListener:
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
self.webchat_queue_mgr = webchat_queue_mgr
self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, conversation_id: str):
"""Listen to a specific conversation queue"""
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(
f"Error processing message from conversation {conversation_id}: {e}",
)
break
async def run(self):
"""Monitor for new conversation queues and start listeners"""
monitored_conversations = set()
while True:
# Check for new conversations
current_conversations = set(self.webchat_queue_mgr.queues.keys())
new_conversations = current_conversations - monitored_conversations
# Start listeners for new conversations
for conversation_id in new_conversations:
task = asyncio.create_task(self.listen_to_queue(conversation_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_conversations.add(conversation_id)
logger.debug(f"Started listener for conversation: {conversation_id}")
# Clean up monitored conversations that no longer exist
removed_conversations = monitored_conversations - current_conversations
monitored_conversations -= removed_conversations
await asyncio.sleep(1) # Check for new conversations every second
"""Register callback and keep adapter task alive."""
self.webchat_queue_mgr.set_listener(self.callback)
await asyncio.Event().wait()
@register_platform_adapter("webchat", "webchat")
@@ -26,8 +26,12 @@ class WebChatMessageEvent(AstrMessageEvent):
session_id: str,
streaming: bool = False,
) -> str | None:
cid = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
request_id = str(message_id)
conversation_id = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
)
if not message:
await web_chat_back_queue.put(
{
@@ -124,9 +128,13 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
reasoning_content = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
message_id = self.message_obj.message_id
request_id = str(message_id)
conversation_id = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
)
async for chain in generator:
# 处理音频流(Live Mode
if chain.type == "audio_chunk":
@@ -1,35 +1,147 @@
import asyncio
from collections.abc import Awaitable, Callable
from astrbot import logger
class WebChatQueueMgr:
def __init__(self) -> None:
self.queues = {}
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
self.queues: dict[str, asyncio.Queue] = {}
"""Conversation ID to asyncio.Queue mapping"""
self.back_queues = {}
"""Conversation ID to asyncio.Queue mapping for responses"""
self.back_queues: dict[str, asyncio.Queue] = {}
"""Request ID to asyncio.Queue mapping for responses"""
self._conversation_back_requests: dict[str, set[str]] = {}
self._request_conversation: dict[str, str] = {}
self._queue_close_events: dict[str, asyncio.Event] = {}
self._listener_tasks: dict[str, asyncio.Task] = {}
self._listener_callback: Callable[[tuple], Awaitable[None]] | None = None
self.queue_maxsize = queue_maxsize
self.back_queue_maxsize = back_queue_maxsize
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a queue for the given conversation ID"""
if conversation_id not in self.queues:
self.queues[conversation_id] = asyncio.Queue()
self.queues[conversation_id] = asyncio.Queue(maxsize=self.queue_maxsize)
self._queue_close_events[conversation_id] = asyncio.Event()
self._start_listener_if_needed(conversation_id)
return self.queues[conversation_id]
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a back queue for the given conversation ID"""
if conversation_id not in self.back_queues:
self.back_queues[conversation_id] = asyncio.Queue()
return self.back_queues[conversation_id]
def get_or_create_back_queue(
self,
request_id: str,
conversation_id: str | None = None,
) -> asyncio.Queue:
"""Get or create a back queue for the given request ID"""
if request_id not in self.back_queues:
self.back_queues[request_id] = asyncio.Queue(
maxsize=self.back_queue_maxsize
)
if conversation_id:
self._request_conversation[request_id] = conversation_id
if conversation_id not in self._conversation_back_requests:
self._conversation_back_requests[conversation_id] = set()
self._conversation_back_requests[conversation_id].add(request_id)
return self.back_queues[request_id]
def remove_back_queue(self, request_id: str):
"""Remove back queue for the given request ID"""
self.back_queues.pop(request_id, None)
conversation_id = self._request_conversation.pop(request_id, None)
if conversation_id:
request_ids = self._conversation_back_requests.get(conversation_id)
if request_ids is not None:
request_ids.discard(request_id)
if not request_ids:
self._conversation_back_requests.pop(conversation_id, None)
def remove_queues(self, conversation_id: str):
"""Remove queues for the given conversation ID"""
if conversation_id in self.queues:
del self.queues[conversation_id]
if conversation_id in self.back_queues:
del self.back_queues[conversation_id]
for request_id in list(
self._conversation_back_requests.get(conversation_id, set())
):
self.remove_back_queue(request_id)
self._conversation_back_requests.pop(conversation_id, None)
self.remove_queue(conversation_id)
def remove_queue(self, conversation_id: str):
"""Remove input queue and listener for the given conversation ID"""
self.queues.pop(conversation_id, None)
close_event = self._queue_close_events.pop(conversation_id, None)
if close_event is not None:
close_event.set()
task = self._listener_tasks.pop(conversation_id, None)
if task is not None:
task.cancel()
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
def set_listener(
self,
callback: Callable[[tuple], Awaitable[None]],
):
self._listener_callback = callback
for conversation_id in list(self.queues.keys()):
self._start_listener_if_needed(conversation_id)
def _start_listener_if_needed(self, conversation_id: str):
if self._listener_callback is None:
return
if conversation_id in self._listener_tasks:
task = self._listener_tasks[conversation_id]
if not task.done():
return
queue = self.queues.get(conversation_id)
close_event = self._queue_close_events.get(conversation_id)
if queue is None or close_event is None:
return
task = asyncio.create_task(
self._listen_to_queue(conversation_id, queue, close_event),
name=f"webchat_listener_{conversation_id}",
)
self._listener_tasks[conversation_id] = task
task.add_done_callback(
lambda _: self._listener_tasks.pop(conversation_id, None)
)
logger.debug(f"Started listener for conversation: {conversation_id}")
async def _listen_to_queue(
self,
conversation_id: str,
queue: asyncio.Queue,
close_event: asyncio.Event,
):
while True:
get_task = asyncio.create_task(queue.get())
close_task = asyncio.create_task(close_event.wait())
try:
done, pending = await asyncio.wait(
{get_task, close_task},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
if close_task in done:
break
data = get_task.result()
if self._listener_callback is None:
continue
try:
await self._listener_callback(data)
except Exception as e:
logger.error(
f"Error processing message from conversation {conversation_id}: {e}"
)
except asyncio.CancelledError:
break
finally:
if not get_task.done():
get_task.cancel()
if not close_task.done():
close_task.cancel()
webchat_queue_mgr = WebChatQueueMgr()
@@ -51,44 +51,13 @@ class WecomAIQueueListener:
) -> None:
self.queue_mgr = queue_mgr
self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, session_id: str):
"""监听特定会话的队列"""
queue = self.queue_mgr.get_or_create_queue(session_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
break
async def run(self):
"""监控新会话队列并启动监听器"""
monitored_sessions = set()
"""注册监听回调并定期清理过期响应。"""
self.queue_mgr.set_listener(self.callback)
while True:
# 检查新会话
current_sessions = set(self.queue_mgr.queues.keys())
new_sessions = current_sessions - monitored_sessions
# 为新会话启动监听器
for session_id in new_sessions:
task = asyncio.create_task(self.listen_to_queue(session_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_sessions.add(session_id)
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
# 清理已不存在的会话
removed_sessions = monitored_sessions - current_sessions
monitored_sessions -= removed_sessions
# 清理过期的待处理响应
self.queue_mgr.cleanup_expired_responses()
await asyncio.sleep(1) # 每秒检查一次新会话
await asyncio.sleep(1)
@register_platform_adapter(
@@ -212,7 +181,12 @@ class WecomAIBotAdapter(Platform):
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
if not self.queue_mgr.has_back_queue(stream_id):
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
if self.queue_mgr.is_stream_finished(stream_id):
logger.debug(
f"Stream already finished, returning end message: {stream_id}"
)
else:
logger.warning(f"Cannot find back queue for stream_id: {stream_id}")
# 返回结束标志,告诉微信服务器流已结束
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
@@ -243,10 +217,10 @@ class WecomAIBotAdapter(Platform):
latest_plain_content = msg["data"] or ""
elif msg["type"] == "image":
image_base64.append(msg["image_data"])
elif msg["type"] == "end":
elif msg["type"] in {"end", "complete"}:
# stream end
finish = True
self.queue_mgr.remove_queues(stream_id)
self.queue_mgr.remove_queues(stream_id, mark_finished=True)
break
logger.debug(
@@ -4,6 +4,7 @@
"""
import asyncio
from collections.abc import Awaitable, Callable
from typing import Any
from astrbot.api import logger
@@ -12,7 +13,7 @@ from astrbot.api import logger
class WecomAIQueueMgr:
"""企业微信智能机器人队列管理器"""
def __init__(self) -> None:
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
self.queues: dict[str, asyncio.Queue] = {}
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
@@ -21,6 +22,13 @@ class WecomAIQueueMgr:
self.pending_responses: dict[str, dict[str, Any]] = {}
"""待处理的响应缓存,用于流式响应"""
self.completed_streams: dict[str, float] = {}
"""已结束的 stream 缓存,用于兼容平台后续重复轮询"""
self._queue_close_events: dict[str, asyncio.Event] = {}
self._listener_tasks: dict[str, asyncio.Task] = {}
self._listener_callback: Callable[[dict], Awaitable[None]] | None = None
self.queue_maxsize = queue_maxsize
self.back_queue_maxsize = back_queue_maxsize
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
"""获取或创建指定会话的输入队列
@@ -33,7 +41,9 @@ class WecomAIQueueMgr:
"""
if session_id not in self.queues:
self.queues[session_id] = asyncio.Queue()
self.queues[session_id] = asyncio.Queue(maxsize=self.queue_maxsize)
self._queue_close_events[session_id] = asyncio.Event()
self._start_listener_if_needed(session_id)
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
return self.queues[session_id]
@@ -48,20 +58,21 @@ class WecomAIQueueMgr:
"""
if session_id not in self.back_queues:
self.back_queues[session_id] = asyncio.Queue()
self.back_queues[session_id] = asyncio.Queue(
maxsize=self.back_queue_maxsize
)
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
return self.back_queues[session_id]
def remove_queues(self, session_id: str):
def remove_queues(self, session_id: str, mark_finished: bool = False):
"""移除指定会话的所有队列
Args:
session_id: 会话ID
mark_finished: 是否标记为已正常结束
"""
if session_id in self.queues:
del self.queues[session_id]
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
self.remove_queue(session_id)
if session_id in self.back_queues:
del self.back_queues[session_id]
@@ -70,6 +81,23 @@ class WecomAIQueueMgr:
if session_id in self.pending_responses:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
if mark_finished:
self.completed_streams[session_id] = asyncio.get_event_loop().time()
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
def remove_queue(self, session_id: str):
"""仅移除输入队列和对应监听任务"""
if session_id in self.queues:
del self.queues[session_id]
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
close_event = self._queue_close_events.pop(session_id, None)
if close_event is not None:
close_event.set()
task = self._listener_tasks.pop(session_id, None)
if task is not None:
task.cancel()
def has_queue(self, session_id: str) -> bool:
"""检查是否存在指定会话的队列
@@ -121,6 +149,20 @@ class WecomAIQueueMgr:
"""
return self.pending_responses.get(session_id)
def is_stream_finished(
self,
session_id: str,
max_age_seconds: int = 60,
) -> bool:
"""判断 stream 是否在短期内已结束"""
finished_at = self.completed_streams.get(session_id)
if finished_at is None:
return False
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
self.completed_streams.pop(session_id, None)
return False
return True
def cleanup_expired_responses(self, max_age_seconds: int = 300):
"""清理过期的待处理响应
@@ -136,8 +178,75 @@ class WecomAIQueueMgr:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
self.remove_queues(session_id)
logger.debug(f"[WecomAI] 清理过期响应及队列: {session_id}")
expired_finished = [
session_id
for session_id, finished_at in self.completed_streams.items()
if current_time - finished_at > 60
]
for session_id in expired_finished:
self.completed_streams.pop(session_id, None)
def set_listener(
self,
callback: Callable[[dict], Awaitable[None]],
):
self._listener_callback = callback
for session_id in list(self.queues.keys()):
self._start_listener_if_needed(session_id)
def _start_listener_if_needed(self, session_id: str):
if self._listener_callback is None:
return
if session_id in self._listener_tasks:
task = self._listener_tasks[session_id]
if not task.done():
return
queue = self.queues.get(session_id)
close_event = self._queue_close_events.get(session_id)
if queue is None or close_event is None:
return
task = asyncio.create_task(
self._listen_to_queue(session_id, queue, close_event),
name=f"wecomai_listener_{session_id}",
)
self._listener_tasks[session_id] = task
task.add_done_callback(lambda _: self._listener_tasks.pop(session_id, None))
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
async def _listen_to_queue(
self,
session_id: str,
queue: asyncio.Queue,
close_event: asyncio.Event,
):
while True:
get_task = asyncio.create_task(queue.get())
close_task = asyncio.create_task(close_event.wait())
try:
done, pending = await asyncio.wait(
{get_task, close_task},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
if close_task in done:
break
data = get_task.result()
if self._listener_callback is None:
continue
try:
await self._listener_callback(data)
except Exception as e:
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
except asyncio.CancelledError:
break
finally:
if not get_task.done():
get_task.cancel()
if not close_task.done():
close_task.cancel()
def get_stats(self) -> dict[str, int]:
"""获取队列统计信息
+207
View File
@@ -0,0 +1,207 @@
"""媒体文件处理工具
提供音视频格式转换、时长获取等功能。
"""
import asyncio
import os
import subprocess
import uuid
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
async def get_media_duration(file_path: str) -> int | None:
"""使用ffprobe获取媒体文件时长
Args:
file_path: 媒体文件路径
Returns:
时长(毫秒),如果获取失败返回None
"""
try:
# 使用ffprobe获取时长
process = await asyncio.create_subprocess_exec(
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
file_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode == 0 and stdout:
duration_seconds = float(stdout.decode().strip())
duration_ms = int(duration_seconds * 1000)
logger.debug(f"[Media Utils] 获取媒体时长: {duration_ms}ms")
return duration_ms
else:
logger.warning(f"[Media Utils] 无法获取媒体文件时长: {file_path}")
return None
except FileNotFoundError:
logger.warning(
"[Media Utils] ffprobe未安装或不在PATH中,无法获取媒体时长。请安装ffmpeg: https://ffmpeg.org/"
)
return None
except Exception as e:
logger.warning(f"[Media Utils] 获取媒体时长时出错: {e}")
return None
async def convert_audio_to_opus(audio_path: str, output_path: str | None = None) -> str:
"""使用ffmpeg将音频转换为opus格式
Args:
audio_path: 原始音频文件路径
output_path: 输出文件路径,如果为None则自动生成
Returns:
转换后的opus文件路径
Raises:
Exception: 转换失败时抛出异常
"""
# 如果已经是opus格式,直接返回
if audio_path.lower().endswith(".opus"):
return audio_path
# 生成输出文件路径
if output_path is None:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.opus")
try:
# 使用ffmpeg转换为opus格式
# -y: 覆盖输出文件
# -i: 输入文件
# -acodec libopus: 使用opus编码器
# -ac 1: 单声道
# -ar 16000: 采样率16kHz
process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-i",
audio_path,
"-acodec",
"libopus",
"-ac",
"1",
"-ar",
"16000",
output_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
# 清理可能已生成但无效的临时文件
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
logger.debug(
f"[Media Utils] 已清理失败的opus输出文件: {output_path}"
)
except OSError as e:
logger.warning(f"[Media Utils] 清理失败的opus输出文件时出错: {e}")
error_msg = stderr.decode() if stderr else "未知错误"
logger.error(f"[Media Utils] ffmpeg转换音频失败: {error_msg}")
raise Exception(f"ffmpeg conversion failed: {error_msg}")
logger.debug(f"[Media Utils] 音频转换成功: {audio_path} -> {output_path}")
return output_path
except FileNotFoundError:
logger.error(
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换音频格式。请安装ffmpeg: https://ffmpeg.org/"
)
raise Exception("ffmpeg not found")
except Exception as e:
logger.error(f"[Media Utils] 转换音频格式时出错: {e}")
raise
async def convert_video_format(
video_path: str, output_format: str = "mp4", output_path: str | None = None
) -> str:
"""使用ffmpeg转换视频格式
Args:
video_path: 原始视频文件路径
output_format: 目标格式,默认mp4
output_path: 输出文件路径,如果为None则自动生成
Returns:
转换后的视频文件路径
Raises:
Exception: 转换失败时抛出异常
"""
# 如果已经是目标格式,直接返回
if video_path.lower().endswith(f".{output_format}"):
return video_path
# 生成输出文件路径
if output_path is None:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{output_format}")
try:
# 使用ffmpeg转换视频格式
process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-i",
video_path,
"-c:v",
"libx264",
"-c:a",
"aac",
output_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
# 清理可能已生成但无效的临时文件
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
logger.debug(
f"[Media Utils] 已清理失败的{output_format}输出文件: {output_path}"
)
except OSError as e:
logger.warning(
f"[Media Utils] 清理失败的{output_format}输出文件时出错: {e}"
)
error_msg = stderr.decode() if stderr else "未知错误"
logger.error(f"[Media Utils] ffmpeg转换视频失败: {error_msg}")
raise Exception(f"ffmpeg conversion failed: {error_msg}")
logger.debug(f"[Media Utils] 视频转换成功: {video_path} -> {output_path}")
return output_path
except FileNotFoundError:
logger.error(
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换视频格式。请安装ffmpeg: https://ffmpeg.org/"
)
raise Exception("ffmpeg not found")
except Exception as e:
logger.error(f"[Media Utils] 转换视频格式时出错: {e}")
raise
+2 -2
View File
@@ -23,7 +23,7 @@ class SharedPreferences:
)
self.path = json_storage_path
self.db_helper = db_helper
self.temorary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
self.temporary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
"""automatically clear per 24 hours. Might be helpful in some cases XD"""
self._sync_loop = asyncio.new_event_loop()
@@ -37,7 +37,7 @@ class SharedPreferences:
self._scheduler.start()
def _clear_temporary_cache(self):
self.temorary_cache.clear()
self.temporary_cache.clear()
async def get_async(
self,
+6 -1
View File
@@ -354,12 +354,15 @@ class ChatRoute(Route):
return Response().error("session_id is empty").__dict__
webchat_conv_id = session_id
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
# 构建用户消息段(包含 path 用于传递给 adapter
message_parts = await self._build_user_message_parts(message)
message_id = str(uuid.uuid4())
back_queue = webchat_queue_mgr.get_or_create_back_queue(
message_id,
webchat_conv_id,
)
async def stream():
client_disconnected = False
@@ -532,6 +535,8 @@ class ChatRoute(Route):
refs = {}
except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
finally:
webchat_queue_mgr.remove_back_queue(message_id)
# 将消息放入会话特定的队列
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
+127 -122
View File
@@ -256,143 +256,148 @@ class LiveChatRoute(Route):
await queue.put((session.username, cid, payload))
# 3. 等待响应并流式发送 TTS 音频
back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid)
bot_text = ""
audio_playing = False
while True:
if session.should_interrupt:
# 用户打断,停止处理
logger.info("[Live Chat] 检测到用户打断")
await websocket.send_json({"t": "stop_play"})
# 保存消息并标记为被打断
await self._save_interrupted_message(session, user_text, bot_text)
# 清空队列中未处理的消息
while not back_queue.empty():
try:
while True:
if session.should_interrupt:
# 用户打断,停止处理
logger.info("[Live Chat] 检测到用户打断")
await websocket.send_json({"t": "stop_play"})
# 保存消息并标记为被打断
await self._save_interrupted_message(
session, user_text, bot_text
)
# 清空队列中未处理的消息
while not back_queue.empty():
try:
back_queue.get_nowait()
except asyncio.QueueEmpty:
break
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
except asyncio.TimeoutError:
continue
if not result:
continue
result_message_id = result.get("message_id")
if result_message_id != message_id:
logger.warning(
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
)
continue
result_type = result.get("type")
result_chain_type = result.get("chain_type")
data = result.get("data", "")
if result_chain_type == "agent_stats":
try:
back_queue.get_nowait()
except asyncio.QueueEmpty:
break
break
stats = json.loads(data)
await websocket.send_json(
{
"t": "metrics",
"data": {
"llm_ttft": stats.get("time_to_first_token", 0),
"llm_total_time": stats.get("end_time", 0)
- stats.get("start_time", 0),
},
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
continue
try:
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
except asyncio.TimeoutError:
continue
if result_chain_type == "tts_stats":
try:
stats = json.loads(data)
await websocket.send_json(
{
"t": "metrics",
"data": stats,
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
continue
if not result:
continue
if result_type == "plain":
# 普通文本消息
bot_text += data
result_message_id = result.get("message_id")
if result_message_id != message_id:
logger.warning(
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
)
continue
elif result_type == "audio_chunk":
# 流式音频数据
if not audio_playing:
audio_playing = True
logger.debug("[Live Chat] 开始播放音频流")
result_type = result.get("type")
result_chain_type = result.get("chain_type")
data = result.get("data", "")
# Calculate latency from wav assembly finish to first audio chunk
speak_to_first_frame_latency = (
time.time() - wav_assembly_finish_time
)
await websocket.send_json(
{
"t": "metrics",
"data": {
"speak_to_first_frame": speak_to_first_frame_latency
},
}
)
if result_chain_type == "agent_stats":
try:
stats = json.loads(data)
text = result.get("text")
if text:
await websocket.send_json(
{
"t": "bot_text_chunk",
"data": {"text": text},
}
)
# 发送音频数据给前端
await websocket.send_json(
{
"t": "response",
"data": data, # base64 编码的音频数据
}
)
elif result_type in ["complete", "end"]:
# 处理完成
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
# 如果没有音频流,发送 bot 消息文本
if not audio_playing:
await websocket.send_json(
{
"t": "bot_msg",
"data": {
"text": bot_text,
"ts": int(time.time() * 1000),
},
}
)
# 发送结束标记
await websocket.send_json({"t": "end"})
# 发送总耗时
wav_to_tts_duration = time.time() - wav_assembly_finish_time
await websocket.send_json(
{
"t": "metrics",
"data": {
"llm_ttft": stats.get("time_to_first_token", 0),
"llm_total_time": stats.get("end_time", 0)
- stats.get("start_time", 0),
},
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
continue
if result_chain_type == "tts_stats":
try:
stats = json.loads(data)
await websocket.send_json(
{
"t": "metrics",
"data": stats,
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
continue
if result_type == "plain":
# 普通文本消息
bot_text += data
elif result_type == "audio_chunk":
# 流式音频数据
if not audio_playing:
audio_playing = True
logger.debug("[Live Chat] 开始播放音频流")
# Calculate latency from wav assembly finish to first audio chunk
speak_to_first_frame_latency = (
time.time() - wav_assembly_finish_time
)
await websocket.send_json(
{
"t": "metrics",
"data": {
"speak_to_first_frame": speak_to_first_frame_latency
},
}
)
text = result.get("text")
if text:
await websocket.send_json(
{
"t": "bot_text_chunk",
"data": {"text": text},
}
)
# 发送音频数据给前端
await websocket.send_json(
{
"t": "response",
"data": data, # base64 编码的音频数据
}
)
elif result_type in ["complete", "end"]:
# 处理完成
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
# 如果没有音频流,发送 bot 消息文本
if not audio_playing:
await websocket.send_json(
{
"t": "bot_msg",
"data": {
"text": bot_text,
"ts": int(time.time() * 1000),
},
}
)
# 发送结束标记
await websocket.send_json({"t": "end"})
# 发送总耗时
wav_to_tts_duration = time.time() - wav_assembly_finish_time
await websocket.send_json(
{
"t": "metrics",
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
}
)
break
break
finally:
webchat_queue_mgr.remove_back_queue(message_id)
except Exception as e:
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
+10
View File
@@ -0,0 +1,10 @@
## What's Changed
### 修复
- 修复一些原因导致 Tavily WebSearch、Bocha WebSearch 无法使用的问题
### xinzeng
- 飞书支持 Bot 发送文件、图片和视频消息类型。
### 优化
- 优化 WebChat 和 企业微信 AI 会话队列生命周期管理,减少内存泄漏,提高性能。
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.14.5"
version = "4.14.6"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.10"