From 4edd11f2f7563d06e5ff2f82c5c3e9dd1a0ec30b Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 24 Jul 2024 09:19:43 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E4=BA=9Bbug=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/bootstrap.py | 21 +++++--- dashboard/server.py | 6 +-- model/command/internal_handler.py | 2 +- model/command/manager.py | 8 ++- model/command/openai_official_handler.py | 8 ++- model/platform/__init__.py | 7 ++- model/platform/manager.py | 14 +++--- model/platform/qq_aiocqhttp.py | 64 +++++++++++++++++------- model/platform/qq_nakuru.py | 7 ++- model/platform/qq_official.py | 10 ++-- util/t2i/context.py | 4 +- util/t2i/renderer.py | 2 +- util/t2i/strategies/local_strategy.py | 2 +- 13 files changed, 101 insertions(+), 54 deletions(-) diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index fdfc922be..bcaa765f4 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -1,5 +1,5 @@ import asyncio -import threading +import traceback from astrbot.message.handler import MessageHandler from astrbot.persist.helper import dbConn from dashboard.server import AstrBotDashBoard @@ -72,13 +72,21 @@ class AstrBotBootstrap(): # load platforms platform_tasks = self.load_platform() # load metrics uploader - metrics_upload_task = upload_metrics(self.context) + metrics_upload_task = asyncio.create_task(upload_metrics(self.context)) # load dashboard self.dashboard.run_http_server() - dashboard_task = self.dashboard.ws_server() - - await asyncio.gather(metrics_upload_task, dashboard_task, *platform_tasks) - + dashboard_task = asyncio.create_task(self.dashboard.ws_server()) + tasks = [metrics_upload_task, dashboard_task, *platform_tasks] + tasks = [self.handle_task(task) for task in tasks] + await asyncio.gather(*tasks) + + async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]): + try: + result = await task + return result + except Exception as e: + logger.error(traceback.format_exc()) + return None def load_llm(self): if 'openai' in self.configs and \ @@ -88,6 +96,7 @@ class AstrBotBootstrap(): from model.command.openai_official_handler import OpenAIOfficialCommandHandler self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) self.llm_instance = ProviderOpenAIOfficial(self.context) + self.openai_command_handler.set_provider(self.llm_instance) logger.info("已启用 OpenAI API 支持。") def load_plugins(self): diff --git a/dashboard/server.py b/dashboard/server.py index 8da717f2e..bbcd35de6 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -419,19 +419,19 @@ class AstrBotDashBoard(): "tag": "" }, { - "title": "QQ(官方机器人 API)", + "title": "QQ(官方)", "desc": "QQ官方API。支持频道、群、私聊(需获得群权限)", "namespace": "internal_platform_qq_official", "tag": "" }, { - "title": "QQ(nakuru 适配器)", + "title": "QQ(nakuru)", "desc": "适用于 go-cqhttp", "namespace": "internal_platform_qq_gocq", "tag": "" }, { - "title": "QQ(aiocqhttp 适配器)", + "title": "QQ(aiocqhttp)", "desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。", "namespace": "internal_platform_qq_aiocqhttp", "tag": "" diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 9a4e6a43b..5f0c58cc9 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -23,7 +23,7 @@ class InternalCommandHandler: self.manager.register("update", "更新 AstrBot", 10, self.update) self.manager.register("plugin", "插件管理", 10, self.plugin) self.manager.register("reboot", "重启 AstrBot", 10, self.reboot) - self.manager.register("web_search", "网页搜索开关", 10, self.web_search) + self.manager.register("websearch", "网页搜索开关", 10, self.web_search) self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle) self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid) diff --git a/model/command/manager.py b/model/command/manager.py index 34b3cc6fb..bb4ba7cb2 100644 --- a/model/command/manager.py +++ b/model/command/manager.py @@ -93,7 +93,11 @@ class CommandManager(): return command_result except BaseException as e: logger.error(traceback.format_exc()) + if not command_metadata.inner_command: - logger.error(f"当执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时,发生了异常。") + text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}" + logger.error(text) else: - logger.error(f"当执行 {command} 指令时,发生了异常。") \ No newline at end of file + text = f"执行 {command} 指令时发生了异常。{e}" + logger.error(text) + return CommandResult().message(text) \ No newline at end of file diff --git a/model/command/openai_official_handler.py b/model/command/openai_official_handler.py index 759149bdf..49b645aba 100644 --- a/model/command/openai_official_handler.py +++ b/model/command/openai_official_handler.py @@ -145,8 +145,6 @@ class OpenAIOfficialCommandHandler(): async def draw(self, message: AstrMessageEvent, context: Context): message = message.message_str.removeprefix("画") img_url = await self.provider.image_generate(message) - p = await download_image_by_url(url=img_url) - with open(p, 'rb') as f: - return CommandResult( - message_chain=[Image.fromBytes(f.read())], - ) \ No newline at end of file + return CommandResult( + message_chain=[Image.fromURL(img_url)], + ) \ No newline at end of file diff --git a/model/platform/__init__.py b/model/platform/__init__.py index 2df17e33a..807362530 100644 --- a/model/platform/__init__.py +++ b/model/platform/__init__.py @@ -64,7 +64,10 @@ class Platform(): if isinstance(i, Plain): plain_str += i.text if plain_str and len(plain_str) > 50: - p = await self.context.image_renderer.render(plain_str) - rendered_images.append(Image.fromFileSystem(p)) + p = await self.context.image_renderer.render(plain_str, return_url=True) + if p.startswith('http'): + rendered_images.append(Image.fromURL(p)) + else: + rendered_images.append(Image.fromFileSystem(p)) return rendered_images \ No newline at end of file diff --git a/model/platform/manager.py b/model/platform/manager.py index bec7df7ea..1f287f466 100644 --- a/model/platform/manager.py +++ b/model/platform/manager.py @@ -1,4 +1,4 @@ -import time, threading +import asyncio from util.io import port_checker from type.register import RegisteredPlatform @@ -21,20 +21,20 @@ class PlatformManager(): if 'gocqbot' in self.config and self.config['gocqbot']['enable']: logger.info("启用 QQ(nakuru 适配器)") - tasks.append(self.gocq_bot()) + tasks.append(asyncio.create_task(self.gocq_bot())) if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']: logger.info("启用 QQ(aiocqhttp 适配器)") - tasks.append(self.aiocq_bot()) + tasks.append(asyncio.create_task(self.aiocq_bot())) # QQ频道 if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None: logger.info("启用 QQ(官方 API) 机器人消息平台") - tasks.append(self.qqchan_bot()) + tasks.append(asyncio.create_task(self.qqchan_bot())) return tasks - def gocq_bot(self): + async def gocq_bot(self): ''' 运行 QQ(nakuru 适配器) ''' @@ -51,7 +51,7 @@ class PlatformManager(): noticed = True logger.warning( f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。") - time.sleep(5) + await asyncio.sleep(5) else: logger.info("nakuru 适配器已连接。") break @@ -59,7 +59,7 @@ class PlatformManager(): qq_gocq = QQGOCQ(self.context, self.msg_handler) self.context.platforms.append(RegisteredPlatform( platform_name="gocq", platform_instance=qq_gocq, origin="internal")) - return qq_gocq.run() + await qq_gocq.run() except BaseException as e: logger.error("启动 nakuru 适配器时出现错误: " + str(e)) diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 777f4cc3d..84ffcf4f3 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -22,11 +22,8 @@ class AIOCQHTTP(Platform): self.announcement = self.context.base_config.get("announcement", "欢迎新人!") self.host = self.context.base_config['aiocqhttp']['ws_reverse_host'] self.port = self.context.base_config['aiocqhttp']['ws_reverse_port'] - - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - def compat_onebot2astrbotmsg(self, event: Event) -> AstrBotMessage: + def convert_message(self, event: Event) -> AstrBotMessage: abm = AstrBotMessage() abm.self_id = str(event.self_id) @@ -69,28 +66,48 @@ class AIOCQHTTP(Platform): def run_aiocqhttp(self): if not self.host or not self.port: return - self.bot = CQHttp(use_ws_reverse=True) - + self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp') @self.bot.on_message('group') async def group(event: Event): - abm = self.compat_onebot2astrbotmsg(event) + abm = self.convert_message(event) if abm: - await self.handle_msg(event, abm) - return {'reply': event.message} + await self.handle_msg(abm) + # return {'reply': event.message} @self.bot.on_message('private') async def private(event: Event): - abm = self.compat_onebot2astrbotmsg(event) + abm = self.convert_message(event) if abm: - await self.handle_msg(event, abm) - return {'reply': event.message} + await self.handle_msg(abm) + # return {'reply': event.message} - return self.bot.run_task(host=self.host, port=int(self.port)) + bot = self.bot.run_task(host=self.host, port=int(self.port)) + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.getLogger('aiocqhttp').setLevel(logging.ERROR) + + return bot + + def pre_check(self, message: AstrBotMessage) -> bool: + # if message chain contains Plain components or At components which points to self_id, return True + if message.type == MessageType.FRIEND_MESSAGE: + return True + for comp in message.message: + if isinstance(comp, At) and str(comp.qq) == message.self_id: + return True + # check nicks + if self.check_nick(message.message_str): + return True + return False async def handle_msg(self, message: AstrBotMessage): logger.info( f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") + if not self.pre_check(message): + return + # 解析 role sender_id = str(message.sender.user_id) if sender_id == self.context.config_helper.get('admin_qq', '') or \ @@ -100,7 +117,7 @@ class AIOCQHTTP(Platform): role = 'member' # construct astrbot message event - ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", message.session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role) # transfer control to message handler message_result = await self.message_handler.handle(ame) @@ -118,13 +135,14 @@ class AIOCQHTTP(Platform): async def reply_msg(self, message: AstrBotMessage, result_message: list): - await super().reply_msg() """ 回复用户唤醒机器人的消息。(被动回复) """ logger.info( f"{message.sender.user_id} <- {self.parse_message_outline(message)}") + res = result_message + if isinstance(res, str): res = [Plain(text=res), ] @@ -138,9 +156,21 @@ class AIOCQHTTP(Platform): except BaseException as e: logger.warn(traceback.format_exc()) logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") + + await self._reply(message, res) async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]): if isinstance(message_chain, str): message_chain = [Plain(text=message_chain), ] - - await self.bot.send(message.raw_message, message_chain) \ No newline at end of file + + ret = [] + for segment in message_chain: + d = segment.toDict() + if isinstance(segment, Plain): + d['type'] = 'text' + if isinstance(segment, Image): + # d['data']['file'] = + pass + ret.append(d) + + await self.bot.send(message.raw_message, ret) \ No newline at end of file diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index a9cc62510..d332b0b58 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -75,8 +75,6 @@ class QQGOCQ(Platform): if message.type == MessageType.FRIEND_MESSAGE: return True for comp in message.message: - if isinstance(comp, Plain): - return True if isinstance(comp, At) and str(comp.qq) == message.self_id: return True # check nicks @@ -85,7 +83,8 @@ class QQGOCQ(Platform): return False def run(self): - return self.client._run() + coro = self.client._run() + return coro async def handle_msg(self, message: AstrBotMessage): logger.info( @@ -119,7 +118,7 @@ class QQGOCQ(Platform): role = 'member' # construct astrbot message event - ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role) # transfer control to message handler message_result = await self.message_handler.handle(ame) diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 060ea5da4..f2d60ea92 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -309,9 +309,13 @@ class QQOfficial(Platform): if 'file_image' in kwargs: file_image_path = kwargs['file_image'].replace("file:///", "") if file_image_path: - logger.debug(f"上传图片: {file_image_path}") - image_url = await self.context.image_uploader.upload_image(file_image_path) - logger.debug(f"上传成功: {image_url}") + + if file_image_path.startswith("http"): + image_url = file_image_path + else: + logger.debug(f"上传图片: {file_image_path}") + image_url = await self.context.image_uploader.upload_image(file_image_path) + logger.debug(f"上传成功: {image_url}") media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url) del kwargs['file_image'] kwargs['media'] = media diff --git a/util/t2i/context.py b/util/t2i/context.py index 255fb89af..ac9f837a0 100644 --- a/util/t2i/context.py +++ b/util/t2i/context.py @@ -7,5 +7,5 @@ class RenderContext: def set_strategy(self, strategy: RenderStrategy): self._strategy = strategy - async def render(self, text: str) -> str: - return await self._strategy.render(text) + async def render(self, text: str, return_url: bool = False): + return await self._strategy.render(text, return_url) diff --git a/util/t2i/renderer.py b/util/t2i/renderer.py index 28df04b6e..bf141bce2 100644 --- a/util/t2i/renderer.py +++ b/util/t2i/renderer.py @@ -15,7 +15,7 @@ class TextToImageRenderer: async def render(self, text: str, use_network: bool = True, return_url: bool = False): if use_network: try: - return await self.context.render(text) + return await self.context.render(text, return_url=return_url) except BaseException as e: logger.error(f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.") self.context.set_strategy(self.local_strategy) diff --git a/util/t2i/strategies/local_strategy.py b/util/t2i/strategies/local_strategy.py index 56210bc90..5711a4830 100644 --- a/util/t2i/strategies/local_strategy.py +++ b/util/t2i/strategies/local_strategy.py @@ -17,7 +17,7 @@ class LocalRenderStrategy(RenderStrategy): except Exception as e: pass - async def render(self, text: str): + async def render(self, text: str, **kwargs): font_size = 26 image_width = 800 image_height = 600