From 2cf18972f362010bb7b9106fd9d65a938fecb62a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 5 Feb 2024 13:18:34 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E9=9D=A2=E6=9D=BF?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E9=85=8D=E7=BD=AE=E6=97=B6=E6=8A=A5=E9=94=99?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98=EF=BC=9B=E4=BF=AE=E5=A4=8D=E9=A2=91?= =?UTF-8?q?=E9=81=93=E7=A7=81=E8=81=8A=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98=20perf:=20=E6=94=B9=E5=96=84=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- addons/dashboard/helper.py | 10 ++++++--- addons/dashboard/server.py | 2 +- cores/qqbot/core.py | 12 ++++------- model/command/command.py | 1 - model/platform/_platfrom.py | 34 ++++++++++++++++++++++++++++- model/platform/qq_gocq.py | 7 +++--- model/platform/qq_official.py | 37 +++++++++++++++++--------------- util/function_calling/gplugin.py | 1 + util/general_utils.py | 3 ++- 9 files changed, 72 insertions(+), 35 deletions(-) diff --git a/addons/dashboard/helper.py b/addons/dashboard/helper.py index dad6dc2f5..caaefac4b 100644 --- a/addons/dashboard/helper.py +++ b/addons/dashboard/helper.py @@ -9,6 +9,7 @@ import sys import os import threading import time +import asyncio def shutdown_bot(delay_s: int): @@ -28,6 +29,9 @@ class DashBoardConfig(): class DashBoardHelper(): def __init__(self, global_object, config: dict): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.logger = global_object.logger dashboard_data = global_object.dashboard_data dashboard_data.configs = { "data": [] @@ -41,13 +45,13 @@ class DashBoardHelper(): @self.dashboard.register("post_configs") def on_post_configs(post_configs: dict): try: - gu.log(f"收到配置更新请求", gu.LEVEL_INFO, tag="可视化面板") + self.logger.log(f"收到配置更新请求", gu.LEVEL_INFO, tag="可视化面板") self.save_config(post_configs) self.parse_default_config(self.dashboard_data, self.cc.get_all()) # 重启 threading.Thread(target=shutdown_bot, args=(2,), daemon=True).start() except Exception as e: - gu.log(f"在保存配置时发生错误:{e}", gu.LEVEL_ERROR, tag="可视化面板") + self.logger.log(f"在保存配置时发生错误:{e}", gu.LEVEL_ERROR, tag="可视化面板") raise e @@ -524,7 +528,7 @@ class DashBoardHelper(): ] except Exception as e: - gu.log(f"配置文件解析错误:{e}", gu.LEVEL_ERROR) + self.logger.log(f"配置文件解析错误:{e}", gu.LEVEL_ERROR) raise e diff --git a/addons/dashboard/server.py b/addons/dashboard/server.py index 8262f13dc..7ae245f3a 100644 --- a/addons/dashboard/server.py +++ b/addons/dashboard/server.py @@ -29,7 +29,7 @@ class Response(): class AstrBotDashBoard(): def __init__(self, global_object): - self.loop = asyncio.new_event_loop() + self.loop = asyncio.get_event_loop() asyncio.set_event_loop(self.loop) self.dashboard_data = global_object.dashboard_data self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/") diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index e12c27476..08fd991c1 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -294,9 +294,6 @@ def initBot(cfg): platform_str = "(未启动任何平台,请前往面板添加)" logger.log(f"🎉 项目启动完成\n - 启动的LLM: {len(llm_instance)}个\n - 启动的平台: {platform_str}\n - 启动的插件: {len(_global_object.cached_plugins)}个") - if chosen_provider is None: - logger.log("没有启动任何语言模型。", gu.LEVEL_WARNING) - dashboard_thread.join() async def cli(): @@ -333,7 +330,7 @@ async def cli_pack_message(prompt: str) -> NakuruGuildMessage: def run_qqchan_bot(cfg: dict, global_object: GlobalObject): try: from model.platform.qq_official import QQOfficial - qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg) + qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg, global_object=global_object) global_object.platform_qqchan = qqchannel_bot qqchannel_bot.run() except BaseException as e: @@ -358,7 +355,7 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject): logger.log("检查完毕,未发现问题。", tag="QQ") break try: - qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg) + qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, global_object=_global_object) _global_object.platform_qq = qq_gocq qq_gocq.run() except BaseException as e: @@ -421,7 +418,6 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak for i in message.message: if isinstance(i, Plain): message_str += i.text.strip() - logger.log(message_str, gu.LEVEL_INFO, tag=platform) if message_str == "": return MessageResult("Hi~") @@ -488,8 +484,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak check, msg = baidu_judge.judge(message_str) if not check: return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}") - if chosen_provider == None: - return MessageResult(f"管理员未启动任何语言模型或者语言模型初始化时失败。") + if chosen_provider == NONE_LLM: + return MessageResult("没有启动任何 LLM 并且未触发任何指令。") try: if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix): return diff --git a/model/command/command.py b/model/command/command.py index a518e344b..add4b607c 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -5,7 +5,6 @@ try: import git.exc from git.repo import Repo except BaseException as e: - gu.log("你正运行在无Git环境下,暂时将无法使用插件、热更新功能。") has_git = False import os import sys diff --git a/model/platform/_platfrom.py b/model/platform/_platfrom.py index 7dc65678b..9d53cd9a9 100644 --- a/model/platform/_platfrom.py +++ b/model/platform/_platfrom.py @@ -1,7 +1,17 @@ import abc import threading import asyncio -from typing import Callable +from typing import Callable, Union +from nakuru import ( + GuildMessage, + GroupMessage, + FriendMessage, +) +from ._nakuru_translation_layer import ( + NakuruGuildMessage, +) +from nakuru.entities.components import Plain, At, Image, Node + class Platform(): def __init__(self, message_handler: callable) -> None: @@ -38,6 +48,28 @@ class Platform(): 发送消息(主动发送)同 send_msg() ''' pass + + def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str]) -> NakuruGuildMessage: + ''' + 将消息解析成大纲消息形式。 + 如: xxxxx[图片]xxxxx + ''' + if isinstance(message, str): + return message + ret = '' + try: + for node in message.message: + if isinstance(node, Plain): + ret += node.text + elif isinstance(node, At): + ret += f'[At: {node.name}/{node.qq}]' + elif isinstance(node, Image): + ret += f'[图片]' + except Exception as e: + pass + ret.replace('\n', '') + return ret + def new_sub_thread(self, func, args=()): thread = threading.Thread(target=self._runner, args=(func, args), daemon=True) diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py index 85eac2d1b..d42c55a66 100644 --- a/model/platform/qq_gocq.py +++ b/model/platform/qq_gocq.py @@ -24,7 +24,7 @@ class FakeSource: class QQGOCQ(Platform): - def __init__(self, cfg: dict, message_handler: callable) -> None: + def __init__(self, cfg: dict, message_handler: callable, global_object) -> None: super().__init__(message_handler) self.loop = asyncio.new_event_loop() @@ -34,7 +34,7 @@ class QQGOCQ(Platform): self.gocq_cnt = 0 self.cc = CmdConfig() self.cfg = cfg - self.logger = gu.Logger() + self.logger: gu.Logger = global_object.logger try: self.nick_qq = cfg['nick_qq'] @@ -107,6 +107,7 @@ class QQGOCQ(Platform): self.client.run() async def handle_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], is_group: bool): + self.logger.log(f"{message.user_id} -> {self.parse_message_outline(message)}", tag="QQ_GOCQ") # 判断是否响应消息 resp = False if not is_group: @@ -178,7 +179,7 @@ class QQGOCQ(Platform): self.gocq_cnt += 1 - self.logger.log(f"{source.user_id} <- {res}", tag="GOCQ") + self.logger.log(f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ") if isinstance(source, int): source = FakeSource("GroupMessage", source) diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 3dc67eed1..39709a2b8 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -17,6 +17,7 @@ from ._nakuru_translation_layer import( gocq_compatible_receive, gocq_compatible_send ) +from typing import Union # QQ 机器人官方框架 class botClient(Client): @@ -25,24 +26,19 @@ class botClient(Client): # 收到频道消息 async def on_at_message_create(self, message: Message): - gu.log(str(message), gu.LEVEL_DEBUG, max_len=9999) # 转换层 nakuru_guild_message = gocq_compatible_receive(message) - gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999) - # await self.platform.handle_msg(nakuru_guild_message, is_group=True) self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, True)) # 收到私聊消息 async def on_direct_message_create(self, message: DirectMessage): # 转换层 nakuru_guild_message = gocq_compatible_receive(message) - gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999) - # await self.platform.handle_msg(nakuru_guild_message, is_group=False) self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, False)) class QQOfficial(Platform): - def __init__(self, cfg: dict, message_handler: callable) -> None: + def __init__(self, cfg: dict, message_handler: callable, global_object) -> None: super().__init__(message_handler) self.loop = asyncio.new_event_loop() @@ -56,7 +52,7 @@ class QQOfficial(Platform): self.token = cfg['qqbot']['token'] self.secret = cfg['qqbot_secret'] self.unique_session = cfg['uniqueSessionMode'] - self.logger = gu.Logger() + self.logger: gu.Logger = global_object.logger self.intents = botpy.Intents( public_guild_messages=True, @@ -87,10 +83,11 @@ class QQOfficial(Platform): ) async def handle_msg(self, message: NakuruGuildMessage, is_group: bool): - + _t = "/私聊" if not is_group else "" + self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}{_t}) -> {self.parse_message_outline(message)}", tag="QQ_OFFICIAL") # 解析出 session_id if self.unique_session or not is_group: - session_id = message.sender.user_id + session_id = message.sesnder.user_id else: session_id = message.channel_id @@ -112,7 +109,7 @@ class QQOfficial(Platform): if message_result is None: return - self.reply_msg(message, message_result.result_message) + self.reply_msg(is_group, message, message_result.result_message) if message_result.callback is not None: message_result.callback() @@ -121,12 +118,13 @@ class QQOfficial(Platform): self.waiting[session_id] = message def reply_msg(self, - message: NakuruGuildMessage, - res: list): + is_group: bool, + message: NakuruGuildMessage, + res: Union[str, list]): ''' 回复频道消息 ''' - self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}) <- {res}", tag="QQ频道") + self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL") self.qqchan_cnt += 1 plain_text = '' @@ -162,13 +160,15 @@ class QQOfficial(Platform): msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False) # 到这里,我们得到了 plain_text,image_path,msg_ref - data = { - 'channel_id': str(message.channel_id), 'content': plain_text, 'msg_id': message.message_id, 'message_reference': msg_ref } + if is_group: + data['channel_id'] = message.channel_id + else: + data['guild_id'] = message.guild_id if image_path != '': data['file_image'] = image_path @@ -207,8 +207,11 @@ class QQOfficial(Platform): self._send_wrapper(**data) def _send_wrapper(self, **kwargs): - # await self.client.api.post_message(**kwargs) - asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result() + if 'channel_id' in kwargs: + asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result() + else: + asyncio.run_coroutine_threadsafe(self.client.api.post_dms(**kwargs), self.loop).result() + def send_msg(self, channel_id: int, message_chain: list, message_id: int = None): ''' diff --git a/util/function_calling/gplugin.py b/util/function_calling/gplugin.py index 8cba715aa..428410c3d 100644 --- a/util/function_calling/gplugin.py +++ b/util/function_calling/gplugin.py @@ -156,6 +156,7 @@ def web_keyword_search_via_sougou(keyword) -> str: if len(res) >= 5: # 限制5条 break except Exception as e: + pass gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR) # 爬取网页内容 _detail_store = [] diff --git a/util/general_utils.py b/util/general_utils.py index 2195f6026..c3f5c3767 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -124,7 +124,7 @@ class Logger: for line in pres: ret += f"\033[{fg};{bg}m{line}\033[0m\n" try: - requests.post("http://localhost:6185/api/log", data=ret[:-1].encode()) + requests.post("http://localhost:6185/api/log", data=ret[:-1].encode(), timeout=1) except BaseException as e: pass self.history.append(ret) @@ -132,6 +132,7 @@ class Logger: self.history = self.history[-100:] print(ret[:-1]) +log = Logger() def port_checker(port: int, host: str = "localhost"): sk = socket.socket(socket.AF_INET,socket.SOCK_STREAM)