Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b4498a554 | |||
| 5e5207da95 | |||
| def8b730b7 | |||
| 22a109c2ae |
@@ -1 +1 @@
|
||||
__version__ = "4.14.5"
|
||||
__version__ = "4.14.6"
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.14.5"
|
||||
VERSION = "4.14.6"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
|
||||
@@ -3,13 +3,10 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
GetMessageResourceRequest,
|
||||
)
|
||||
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
|
||||
@@ -125,44 +122,23 @@ class LarkPlatformAdapter(Platform):
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||
return
|
||||
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
id_type = "chat_id"
|
||||
if "%" in session.session_id:
|
||||
session.session_id = session.session_id.split("%")[1]
|
||||
receive_id = session.session_id
|
||||
if "%" in receive_id:
|
||||
receive_id = receive_id.split("%")[1]
|
||||
else:
|
||||
id_type = "open_id"
|
||||
receive_id = session.session_id
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(session.session_id)
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build(),
|
||||
)
|
||||
.build()
|
||||
# 复用 LarkMessageEvent 中的通用发送逻辑
|
||||
await LarkMessageEvent.send_message_chain(
|
||||
message_chain,
|
||||
self.lark_api,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=id_type,
|
||||
)
|
||||
|
||||
response = await self.lark_api.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
|
||||
@@ -6,6 +6,8 @@ from io import BytesIO
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateFileRequest,
|
||||
CreateFileRequestBody,
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
@@ -17,10 +19,15 @@ from lark_oapi.api.im.v1 import (
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Plain
|
||||
from astrbot.api.message_components import At, File, Plain, Record, Video
|
||||
from astrbot.api.message_components import Image as AstrBotImage
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.media_utils import (
|
||||
convert_audio_to_opus,
|
||||
convert_video_format,
|
||||
get_media_duration,
|
||||
)
|
||||
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
@@ -35,6 +42,144 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _send_im_message(
|
||||
lark_client: lark.Client,
|
||||
*,
|
||||
content: str,
|
||||
msg_type: str,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
) -> bool:
|
||||
"""发送飞书 IM 消息的通用辅助函数
|
||||
|
||||
Args:
|
||||
lark_client: 飞书客户端
|
||||
content: 消息内容(JSON字符串)
|
||||
msg_type: 消息类型(post/file/audio/media等)
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
return False
|
||||
|
||||
if reply_message_id:
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(reply_message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(content)
|
||||
.msg_type(msg_type)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.reply_in_thread(False)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.message.areply(request)
|
||||
else:
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
)
|
||||
|
||||
if receive_id_type is None or receive_id is None:
|
||||
logger.error(
|
||||
"[Lark] 主动发送消息时,receive_id 和 receive_id_type 不能为空",
|
||||
)
|
||||
return False
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(receive_id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.content(content)
|
||||
.msg_type(msg_type)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"[Lark] 发送飞书消息失败({response.code}): {response.msg}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _upload_lark_file(
|
||||
lark_client: lark.Client,
|
||||
*,
|
||||
path: str,
|
||||
file_type: str,
|
||||
duration: int | None = None,
|
||||
) -> str | None:
|
||||
"""上传文件到飞书的通用辅助函数
|
||||
|
||||
Args:
|
||||
lark_client: 飞书客户端
|
||||
path: 文件路径
|
||||
file_type: 文件类型(stream/opus/mp4等)
|
||||
duration: 媒体时长(毫秒),可选
|
||||
|
||||
Returns:
|
||||
成功返回file_key,失败返回None
|
||||
"""
|
||||
if not path or not os.path.exists(path):
|
||||
logger.error(f"[Lark] 文件不存在: {path}")
|
||||
return None
|
||||
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法上传文件")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, "rb") as file_obj:
|
||||
body_builder = (
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(os.path.basename(path))
|
||||
.file(file_obj)
|
||||
)
|
||||
if duration is not None:
|
||||
body_builder.duration(duration)
|
||||
|
||||
request = (
|
||||
CreateFileRequest.builder()
|
||||
.request_body(body_builder.build())
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.file.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(
|
||||
f"[Lark] 无法上传文件({response.code}): {response.msg}"
|
||||
)
|
||||
return None
|
||||
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
|
||||
return None
|
||||
|
||||
file_key = response.data.file_key
|
||||
logger.debug(f"[Lark] 文件上传成功: {file_key}")
|
||||
return file_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开或上传文件: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> list:
|
||||
ret = []
|
||||
@@ -103,6 +248,18 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
ret.append(_stage)
|
||||
ret.append([{"tag": "img", "image_key": image_key}])
|
||||
_stage.clear()
|
||||
elif isinstance(comp, File):
|
||||
# 文件将通过 _send_file_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到文件组件,将单独发送")
|
||||
continue
|
||||
elif isinstance(comp, Record):
|
||||
# 音频将通过 _send_audio_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到音频组件,将单独发送")
|
||||
continue
|
||||
elif isinstance(comp, Video):
|
||||
# 视频将通过 _send_media_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到视频组件,将单独发送")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
|
||||
|
||||
@@ -110,40 +267,270 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
ret.append(_stage)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
@staticmethod
|
||||
async def send_message_chain(
|
||||
message_chain: MessageChain,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""通用的消息链发送方法
|
||||
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.reply_in_thread(False)
|
||||
.build(),
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
|
||||
Args:
|
||||
message_chain: 要发送的消息链
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型,如 'open_id', 'chat_id'(用于主动发送)
|
||||
"""
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
return
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
# 分离文件、音频、视频组件和其他组件
|
||||
file_components: list[File] = []
|
||||
audio_components: list[Record] = []
|
||||
media_components: list[Video] = []
|
||||
other_components = []
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, File):
|
||||
file_components.append(comp)
|
||||
elif isinstance(comp, Record):
|
||||
audio_components.append(comp)
|
||||
elif isinstance(comp, Video):
|
||||
media_components.append(comp)
|
||||
else:
|
||||
other_components.append(comp)
|
||||
|
||||
# 先发送非文件内容(如果有)
|
||||
if other_components:
|
||||
temp_chain = MessageChain()
|
||||
temp_chain.chain = other_components
|
||||
res = await LarkMessageEvent._convert_to_lark(temp_chain, lark_client)
|
||||
|
||||
if res: # 只在有内容时发送
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps(wrapped),
|
||||
msg_type="post",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
# 发送附件
|
||||
for file_comp in file_components:
|
||||
await LarkMessageEvent._send_file_message(
|
||||
file_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
for audio_comp in audio_components:
|
||||
await LarkMessageEvent._send_audio_message(
|
||||
audio_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
for media_comp in media_components:
|
||||
await LarkMessageEvent._send_media_message(
|
||||
media_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息链到飞书,然后交给父类做框架级发送/记录"""
|
||||
await LarkMessageEvent.send_message_chain(
|
||||
message,
|
||||
self.bot,
|
||||
reply_message_id=self.message_obj.message_id,
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
async def _send_file_message(
|
||||
file_comp: File,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送文件消息
|
||||
|
||||
Args:
|
||||
file_comp: 文件组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
file_path = file_comp.file or ""
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client, path=file_path, file_type="stream"
|
||||
)
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
content = json.dumps({"file_key": file_key})
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=content,
|
||||
msg_type="file",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_audio_message(
|
||||
audio_comp: Record,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送音频消息
|
||||
|
||||
Args:
|
||||
audio_comp: 音频组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
# 获取音频文件路径
|
||||
try:
|
||||
original_audio_path = await audio_comp.convert_to_file_path()
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法获取音频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_audio_path or not os.path.exists(original_audio_path):
|
||||
logger.error(f"[Lark] 音频文件不存在: {original_audio_path}")
|
||||
return
|
||||
|
||||
# 转换为opus格式
|
||||
converted_audio_path = None
|
||||
try:
|
||||
audio_path = await convert_audio_to_opus(original_audio_path)
|
||||
# 如果转换后路径与原路径不同,说明生成了新文件
|
||||
if audio_path != original_audio_path:
|
||||
converted_audio_path = audio_path
|
||||
else:
|
||||
audio_path = original_audio_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 音频格式转换失败,将尝试直接上传: {e}")
|
||||
# 如果转换失败,继续尝试直接上传原始文件
|
||||
audio_path = original_audio_path
|
||||
|
||||
# 获取音频时长
|
||||
duration = await get_media_duration(audio_path)
|
||||
|
||||
# 上传音频文件
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client,
|
||||
path=audio_path,
|
||||
file_type="opus",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
# 清理转换后的临时音频文件
|
||||
if converted_audio_path and os.path.exists(converted_audio_path):
|
||||
try:
|
||||
os.remove(converted_audio_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Lark] 删除转换后的音频文件失败: {e}")
|
||||
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps({"file_key": file_key}),
|
||||
msg_type="audio",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_media_message(
|
||||
media_comp: Video,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送视频消息
|
||||
|
||||
Args:
|
||||
media_comp: 视频组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
# 获取视频文件路径
|
||||
try:
|
||||
original_video_path = await media_comp.convert_to_file_path()
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法获取视频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_video_path or not os.path.exists(original_video_path):
|
||||
logger.error(f"[Lark] 视频文件不存在: {original_video_path}")
|
||||
return
|
||||
|
||||
# 转换为mp4格式
|
||||
converted_video_path = None
|
||||
try:
|
||||
video_path = await convert_video_format(original_video_path, "mp4")
|
||||
# 如果转换后路径与原路径不同,说明生成了新文件
|
||||
if video_path != original_video_path:
|
||||
converted_video_path = video_path
|
||||
else:
|
||||
video_path = original_video_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 视频格式转换失败,将尝试直接上传: {e}")
|
||||
# 如果转换失败,继续尝试直接上传原始文件
|
||||
video_path = original_video_path
|
||||
|
||||
# 获取视频时长
|
||||
duration = await get_media_duration(video_path)
|
||||
|
||||
# 上传视频文件
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client,
|
||||
path=video_path,
|
||||
file_type="mp4",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
# 清理转换后的临时视频文件
|
||||
if converted_video_path and os.path.exists(converted_video_path):
|
||||
try:
|
||||
os.remove(converted_video_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Lark] 删除转换后的视频文件失败: {e}")
|
||||
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps({"file_key": file_key}),
|
||||
msg_type="media",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||
|
||||
@@ -29,43 +29,11 @@ class QueueListener:
|
||||
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
|
||||
self.webchat_queue_mgr = webchat_queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, conversation_id: str):
|
||||
"""Listen to a specific conversation queue"""
|
||||
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing message from conversation {conversation_id}: {e}",
|
||||
)
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""Monitor for new conversation queues and start listeners"""
|
||||
monitored_conversations = set()
|
||||
|
||||
while True:
|
||||
# Check for new conversations
|
||||
current_conversations = set(self.webchat_queue_mgr.queues.keys())
|
||||
new_conversations = current_conversations - monitored_conversations
|
||||
|
||||
# Start listeners for new conversations
|
||||
for conversation_id in new_conversations:
|
||||
task = asyncio.create_task(self.listen_to_queue(conversation_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_conversations.add(conversation_id)
|
||||
logger.debug(f"Started listener for conversation: {conversation_id}")
|
||||
|
||||
# Clean up monitored conversations that no longer exist
|
||||
removed_conversations = monitored_conversations - current_conversations
|
||||
monitored_conversations -= removed_conversations
|
||||
|
||||
await asyncio.sleep(1) # Check for new conversations every second
|
||||
"""Register callback and keep adapter task alive."""
|
||||
self.webchat_queue_mgr.set_listener(self.callback)
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
@register_platform_adapter("webchat", "webchat")
|
||||
|
||||
@@ -26,8 +26,12 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
session_id: str,
|
||||
streaming: bool = False,
|
||||
) -> str | None:
|
||||
cid = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
request_id = str(message_id)
|
||||
conversation_id = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
)
|
||||
if not message:
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
@@ -124,9 +128,13 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
reasoning_content = ""
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
message_id = self.message_obj.message_id
|
||||
request_id = str(message_id)
|
||||
conversation_id = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
)
|
||||
async for chain in generator:
|
||||
# 处理音频流(Live Mode)
|
||||
if chain.type == "audio_chunk":
|
||||
|
||||
@@ -1,35 +1,147 @@
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class WebChatQueueMgr:
|
||||
def __init__(self) -> None:
|
||||
self.queues = {}
|
||||
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
"""Conversation ID to asyncio.Queue mapping"""
|
||||
self.back_queues = {}
|
||||
"""Conversation ID to asyncio.Queue mapping for responses"""
|
||||
self.back_queues: dict[str, asyncio.Queue] = {}
|
||||
"""Request ID to asyncio.Queue mapping for responses"""
|
||||
self._conversation_back_requests: dict[str, set[str]] = {}
|
||||
self._request_conversation: dict[str, str] = {}
|
||||
self._queue_close_events: dict[str, asyncio.Event] = {}
|
||||
self._listener_tasks: dict[str, asyncio.Task] = {}
|
||||
self._listener_callback: Callable[[tuple], Awaitable[None]] | None = None
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.back_queue_maxsize = back_queue_maxsize
|
||||
|
||||
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||
"""Get or create a queue for the given conversation ID"""
|
||||
if conversation_id not in self.queues:
|
||||
self.queues[conversation_id] = asyncio.Queue()
|
||||
self.queues[conversation_id] = asyncio.Queue(maxsize=self.queue_maxsize)
|
||||
self._queue_close_events[conversation_id] = asyncio.Event()
|
||||
self._start_listener_if_needed(conversation_id)
|
||||
return self.queues[conversation_id]
|
||||
|
||||
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||
"""Get or create a back queue for the given conversation ID"""
|
||||
if conversation_id not in self.back_queues:
|
||||
self.back_queues[conversation_id] = asyncio.Queue()
|
||||
return self.back_queues[conversation_id]
|
||||
def get_or_create_back_queue(
|
||||
self,
|
||||
request_id: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> asyncio.Queue:
|
||||
"""Get or create a back queue for the given request ID"""
|
||||
if request_id not in self.back_queues:
|
||||
self.back_queues[request_id] = asyncio.Queue(
|
||||
maxsize=self.back_queue_maxsize
|
||||
)
|
||||
if conversation_id:
|
||||
self._request_conversation[request_id] = conversation_id
|
||||
if conversation_id not in self._conversation_back_requests:
|
||||
self._conversation_back_requests[conversation_id] = set()
|
||||
self._conversation_back_requests[conversation_id].add(request_id)
|
||||
return self.back_queues[request_id]
|
||||
|
||||
def remove_back_queue(self, request_id: str):
|
||||
"""Remove back queue for the given request ID"""
|
||||
self.back_queues.pop(request_id, None)
|
||||
conversation_id = self._request_conversation.pop(request_id, None)
|
||||
if conversation_id:
|
||||
request_ids = self._conversation_back_requests.get(conversation_id)
|
||||
if request_ids is not None:
|
||||
request_ids.discard(request_id)
|
||||
if not request_ids:
|
||||
self._conversation_back_requests.pop(conversation_id, None)
|
||||
|
||||
def remove_queues(self, conversation_id: str):
|
||||
"""Remove queues for the given conversation ID"""
|
||||
if conversation_id in self.queues:
|
||||
del self.queues[conversation_id]
|
||||
if conversation_id in self.back_queues:
|
||||
del self.back_queues[conversation_id]
|
||||
for request_id in list(
|
||||
self._conversation_back_requests.get(conversation_id, set())
|
||||
):
|
||||
self.remove_back_queue(request_id)
|
||||
self._conversation_back_requests.pop(conversation_id, None)
|
||||
self.remove_queue(conversation_id)
|
||||
|
||||
def remove_queue(self, conversation_id: str):
|
||||
"""Remove input queue and listener for the given conversation ID"""
|
||||
self.queues.pop(conversation_id, None)
|
||||
|
||||
close_event = self._queue_close_events.pop(conversation_id, None)
|
||||
if close_event is not None:
|
||||
close_event.set()
|
||||
|
||||
task = self._listener_tasks.pop(conversation_id, None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
|
||||
def has_queue(self, conversation_id: str) -> bool:
|
||||
"""Check if a queue exists for the given conversation ID"""
|
||||
return conversation_id in self.queues
|
||||
|
||||
def set_listener(
|
||||
self,
|
||||
callback: Callable[[tuple], Awaitable[None]],
|
||||
):
|
||||
self._listener_callback = callback
|
||||
for conversation_id in list(self.queues.keys()):
|
||||
self._start_listener_if_needed(conversation_id)
|
||||
|
||||
def _start_listener_if_needed(self, conversation_id: str):
|
||||
if self._listener_callback is None:
|
||||
return
|
||||
if conversation_id in self._listener_tasks:
|
||||
task = self._listener_tasks[conversation_id]
|
||||
if not task.done():
|
||||
return
|
||||
queue = self.queues.get(conversation_id)
|
||||
close_event = self._queue_close_events.get(conversation_id)
|
||||
if queue is None or close_event is None:
|
||||
return
|
||||
task = asyncio.create_task(
|
||||
self._listen_to_queue(conversation_id, queue, close_event),
|
||||
name=f"webchat_listener_{conversation_id}",
|
||||
)
|
||||
self._listener_tasks[conversation_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self._listener_tasks.pop(conversation_id, None)
|
||||
)
|
||||
logger.debug(f"Started listener for conversation: {conversation_id}")
|
||||
|
||||
async def _listen_to_queue(
|
||||
self,
|
||||
conversation_id: str,
|
||||
queue: asyncio.Queue,
|
||||
close_event: asyncio.Event,
|
||||
):
|
||||
while True:
|
||||
get_task = asyncio.create_task(queue.get())
|
||||
close_task = asyncio.create_task(close_event.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
{get_task, close_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if close_task in done:
|
||||
break
|
||||
data = get_task.result()
|
||||
if self._listener_callback is None:
|
||||
continue
|
||||
try:
|
||||
await self._listener_callback(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing message from conversation {conversation_id}: {e}"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
if not get_task.done():
|
||||
get_task.cancel()
|
||||
if not close_task.done():
|
||||
close_task.cancel()
|
||||
|
||||
|
||||
webchat_queue_mgr = WebChatQueueMgr()
|
||||
|
||||
@@ -51,44 +51,13 @@ class WecomAIQueueListener:
|
||||
) -> None:
|
||||
self.queue_mgr = queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, session_id: str):
|
||||
"""监听特定会话的队列"""
|
||||
queue = self.queue_mgr.get_or_create_queue(session_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""监控新会话队列并启动监听器"""
|
||||
monitored_sessions = set()
|
||||
|
||||
"""注册监听回调并定期清理过期响应。"""
|
||||
self.queue_mgr.set_listener(self.callback)
|
||||
while True:
|
||||
# 检查新会话
|
||||
current_sessions = set(self.queue_mgr.queues.keys())
|
||||
new_sessions = current_sessions - monitored_sessions
|
||||
|
||||
# 为新会话启动监听器
|
||||
for session_id in new_sessions:
|
||||
task = asyncio.create_task(self.listen_to_queue(session_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_sessions.add(session_id)
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
# 清理已不存在的会话
|
||||
removed_sessions = monitored_sessions - current_sessions
|
||||
monitored_sessions -= removed_sessions
|
||||
|
||||
# 清理过期的待处理响应
|
||||
self.queue_mgr.cleanup_expired_responses()
|
||||
|
||||
await asyncio.sleep(1) # 每秒检查一次新会话
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -212,7 +181,12 @@ class WecomAIBotAdapter(Platform):
|
||||
# wechat server is requesting for updates of a stream
|
||||
stream_id = message_data["stream"]["id"]
|
||||
if not self.queue_mgr.has_back_queue(stream_id):
|
||||
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
if self.queue_mgr.is_stream_finished(stream_id):
|
||||
logger.debug(
|
||||
f"Stream already finished, returning end message: {stream_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
|
||||
# 返回结束标志,告诉微信服务器流已结束
|
||||
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
@@ -243,10 +217,10 @@ class WecomAIBotAdapter(Platform):
|
||||
latest_plain_content = msg["data"] or ""
|
||||
elif msg["type"] == "image":
|
||||
image_base64.append(msg["image_data"])
|
||||
elif msg["type"] == "end":
|
||||
elif msg["type"] in {"end", "complete"}:
|
||||
# stream end
|
||||
finish = True
|
||||
self.queue_mgr.remove_queues(stream_id)
|
||||
self.queue_mgr.remove_queues(stream_id, mark_finished=True)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -12,7 +13,7 @@ from astrbot.api import logger
|
||||
class WecomAIQueueMgr:
|
||||
"""企业微信智能机器人队列管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
|
||||
|
||||
@@ -21,6 +22,13 @@ class WecomAIQueueMgr:
|
||||
|
||||
self.pending_responses: dict[str, dict[str, Any]] = {}
|
||||
"""待处理的响应缓存,用于流式响应"""
|
||||
self.completed_streams: dict[str, float] = {}
|
||||
"""已结束的 stream 缓存,用于兼容平台后续重复轮询"""
|
||||
self._queue_close_events: dict[str, asyncio.Event] = {}
|
||||
self._listener_tasks: dict[str, asyncio.Task] = {}
|
||||
self._listener_callback: Callable[[dict], Awaitable[None]] | None = None
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.back_queue_maxsize = back_queue_maxsize
|
||||
|
||||
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
|
||||
"""获取或创建指定会话的输入队列
|
||||
@@ -33,7 +41,9 @@ class WecomAIQueueMgr:
|
||||
|
||||
"""
|
||||
if session_id not in self.queues:
|
||||
self.queues[session_id] = asyncio.Queue()
|
||||
self.queues[session_id] = asyncio.Queue(maxsize=self.queue_maxsize)
|
||||
self._queue_close_events[session_id] = asyncio.Event()
|
||||
self._start_listener_if_needed(session_id)
|
||||
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
|
||||
return self.queues[session_id]
|
||||
|
||||
@@ -48,20 +58,21 @@ class WecomAIQueueMgr:
|
||||
|
||||
"""
|
||||
if session_id not in self.back_queues:
|
||||
self.back_queues[session_id] = asyncio.Queue()
|
||||
self.back_queues[session_id] = asyncio.Queue(
|
||||
maxsize=self.back_queue_maxsize
|
||||
)
|
||||
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
|
||||
return self.back_queues[session_id]
|
||||
|
||||
def remove_queues(self, session_id: str):
|
||||
def remove_queues(self, session_id: str, mark_finished: bool = False):
|
||||
"""移除指定会话的所有队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
mark_finished: 是否标记为已正常结束
|
||||
|
||||
"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
self.remove_queue(session_id)
|
||||
|
||||
if session_id in self.back_queues:
|
||||
del self.back_queues[session_id]
|
||||
@@ -70,6 +81,23 @@ class WecomAIQueueMgr:
|
||||
if session_id in self.pending_responses:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||
if mark_finished:
|
||||
self.completed_streams[session_id] = asyncio.get_event_loop().time()
|
||||
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
||||
|
||||
def remove_queue(self, session_id: str):
|
||||
"""仅移除输入队列和对应监听任务"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
|
||||
close_event = self._queue_close_events.pop(session_id, None)
|
||||
if close_event is not None:
|
||||
close_event.set()
|
||||
|
||||
task = self._listener_tasks.pop(session_id, None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
|
||||
def has_queue(self, session_id: str) -> bool:
|
||||
"""检查是否存在指定会话的队列
|
||||
@@ -121,6 +149,20 @@ class WecomAIQueueMgr:
|
||||
"""
|
||||
return self.pending_responses.get(session_id)
|
||||
|
||||
def is_stream_finished(
|
||||
self,
|
||||
session_id: str,
|
||||
max_age_seconds: int = 60,
|
||||
) -> bool:
|
||||
"""判断 stream 是否在短期内已结束"""
|
||||
finished_at = self.completed_streams.get(session_id)
|
||||
if finished_at is None:
|
||||
return False
|
||||
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
|
||||
self.completed_streams.pop(session_id, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def cleanup_expired_responses(self, max_age_seconds: int = 300):
|
||||
"""清理过期的待处理响应
|
||||
|
||||
@@ -136,8 +178,75 @@ class WecomAIQueueMgr:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
|
||||
self.remove_queues(session_id)
|
||||
logger.debug(f"[WecomAI] 清理过期响应及队列: {session_id}")
|
||||
expired_finished = [
|
||||
session_id
|
||||
for session_id, finished_at in self.completed_streams.items()
|
||||
if current_time - finished_at > 60
|
||||
]
|
||||
for session_id in expired_finished:
|
||||
self.completed_streams.pop(session_id, None)
|
||||
|
||||
def set_listener(
|
||||
self,
|
||||
callback: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
self._listener_callback = callback
|
||||
for session_id in list(self.queues.keys()):
|
||||
self._start_listener_if_needed(session_id)
|
||||
|
||||
def _start_listener_if_needed(self, session_id: str):
|
||||
if self._listener_callback is None:
|
||||
return
|
||||
if session_id in self._listener_tasks:
|
||||
task = self._listener_tasks[session_id]
|
||||
if not task.done():
|
||||
return
|
||||
queue = self.queues.get(session_id)
|
||||
close_event = self._queue_close_events.get(session_id)
|
||||
if queue is None or close_event is None:
|
||||
return
|
||||
task = asyncio.create_task(
|
||||
self._listen_to_queue(session_id, queue, close_event),
|
||||
name=f"wecomai_listener_{session_id}",
|
||||
)
|
||||
self._listener_tasks[session_id] = task
|
||||
task.add_done_callback(lambda _: self._listener_tasks.pop(session_id, None))
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
async def _listen_to_queue(
|
||||
self,
|
||||
session_id: str,
|
||||
queue: asyncio.Queue,
|
||||
close_event: asyncio.Event,
|
||||
):
|
||||
while True:
|
||||
get_task = asyncio.create_task(queue.get())
|
||||
close_task = asyncio.create_task(close_event.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
{get_task, close_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if close_task in done:
|
||||
break
|
||||
data = get_task.result()
|
||||
if self._listener_callback is None:
|
||||
continue
|
||||
try:
|
||||
await self._listener_callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
if not get_task.done():
|
||||
get_task.cancel()
|
||||
if not close_task.done():
|
||||
close_task.cancel()
|
||||
|
||||
def get_stats(self) -> dict[str, int]:
|
||||
"""获取队列统计信息
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""媒体文件处理工具
|
||||
|
||||
提供音视频格式转换、时长获取等功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
async def get_media_duration(file_path: str) -> int | None:
|
||||
"""使用ffprobe获取媒体文件时长
|
||||
|
||||
Args:
|
||||
file_path: 媒体文件路径
|
||||
|
||||
Returns:
|
||||
时长(毫秒),如果获取失败返回None
|
||||
"""
|
||||
try:
|
||||
# 使用ffprobe获取时长
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
file_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0 and stdout:
|
||||
duration_seconds = float(stdout.decode().strip())
|
||||
duration_ms = int(duration_seconds * 1000)
|
||||
logger.debug(f"[Media Utils] 获取媒体时长: {duration_ms}ms")
|
||||
return duration_ms
|
||||
else:
|
||||
logger.warning(f"[Media Utils] 无法获取媒体文件时长: {file_path}")
|
||||
return None
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(
|
||||
"[Media Utils] ffprobe未安装或不在PATH中,无法获取媒体时长。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Media Utils] 获取媒体时长时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def convert_audio_to_opus(audio_path: str, output_path: str | None = None) -> str:
|
||||
"""使用ffmpeg将音频转换为opus格式
|
||||
|
||||
Args:
|
||||
audio_path: 原始音频文件路径
|
||||
output_path: 输出文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
转换后的opus文件路径
|
||||
|
||||
Raises:
|
||||
Exception: 转换失败时抛出异常
|
||||
"""
|
||||
# 如果已经是opus格式,直接返回
|
||||
if audio_path.lower().endswith(".opus"):
|
||||
return audio_path
|
||||
|
||||
# 生成输出文件路径
|
||||
if output_path is None:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.opus")
|
||||
|
||||
try:
|
||||
# 使用ffmpeg转换为opus格式
|
||||
# -y: 覆盖输出文件
|
||||
# -i: 输入文件
|
||||
# -acodec libopus: 使用opus编码器
|
||||
# -ac 1: 单声道
|
||||
# -ar 16000: 采样率16kHz
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
audio_path,
|
||||
"-acodec",
|
||||
"libopus",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
"16000",
|
||||
output_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
f"[Media Utils] 已清理失败的opus输出文件: {output_path}"
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(f"[Media Utils] 清理失败的opus输出文件时出错: {e}")
|
||||
|
||||
error_msg = stderr.decode() if stderr else "未知错误"
|
||||
logger.error(f"[Media Utils] ffmpeg转换音频失败: {error_msg}")
|
||||
raise Exception(f"ffmpeg conversion failed: {error_msg}")
|
||||
|
||||
logger.debug(f"[Media Utils] 音频转换成功: {audio_path} -> {output_path}")
|
||||
return output_path
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换音频格式。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
raise Exception("ffmpeg not found")
|
||||
except Exception as e:
|
||||
logger.error(f"[Media Utils] 转换音频格式时出错: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def convert_video_format(
|
||||
video_path: str, output_format: str = "mp4", output_path: str | None = None
|
||||
) -> str:
|
||||
"""使用ffmpeg转换视频格式
|
||||
|
||||
Args:
|
||||
video_path: 原始视频文件路径
|
||||
output_format: 目标格式,默认mp4
|
||||
output_path: 输出文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
转换后的视频文件路径
|
||||
|
||||
Raises:
|
||||
Exception: 转换失败时抛出异常
|
||||
"""
|
||||
# 如果已经是目标格式,直接返回
|
||||
if video_path.lower().endswith(f".{output_format}"):
|
||||
return video_path
|
||||
|
||||
# 生成输出文件路径
|
||||
if output_path is None:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{output_format}")
|
||||
|
||||
try:
|
||||
# 使用ffmpeg转换视频格式
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
video_path,
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-c:a",
|
||||
"aac",
|
||||
output_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
f"[Media Utils] 已清理失败的{output_format}输出文件: {output_path}"
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"[Media Utils] 清理失败的{output_format}输出文件时出错: {e}"
|
||||
)
|
||||
|
||||
error_msg = stderr.decode() if stderr else "未知错误"
|
||||
logger.error(f"[Media Utils] ffmpeg转换视频失败: {error_msg}")
|
||||
raise Exception(f"ffmpeg conversion failed: {error_msg}")
|
||||
|
||||
logger.debug(f"[Media Utils] 视频转换成功: {video_path} -> {output_path}")
|
||||
return output_path
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换视频格式。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
raise Exception("ffmpeg not found")
|
||||
except Exception as e:
|
||||
logger.error(f"[Media Utils] 转换视频格式时出错: {e}")
|
||||
raise
|
||||
@@ -23,7 +23,7 @@ class SharedPreferences:
|
||||
)
|
||||
self.path = json_storage_path
|
||||
self.db_helper = db_helper
|
||||
self.temorary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
self.temporary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
"""automatically clear per 24 hours. Might be helpful in some cases XD"""
|
||||
|
||||
self._sync_loop = asyncio.new_event_loop()
|
||||
@@ -37,7 +37,7 @@ class SharedPreferences:
|
||||
self._scheduler.start()
|
||||
|
||||
def _clear_temporary_cache(self):
|
||||
self.temorary_cache.clear()
|
||||
self.temporary_cache.clear()
|
||||
|
||||
async def get_async(
|
||||
self,
|
||||
|
||||
@@ -354,12 +354,15 @@ class ChatRoute(Route):
|
||||
return Response().error("session_id is empty").__dict__
|
||||
|
||||
webchat_conv_id = session_id
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
||||
|
||||
# 构建用户消息段(包含 path 用于传递给 adapter)
|
||||
message_parts = await self._build_user_message_parts(message)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
message_id,
|
||||
webchat_conv_id,
|
||||
)
|
||||
|
||||
async def stream():
|
||||
client_disconnected = False
|
||||
@@ -532,6 +535,8 @@ class ChatRoute(Route):
|
||||
refs = {}
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
# 将消息放入会话特定的队列
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
||||
|
||||
@@ -256,143 +256,148 @@ class LiveChatRoute(Route):
|
||||
await queue.put((session.username, cid, payload))
|
||||
|
||||
# 3. 等待响应并流式发送 TTS 音频
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid)
|
||||
|
||||
bot_text = ""
|
||||
audio_playing = False
|
||||
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await websocket.send_json({"t": "stop_play"})
|
||||
# 保存消息并标记为被打断
|
||||
await self._save_interrupted_message(session, user_text, bot_text)
|
||||
# 清空队列中未处理的消息
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await websocket.send_json({"t": "stop_play"})
|
||||
# 保存消息并标记为被打断
|
||||
await self._save_interrupted_message(
|
||||
session, user_text, bot_text
|
||||
)
|
||||
# 清空队列中未处理的消息
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if result_type == "plain":
|
||||
# 普通文本消息
|
||||
bot_text += data
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
elif result_type == "audio_chunk":
|
||||
# 流式音频数据
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
# Calculate latency from wav assembly finish to first audio chunk
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送音频数据给前端
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data, # base64 编码的音频数据
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
# 处理完成
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
# 如果没有音频流,发送 bot 消息文本
|
||||
if not audio_playing:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束标记
|
||||
await websocket.send_json({"t": "end"})
|
||||
|
||||
# 发送总耗时
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_type == "plain":
|
||||
# 普通文本消息
|
||||
bot_text += data
|
||||
|
||||
elif result_type == "audio_chunk":
|
||||
# 流式音频数据
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
# Calculate latency from wav assembly finish to first audio chunk
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送音频数据给前端
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data, # base64 编码的音频数据
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
# 处理完成
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
# 如果没有音频流,发送 bot 消息文本
|
||||
if not audio_playing:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束标记
|
||||
await websocket.send_json({"t": "end"})
|
||||
|
||||
# 发送总耗时
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
break
|
||||
break
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
- 修复一些原因导致 Tavily WebSearch、Bocha WebSearch 无法使用的问题
|
||||
|
||||
### xinzeng
|
||||
- 飞书支持 Bot 发送文件、图片和视频消息类型。
|
||||
|
||||
### 优化
|
||||
- 优化 WebChat 和 企业微信 AI 会话队列生命周期管理,减少内存泄漏,提高性能。
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.14.5"
|
||||
version = "4.14.6"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user