diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 29ac02653..9c13cfeff 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -30,7 +30,7 @@ from .wecomai_api import ( WecomAIBotStreamMessageBuilder, ) from .wecomai_event import WecomAIBotMessageEvent -from .wecomai_queue_mgr import WecomAIQueueMgr, wecomai_queue_mgr +from .wecomai_queue_mgr import WecomAIQueueMgr from .wecomai_server import WecomAIBotServer from .wecomai_utils import ( WecomAIBotConstants, @@ -144,9 +144,12 @@ class WecomAIBotAdapter(Platform): # 事件循环和关闭信号 self.shutdown_event = asyncio.Event() + # 队列管理器 + self.queue_mgr = WecomAIQueueMgr() + # 队列监听器 self.queue_listener = WecomAIQueueListener( - wecomai_queue_mgr, + self.queue_mgr, self._handle_queued_message, ) @@ -189,7 +192,7 @@ class WecomAIBotAdapter(Platform): stream_id, session_id, ) - wecomai_queue_mgr.set_pending_response(stream_id, callback_params) + self.queue_mgr.set_pending_response(stream_id, callback_params) resp = WecomAIBotStreamMessageBuilder.make_text_stream( stream_id, @@ -207,7 +210,7 @@ class WecomAIBotAdapter(Platform): elif msgtype == "stream": # wechat server is requesting for updates of a stream stream_id = message_data["stream"]["id"] - if not wecomai_queue_mgr.has_back_queue(stream_id): + if not self.queue_mgr.has_back_queue(stream_id): logger.error(f"Cannot find back queue for stream_id: {stream_id}") # 返回结束标志,告诉微信服务器流已结束 @@ -222,7 +225,7 @@ class WecomAIBotAdapter(Platform): callback_params["timestamp"], ) return resp - queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + queue = self.queue_mgr.get_or_create_back_queue(stream_id) if queue.empty(): logger.debug( f"No new messages in back queue for stream_id: {stream_id}", @@ -242,10 +245,9 @@ class WecomAIBotAdapter(Platform): elif msg["type"] == "end": # stream end finish = True - wecomai_queue_mgr.remove_queues(stream_id) + self.queue_mgr.remove_queues(stream_id) break - else: - pass + logger.debug( f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}", ) @@ -313,8 +315,8 @@ class WecomAIBotAdapter(Platform): session_id: str, ): """将消息放入队列进行异步处理""" - input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id) - _ = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + input_queue = self.queue_mgr.get_or_create_queue(stream_id) + _ = self.queue_mgr.get_or_create_back_queue(stream_id) message_payload = { "message_data": message_data, "callback_params": callback_params, @@ -453,6 +455,7 @@ class WecomAIBotAdapter(Platform): platform_meta=self.meta(), session_id=message.session_id, api_client=self.api_client, + queue_mgr=self.queue_mgr, ) self.commit_event(message_event) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index 130182b48..0091783a4 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -8,7 +8,7 @@ from astrbot.api.message_components import ( ) from .wecomai_api import WecomAIBotAPIClient -from .wecomai_queue_mgr import wecomai_queue_mgr +from .wecomai_queue_mgr import WecomAIQueueMgr class WecomAIBotMessageEvent(AstrMessageEvent): @@ -21,6 +21,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): platform_meta, session_id: str, api_client: WecomAIBotAPIClient, + queue_mgr: WecomAIQueueMgr, ): """初始化消息事件 @@ -34,14 +35,16 @@ class WecomAIBotMessageEvent(AstrMessageEvent): """ super().__init__(message_str, message_obj, platform_meta, session_id) self.api_client = api_client + self.queue_mgr = queue_mgr @staticmethod async def _send( message_chain: MessageChain, stream_id: str, + queue_mgr: WecomAIQueueMgr, streaming: bool = False, ): - back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + back_queue = queue_mgr.get_or_create_back_queue(stream_id) if not message_chain: await back_queue.put( @@ -94,7 +97,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "wecom_ai_bot platform event raw_message should be a dict" ) stream_id = raw.get("stream_id", self.session_id) - await WecomAIBotMessageEvent._send(message, stream_id) + await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await super().send(message) async def send_streaming(self, generator, use_fallback=False): @@ -105,7 +108,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "wecom_ai_bot platform event raw_message should be a dict" ) stream_id = raw.get("stream_id", self.session_id) - back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + back_queue = self.queue_mgr.get_or_create_back_queue(stream_id) # 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送 increment_plain = "" @@ -134,6 +137,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): final_data += await WecomAIBotMessageEvent._send( chain, stream_id=stream_id, + queue_mgr=self.queue_mgr, streaming=True, ) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index eb3455292..3a982bdf7 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -151,7 +151,3 @@ class WecomAIQueueMgr: "output_queues": len(self.back_queues), "pending_responses": len(self.pending_responses), } - - -# 全局队列管理器实例 -wecomai_queue_mgr = WecomAIQueueMgr()