From 50f62e66b03864b95ea49441c74879f147f29fef Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 6 Oct 2024 00:20:42 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=96=87=E8=BD=AC=E5=9B=BE=E6=B8=B2?= =?UTF-8?q?=E6=9F=93=E5=A4=B1=E8=B4=A5=E6=97=B6=E5=8F=91=E9=80=81=E7=BA=AF?= =?UTF-8?q?=E6=96=87=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/platform/__init__.py | 3 ++ model/platform/qq_aiocqhttp.py | 66 ++++++++++------------- model/platform/qq_nakuru.py | 97 +++++++++++++--------------------- model/platform/qq_official.py | 41 +++----------- type/message_event.py | 9 +++- 5 files changed, 84 insertions(+), 132 deletions(-) diff --git a/model/platform/__init__.py b/model/platform/__init__.py index d1d12a92c..97b3db0a3 100644 --- a/model/platform/__init__.py +++ b/model/platform/__init__.py @@ -5,6 +5,9 @@ from type.astrbot_message import AstrBotMessage from type.command import CommandResult from type.astrbot_message import MessageType +class T2IException(Exception): + def __init__(self, message: str = "文本转图片时发生错误") -> None: + super().__init__(message) class Platform(): def __init__(self, platform_name: str, context) -> None: diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index fad17f5d5..91f00b6d5 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -4,7 +4,7 @@ import traceback import logging from aiocqhttp import CQHttp, Event from aiocqhttp.exceptions import ActionFailed -from . import Platform +from . import Platform, T2IException from type.astrbot_message import * from type.message_event import * from type.command import * @@ -25,13 +25,11 @@ class AIOCQHTTP(Platform): assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。" self.message_handler = message_handler - self.waiting = {} self.context = context self.config = platform_config self.unique_session = context.config_helper.platform_settings.unique_session self.host = platform_config.ws_reverse_host self.port = platform_config.ws_reverse_port - self.admins = context.config_helper.admins_id def convert_message(self, event: Event) -> AstrBotMessage: @@ -134,13 +132,6 @@ class AIOCQHTTP(Platform): ok, reason = await self.pre_check(message) if not ok: return - - # 解析 role - sender_id = str(message.sender.user_id) - if sender_id in self.admins: - role = 'admin' - else: - role = 'member' # parse unified message origin unified_msg_origin = None @@ -157,7 +148,6 @@ class AIOCQHTTP(Platform): self.context, "aiocqhttp", message.session_id, - role, unified_msg_origin, reason == "command") # only_command @@ -169,10 +159,6 @@ class AIOCQHTTP(Platform): if message_result.callback: message_result.callback() - # 如果是等待回复的消息 - if message.session_id in self.waiting and self.waiting[message.session_id] == '': - self.waiting[message.session_id] = message - return message_result @@ -183,36 +169,35 @@ class AIOCQHTTP(Platform): """ 回复用户唤醒机器人的消息。(被动回复) """ - res = result_message - - if isinstance(res, str): - res = [Plain(text=res), ] + try: + await self._reply(message, result_message, use_t2i) + except T2IException as e: + logger.error(traceback.format_exc()) + logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") + await self._reply(message, result_message, False) + return result_message - # if image mode, put all Plain texts into a new picture. - if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(result_message, list): - rendered_images = await self.convert_to_t2i_chain(res) - if rendered_images: - try: - await self._reply(message, rendered_images) - return rendered_images - except BaseException as e: - logger.warn(traceback.format_exc()) - logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") - - await self._reply(message, res) - return res - - async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]): + async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None): await self.record_metrics() if isinstance(message_chain, str): message_chain = [Plain(text=message_chain), ] + + # 文转图处理 + if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list): + try: + message_chain = await self.convert_to_t2i_chain(message_chain) + if not message_chain: raise T2IException() + except BaseException as e: + raise T2IException() + # log if isinstance(message, AstrBotMessage): logger.info( f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}") else: logger.info(f"回复消息: {message_chain}") - + + # 解析成 OneBot json 格式并发送 ret = [] image_idx = [] for idx, segment in enumerate(message_chain): @@ -232,10 +217,11 @@ class AIOCQHTTP(Platform): # ENOENT if not image_idx: raise e - logger.warn("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。") + logger.warning("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。") for idx in image_idx: if ret[idx]['data']['file'].startswith('file://'): logger.info(f"正在上传图片: {ret[idx]['data']['path']}") + # 除了上传到图床,想不到更好的办法。 image_url = await self.context.image_uploader.upload_image(ret[idx]['data']['path']) logger.info(f"上传成功。") ret[idx]['data']['file'] = image_url @@ -267,8 +253,12 @@ class AIOCQHTTP(Platform): - 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号; ''' - - await self._reply(target, result_message.message_chain) + try: + await self._reply(target, result_message, result_message.is_use_t2i) + except T2IException as e: + logger.error(traceback.format_exc()) + logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") + await self._reply(target, result_message, False) async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): if message_type == MessageType.GROUP_MESSAGE: diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index 2fd647b8b..4fef3d14c 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -11,7 +11,7 @@ from nakuru import ( ) from typing import Union, List, Dict from type.types import Context -from . import Platform +from . import Platform, T2IException from type.astrbot_message import * from type.message_event import * from type.command import * @@ -40,11 +40,9 @@ class QQNakuru(Platform): asyncio.set_event_loop(self.loop) self.message_handler = message_handler - self.waiting = {} self.context = context self.unique_session = context.config_helper.platform_settings.unique_session self.config = platform_config - self.admins = context.config_helper.admins_id self.client = CQHTTP( host=self.config.host, @@ -113,13 +111,6 @@ class QQNakuru(Platform): session_id = message.raw_message.user_id message.session_id = session_id - - # 解析 role - sender_id = str(message.raw_message.user_id) - if sender_id in self.admins: - role = 'admin' - else: - role = 'member' # parse unified message origin unified_msg_origin = None @@ -141,7 +132,6 @@ class QQNakuru(Platform): self.context, "nakuru", session_id, - role, unified_msg_origin, reason == 'command') # only_command @@ -153,49 +143,47 @@ class QQNakuru(Platform): if message_result.callback: message_result.callback() - # 如果是等待回复的消息 - if session_id in self.waiting and self.waiting[session_id] == '': - self.waiting[session_id] = message - async def reply_msg(self, message: AstrBotMessage, result_message: List[BaseMessageComponent], use_t2i: bool = None): """ 回复用户唤醒机器人的消息。(被动回复) - """ - source = message.raw_message - res = result_message + """ + assert isinstance(message.raw_message, (GroupMessage, FriendMessage, GuildMessage)) + + try: + await self._reply(message, result_message, use_t2i) + except T2IException as e: + logger.error(traceback.format_exc()) + logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") + await self._reply(message, result_message, False) + return result_message - assert isinstance(source, - (GroupMessage, FriendMessage, GuildMessage)) - - logger.info( - f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(res)}") - - if isinstance(res, str): - res = [Plain(text=res), ] - - # if image mode, put all Plain texts into a new picture. - if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list): - rendered_images = await self.convert_to_t2i_chain(res) - if rendered_images: - try: - await self._reply(source, rendered_images) - return - except BaseException as e: - logger.warn(traceback.format_exc()) - logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") - - await self._reply(source, res) - - async def _reply(self, source, message_chain: List[BaseMessageComponent]): + async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None): await self.record_metrics() if isinstance(message_chain, str): message_chain = [Plain(text=message_chain), ] + + # 文转图处理 + if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list): + try: + message_chain = await self.convert_to_t2i_chain(message_chain) + if not message_chain: raise T2IException() + except BaseException as e: + raise T2IException() + + # log + if isinstance(message, AstrBotMessage): + logger.info( + f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}") + else: + logger.info(f"回复消息: {message_chain}") + source = message.raw_message is_dict = isinstance(source, dict) + # 发消息 typ = None if is_dict: if "group_id" in source: @@ -250,7 +238,13 @@ class QQNakuru(Platform): guild_id 不是频道号。 ''' - await self._reply(target, result_message.message_chain) + try: + await self._reply(target, result_message, result_message.is_use_t2i) + except T2IException as e: + logger.error(traceback.format_exc()) + logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") + await self._reply(target, result_message, False) + return result_message async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): ''' @@ -290,21 +284,4 @@ class QQNakuru(Platform): ) abm.tag = "nakuru" abm.message = message.message - return abm - - def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]: - ''' - 等待下一条消息,超时 300s 后抛出异常 - ''' - self.waiting[group_id] = '' - cnt = 0 - while True: - if group_id in self.waiting and self.waiting[group_id] != '': - # 去掉 - ret = self.waiting[group_id] - del self.waiting[group_id] - return ret - cnt += 1 - if cnt > 300: - raise Exception("等待消息超时。") - time.sleep(1) \ No newline at end of file + return abm \ No newline at end of file diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index b1f17398b..2c534c590 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -63,10 +63,8 @@ class QQOfficial(Platform): asyncio.set_event_loop(self.loop) self.message_handler = message_handler - self.waiting: dict = {} self.context = context self.config = platform_config - self.admins = context.config_helper.admins_id self.appid = platform_config.appid self.secret = platform_config.secret @@ -201,15 +199,8 @@ class QQOfficial(Platform): session_id = str(message.raw_message.author.id) message.session_id = session_id - # 解析出 role - sender_id = message.sender.user_id - if sender_id in self.admins: - role = 'admin' - else: - role = 'member' - # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id) message_result = await self.message_handler.handle(ame) if not message_result: @@ -219,10 +210,6 @@ class QQOfficial(Platform): if message_result.callback: message_result.callback() - # 如果是等待回复的消息 - if session_id in self.waiting and self.waiting[session_id] == '': - self.waiting[session_id] = message - return ret async def reply_msg(self, @@ -241,10 +228,15 @@ class QQOfficial(Platform): plain_text = '' image_path = '' msg_ref = None - rendered_images = [] + rendered_images = None if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list): - rendered_images = await self.convert_to_t2i_chain(result_message) + try: + rendered_images = await self.convert_to_t2i_chain(result_message) + except BaseException as e: + logger.warning(traceback.format_exc()) + logger.warning(f"文本转图片时发生错误: {e},将尝试默认方式。") + rendered_images = None if isinstance(result_message, list): plain_text, image_path = await self._parse_to_qqofficial(result_message, message.type == MessageType.GROUP_MESSAGE) @@ -386,20 +378,3 @@ class QQOfficial(Platform): async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): raise NotImplementedError("qqofficial 不支持此方法。") - - def wait_for_message(self, channel_id: int) -> AstrBotMessage: - ''' - 等待指定 channel_id 的下一条信息,超时 300s 后抛出异常 - ''' - self.waiting[channel_id] = '' - cnt = 0 - while True: - if channel_id in self.waiting and self.waiting[channel_id] != '': - # 去掉 - ret = self.waiting[channel_id] - del self.waiting[channel_id] - return ret - cnt += 1 - if cnt > 300: - raise Exception("等待消息超时。") - time.sleep(1) diff --git a/type/message_event.py b/type/message_event.py index dc0221897..5c106f9e1 100644 --- a/type/message_event.py +++ b/type/message_event.py @@ -47,10 +47,17 @@ class AstrMessageEvent(): context: Context, platform_name: str, session_id: str, - role: str = "member", + unified_msg_origin: str = None, only_command: bool = False): + # 解析 role + sender_id = str(message.sender.user_id) + if sender_id in context.config_helper.admins_id: + role = 'admin' + else: + role = 'member' + ame = AstrMessageEvent(message.message_str, message, context.find_platform(platform_name),