From 9c284b84b1988f9d3cea6562896ff8e895043ec9 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Mon, 15 May 2023 20:03:17 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E5=8D=87=E7=BA=A7=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=8D=8F=E8=AE=AE=E7=B0=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cores/qqbot/core.py | 16 +++++++---- model/command/command.py | 6 ++-- model/command/command_openai_official.py | 6 ++-- model/command/command_rev_chatgpt.py | 6 ++-- model/command/command_rev_edgegpt.py | 7 +++-- model/platform/qq.py | 36 ++++++++++++++++++++++-- 6 files changed, 60 insertions(+), 17 deletions(-) diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 16d179e3b..afc526d3e 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -309,6 +309,7 @@ def initBot(cfg, prov): gu.log("--------加载平台--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) # GOCQ + global gocq_bot if 'gocqbot' in cfg and cfg['gocqbot']['enable']: gu.log("- 启用QQ机器人 -", gu.LEVEL_INFO) @@ -326,11 +327,14 @@ def initBot(cfg, prov): with open("cmd_config.json", 'w', encoding='utf-8') as f: json.dump(cmd_config, f, indent=4) f.flush() - global gocq_app, gocq_bot, gocq_loop - gocq_bot = QQ() + global gocq_app, gocq_loop gocq_loop = asyncio.new_event_loop() + gocq_bot = QQ(True, gocq_loop) thread_inst = threading.Thread(target=run_gocq_bot, args=(gocq_loop, gocq_bot, gocq_app), daemon=False) thread_inst.start() + else: + gocq_bot = QQ(False) + # QQ频道 if 'qqbot' in cfg and cfg['qqbot']['enable']: @@ -437,7 +441,7 @@ def oper_msg(message, role = "member" # 角色 hit = False # 是否命中指令 command_result = () # 调用指令返回的结果 - global admin_qq, cached_plugins + global admin_qq, cached_plugins, gocq_bot if platform == PLATFORM_QQCHAN: gu.log(f"接收到消息:{message.content}", gu.LEVEL_INFO, tag="QQ频道") @@ -556,7 +560,7 @@ def oper_msg(message, chatgpt_res = "" if chosen_provider == OPENAI_OFFICIAL: - hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, platform=platform, message_obj=message, cached_plugins=cached_plugins) + hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot) # hit: 是否触发了指令 if not hit: # 请求ChatGPT获得结果 @@ -569,7 +573,7 @@ def oper_msg(message, send_message(platform, message, f"OpenAI API错误, 原因: {str(e)}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) elif chosen_provider == REV_CHATGPT: - hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, platform=platform, message_obj=message, cached_plugins=cached_plugins) + hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot) if not hit: try: chatgpt_res = str(rev_chatgpt.text_chat(qq_msg)) @@ -585,7 +589,7 @@ def oper_msg(message, bing_cache_loop = gocq_loop elif platform == PLATFORM_QQCHAN: bing_cache_loop = qqchan_loop - hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, platform=platform, message_obj=message, cached_plugins=cached_plugins) + hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot) if not hit: try: while rev_edgegpt.is_busy(): diff --git a/model/command/command.py b/model/command/command.py index 251787cbd..4e94c2f20 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -10,7 +10,7 @@ import util.plugin_util as putil import shutil import importlib from util import general_utils as gu - +from model.platform.qq import QQ PLATFORM_QQCHAN = 'qqchan' PLATFORM_GOCQ = 'gocq' @@ -34,12 +34,12 @@ class Command: except BaseException as e: raise e - def check_command(self, message, role, platform, message_obj, cached_plugins: dict): + def check_command(self, message, role, platform, message_obj, cached_plugins: dict, qq_platform: QQ): # 插件 for k, v in cached_plugins.items(): try: - hit, res = v["clsobj"].run(message, role, platform, message_obj) + hit, res = v["clsobj"].run(message, role, platform, message_obj, qq_platform) if hit: return True, res except BaseException as e: diff --git a/model/command/command_openai_official.py b/model/command/command_openai_official.py index 28c34d852..7cc949feb 100644 --- a/model/command/command_openai_official.py +++ b/model/command/command_openai_official.py @@ -1,6 +1,7 @@ from model.command.command import Command from model.provider.provider_openai_official import ProviderOpenAIOfficial from cores.qqbot.personality import personalities +from model.platform.qq import QQ class CommandOpenAIOfficial(Command): def __init__(self, provider: ProviderOpenAIOfficial): @@ -14,8 +15,9 @@ class CommandOpenAIOfficial(Command): role: str, platform: str, message_obj, - cached_plugins: dict): - hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins) + cached_plugins: dict, + qq_platform: QQ): + hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform) if hit: return True, res if self.command_start_with(message, "reset", "重置"): diff --git a/model/command/command_rev_chatgpt.py b/model/command/command_rev_chatgpt.py index 7cdebd6ec..676f094b3 100644 --- a/model/command/command_rev_chatgpt.py +++ b/model/command/command_rev_chatgpt.py @@ -1,5 +1,6 @@ from model.command.command import Command from model.provider.provider_rev_chatgpt import ProviderRevChatGPT +from model.platform.qq import QQ class CommandRevChatGPT(Command): def __init__(self, provider: ProviderRevChatGPT): @@ -11,8 +12,9 @@ class CommandRevChatGPT(Command): role: str, platform: str, message_obj, - cached_plugins: dict): - hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins) + cached_plugins: dict, + qq_platform: QQ): + hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform) if hit: return True, res if self.command_start_with(message, "help", "帮助"): diff --git a/model/command/command_rev_edgegpt.py b/model/command/command_rev_edgegpt.py index 838dfae92..39a4dc9b2 100644 --- a/model/command/command_rev_edgegpt.py +++ b/model/command/command_rev_edgegpt.py @@ -1,6 +1,8 @@ from model.command.command import Command from model.provider.provider_rev_edgegpt import ProviderRevEdgeGPT import asyncio +from model.platform.qq import QQ + class CommandRevEdgeGPT(Command): def __init__(self, provider: ProviderRevEdgeGPT): self.provider = provider @@ -13,8 +15,9 @@ class CommandRevEdgeGPT(Command): role: str, platform: str, message_obj, - cached_plugins: dict): - hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins) + cached_plugins: dict, + qq_platform: QQ): + hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform) if hit: return True, res if self.command_start_with(message, "reset"): diff --git a/model/platform/qq.py b/model/platform/qq.py index 679e7854d..29302d7dc 100644 --- a/model/platform/qq.py +++ b/model/platform/qq.py @@ -1,22 +1,40 @@ from nakuru.entities.components import Plain, At, Image from util import general_utils as gu +import asyncio class QQ: + def __init__(self, is_start: bool, gocq_loop = None) -> None: + self.is_start = is_start + self.gocq_loop = gocq_loop + def run_bot(self, gocq): self.client = gocq self.client.run() + def get_msg_loop(self): + return self.gocq_loop + async def send_qq_msg(self, source, res, image_mode: bool = False): + if not self.is_start: + raise Exception("管理员未启动QQ平台") """ - res可以是一个数组,也就是gocq的消息链. + res可以是一个数组, 也就是gocq的消息链。 + 插件开发者请使用send方法, 可以不用直接调用这个方法。 """ gu.log("回复QQ消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ", max_len=30) + if isinstance(source, int): + source = { + "type": "GroupMessage", + "group_id": source + } + if isinstance(res, list) and len(res) > 0: await self.client.sendGroupMessage(source.group_id, res) return + # 通过消息链处理 if not image_mode: if source.type == "GroupMessage": @@ -39,4 +57,18 @@ class QQ: await self.client.sendFriendMessage(source.user_id, [ Plain(text="好的,我根据你的需要为你生成了一张图片😊"), Image.fromURL(url=res) - ]) \ No newline at end of file + ]) + + def send(self, + to, + res): + ''' + 提供给插件的发送QQ消息接口, 不用在外部await。 + 参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 + ''' + if isinstance(to, int): + + try: + asyncio.run_coroutine_threadsafe(self.send_qq_msg(message_obj, res), self.gocq_loop).result() + except BaseException as e: + raise e \ No newline at end of file