From 04fb4f88ad3cdd1859e7392a269b2cc9c2f7119c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 30 Dec 2023 20:08:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- addons/plugins/helloworld/helloworld.py | 1 - cores/qqbot/core.py | 596 +++++--------------- cores/qqbot/global_object.py | 40 +- main.py | 18 +- model/command/command.py | 2 - model/command/openai_official.py | 3 - model/platform/_message_result.py | 8 + model/platform/_nakuru_translation_layer.py | 77 +++ model/platform/_platfrom.py | 30 + model/platform/qq.py | 190 ------- model/platform/qq_gocq.py | 305 ++++++++++ model/platform/qq_official.py | 227 ++++++++ model/platform/qqchan.py | 217 ------- model/platform/qqgroup.py | 188 ------ model/provider/openai_official.py | 2 +- requirements.txt | 2 +- util/cmd_config.py | 41 +- util/general_utils.py | 18 +- 18 files changed, 842 insertions(+), 1123 deletions(-) create mode 100644 model/platform/_message_result.py create mode 100644 model/platform/_nakuru_translation_layer.py create mode 100644 model/platform/_platfrom.py delete mode 100644 model/platform/qq.py create mode 100644 model/platform/qq_gocq.py create mode 100644 model/platform/qq_official.py delete mode 100644 model/platform/qqchan.py delete mode 100644 model/platform/qqgroup.py diff --git a/addons/plugins/helloworld/helloworld.py b/addons/plugins/helloworld/helloworld.py index 76a5d9b17..fbfdc0260 100644 --- a/addons/plugins/helloworld/helloworld.py +++ b/addons/plugins/helloworld/helloworld.py @@ -4,7 +4,6 @@ from nakuru import ( FriendMessage ) from botpy.message import Message, DirectMessage -from model.platform.qq import QQ from cores.qqbot.global_object import ( AstrMessageEvent, CommandResult diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index a76a20738..b8983c32a 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -1,5 +1,3 @@ -import botpy -from botpy.message import Message, DirectMessage import re import json import threading @@ -11,46 +9,30 @@ import os import sys from cores.qqbot.personality import personalities from addons.baidu_aip_judge import BaiduJudge -from model.platform.qqchan import QQChan, NakuruGuildMember, NakuruGuildMessage -from model.platform.qq import QQ -from model.platform.qqgroup import ( - UnofficialQQBotSDK, - Event as QQEvent, - Message as QQMessage, - MessageChain, - PlainText -) from nakuru import ( - CQHTTP, GroupMessage, - GroupMemberIncrease, FriendMessage, GuildMessage, - Notify ) +from model.platform._nakuru_translation_layer import NakuruGuildMember, NakuruGuildMessage from nakuru.entities.components import Plain,At,Image from model.provider.provider import Provider from model.command.command import Command from util import general_utils as gu from util.cmd_config import CmdConfig as cc +from util.cmd_config import init_astrbot_config_items import util.function_calling.gplugin as gplugin import util.plugin_util as putil from PIL import Image as PILImage import io import traceback from . global_object import GlobalObject -from typing import Union, Callable +from typing import Union from addons.dashboard.helper import DashBoardHelper from addons.dashboard.server import DashBoardData from cores.monitor.perf import run_monitor from cores.database.conn import dbConn - -# 缓存的会话 -session_dict = {} -# 统计信息 -count = {} -# 统计信息 -stat_file = '' +from model.platform._message_result import MessageResult # 用户发言频率 user_frequency = {} @@ -59,14 +41,8 @@ frequency_time = 60 # 计数默认值 frequency_count = 2 -# 公告(可自定义): -announcement = "" - -# 机器人私聊模式 -direct_message_mode = True - # 版本 -version = '3.1.0' +version = '3.1.1' # 语言模型 REV_CHATGPT = 'rev_chatgpt' @@ -75,7 +51,6 @@ REV_ERNIE = 'rev_ernie' REV_EDGEGPT = 'rev_edgegpt' NONE_LLM = 'none_llm' chosen_provider = None - # 语言模型对象 llm_instance: dict[str, Provider] = {} llm_command_instance: dict[str, Command] = {} @@ -85,123 +60,57 @@ baidu_judge = None # 关键词回复 keywords = {} -# QQ频道机器人 -qqchannel_bot: QQChan = None -PLATFORM_QQCHAN = 'qqchan' -qqchan_loop = None -client = None - -# QQ群机器人 -PLATFROM_QQBOT = 'qqbot' - # CLI PLATFORM_CLI = 'cli' -# 加载默认配置 -cc.init_attributes("qq_forward_threshold", 200) -cc.init_attributes("qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n") -cc.init_attributes("bing_proxy", "") -cc.init_attributes("qq_pic_mode", False) -cc.init_attributes("rev_chatgpt_model", "") -cc.init_attributes("rev_chatgpt_plugin_ids", []) -cc.init_attributes("rev_chatgpt_PUID", "") -cc.init_attributes("rev_chatgpt_unverified_plugin_domains", []) -cc.init_attributes("gocq_host", "127.0.0.1") -cc.init_attributes("gocq_http_port", 5700) -cc.init_attributes("gocq_websocket_port", 6700) -cc.init_attributes("gocq_react_group", True) -cc.init_attributes("gocq_react_guild", True) -cc.init_attributes("gocq_react_friend", True) -cc.init_attributes("gocq_react_group_increase", True) -cc.init_attributes("gocq_qqchan_admin", "") -cc.init_attributes("other_admins", []) -cc.init_attributes("CHATGPT_BASE_URL", "") -cc.init_attributes("qqbot_appid", "") -cc.init_attributes("qqbot_secret", "") -cc.init_attributes("llm_env_prompt", "> hint: 末尾根据内容和心情添加 1-2 个emoji") -cc.init_attributes("default_personality_str", "") -cc.init_attributes("openai_image_generate", { - "model": "dall-e-3", - "size": "1024x1024", - "style": "vivid", - "quality": "standard", -}) -cc.init_attributes("http_proxy", "") -cc.init_attributes("https_proxy", "") -cc.init_attributes("dashboard_username", "") -cc.init_attributes("dashboard_password", "") -# cc.init_attributes(["qq_forward_mode"], False) - -# QQ机器人 -gocq_bot = None -PLATFORM_GOCQ = 'gocq' -gocq_app = CQHTTP( - host=cc.get("gocq_host", "127.0.0.1"), - port=cc.get("gocq_websocket_port", 6700), - http_port=cc.get("gocq_http_port", 5700), -) -qq_bot: UnofficialQQBotSDK = UnofficialQQBotSDK( - cc.get("qqbot_appid", None), - cc.get("qqbot_secret", None) -) - -gocq_loop: asyncio.AbstractEventLoop = None -qqbot_loop: asyncio.AbstractEventLoop = None - +init_astrbot_config_items() # 全局对象 _global_object: GlobalObject = None -def new_sub_thread(func, args=()): - thread = threading.Thread(target=_runner, args=(func, args), daemon=True) - thread.start() +# def new_sub_thread(func, args=()): +# thread = threading.Thread(target=_runner, args=(func, args), daemon=True) +# thread.start() -def _runner(func: Callable, args: tuple): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(func(*args)) - loop.close() +# def _runner(func: Callable, args: tuple): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# loop.run_until_complete(func(*args)) +# loop.close() # 统计消息数据 def upload(): - global version, gocq_bot, qqchannel_bot + global version while True: - addr = '' addr_ip = '' - session_dict_dump = '{}' try: - addr = requests.get('http://myip.ipip.net', timeout=5).text - addr_ip = re.findall(r'\d+.\d+.\d+.\d+', addr)[0] - except BaseException as e: - pass - try: - gocq_cnt = 0 - qqchan_cnt = 0 - if gocq_bot is not None: - gocq_cnt = gocq_bot.get_cnt() - if qqchannel_bot is not None: - qqchan_cnt = qqchannel_bot.get_cnt() - o = {"cnt_total": _global_object.cnt_total,"admin": _global_object.admin_qq,"addr": addr, 's': session_dict_dump} + + o = { + "cnt_total": _global_object.cnt_total, + "admin": _global_object.admin_qq, + } o_j = json.dumps(o) - res = {"version": version, "count": gocq_cnt+qqchan_cnt, "ip": addr_ip, "others": o_j, "cntqc": qqchan_cnt, "cntgc": gocq_cnt} + res = { + "version": version, + "count": _global_object.cnt_total, + "cntqc": -1, + "cntgc": -1, + "ip": addr_ip, + "others": o_j, + "sys": sys.platform, + } gu.log(res, gu.LEVEL_DEBUG, tag="Upload", fg = gu.FG_COLORS['yellow'], bg=gu.BG_COLORS['black']) resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5) - # print(resp.text) if resp.status_code == 200: ok = resp.json() if ok['status'] == 'ok': _global_object.cnt_total = 0 - if gocq_bot is not None: - gocq_cnt = gocq_bot.set_cnt(0) - if qqchannel_bot is not None: - qqchan_cnt = qqchannel_bot.set_cnt(0) except BaseException as e: gu.log("上传统计信息时出现错误: " + str(e), gu.LEVEL_ERROR, tag="Upload") pass time.sleep(10*60) - # 语言模型选择 def privider_chooser(cfg): l = [] @@ -221,7 +130,7 @@ def privider_chooser(cfg): def initBot(cfg): global llm_instance, llm_command_instance global baidu_judge, chosen_provider - global frequency_count, frequency_time, announcement, direct_message_mode + global frequency_count, frequency_time global keywords, _global_object # 迁移旧配置 @@ -310,8 +219,7 @@ def initBot(cfg): # 得到私聊模式配置 if 'direct_message_mode' in cfg: - direct_message_mode = cfg['direct_message_mode'] - gu.log("私聊功能: "+str(direct_message_mode), gu.LEVEL_INFO) + gu.log("私聊功能: "+str(cfg['direct_message_mode']), gu.LEVEL_INFO) # 得到发言频率配置 if 'limit' in cfg: @@ -321,14 +229,6 @@ def initBot(cfg): if 'time' in cfg['limit']: frequency_time = cfg['limit']['time'] - # 得到公告配置 - if 'notice' in cfg: - if cc.get("qq_welcome", None) != None and cfg['notice'] == '此机器人由Github项目QQChannelChatGPT驱动。': - announcement = cc.get("qq_welcome", None) - else: - announcement = cfg['notice'] - gu.log("公告配置: " + announcement, gu.LEVEL_INFO) - try: if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']: _global_object.uniqueSession = True @@ -338,9 +238,6 @@ def initBot(cfg): except BaseException as e: gu.log("独立会话配置错误: "+str(e), gu.LEVEL_ERROR) - - gu.log(f"QQ开放平台AppID: {cfg['qqbot']['appid']} 令牌: {cfg['qqbot']['token']}") - if chosen_provider is None: gu.log("检测到没有启动任何语言模型。", gu.LEVEL_CRITICAL) @@ -370,56 +267,17 @@ def initBot(cfg): gu.log("--------加载机器人平台--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) - admin_qq = cc.get('admin_qq', None) - admin_qqchan = cc.get('admin_qqchan', None) - if admin_qq == None: - gu.log("未设置管理者QQ号(管理者才能使用update/plugin等指令),如需设置,请编辑 cmd_config.json 文件", gu.LEVEL_WARNING) - - if admin_qqchan == None: - gu.log("未设置管理者QQ频道用户号(管理者才能使用update/plugin等指令),如需设置,请编辑 cmd_config.json 文件。可在频道发送指令 !myid 获取", gu.LEVEL_WARNING) - - _global_object.admin_qq = admin_qq - _global_object.admin_qqchan = admin_qqchan - - global qq_bot, qqbot_loop - qqbot_loop = asyncio.new_event_loop() - if cc.get("qqbot_appid", '') != '' and cc.get("qqbot_secret", '') != '': - gu.log("- 启用QQ群机器人 -", gu.LEVEL_INFO) - thread_inst = threading.Thread(target=run_qqbot, args=(qqbot_loop, qq_bot,), daemon=True) - thread_inst.start() + gu.log("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)", gu.LEVEL_WARNING) # GOCQ - global gocq_bot if 'gocqbot' in cfg and cfg['gocqbot']['enable']: - gu.log("- 启用QQ机器人 -", gu.LEVEL_INFO) - - global gocq_app, gocq_loop - gocq_loop = asyncio.new_event_loop() - gocq_bot = QQ(True, cc, gocq_loop) - thread_inst = threading.Thread(target=run_gocq_bot, args=(gocq_loop, gocq_bot, gocq_app), daemon=True) - thread_inst.start() - else: - gocq_bot = QQ(False) - - _global_object.platform_qq = gocq_bot - - gu.log("机器人部署教程: https://github.com/Soulter/QQChannelChatGPT/wiki/", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) - gu.log("如果有任何问题, 请在 https://github.com/Soulter/QQChannelChatGPT 上提交 issue 或加群 322154837", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) - gu.log("请给 https://github.com/Soulter/QQChannelChatGPT 点个 star!", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) + gu.log("- 启用 QQ_GOCQ 机器人 -", gu.LEVEL_INFO) + threading.Thread(target=run_gocq_bot, args=(cfg, _global_object), daemon=True).start() # QQ频道 if 'qqbot' in cfg and cfg['qqbot']['enable']: - gu.log("- 启用QQ频道机器人 -", gu.LEVEL_INFO) - global qqchannel_bot, qqchan_loop - qqchannel_bot = QQChan() - qqchan_loop = asyncio.new_event_loop() - _global_object.platform_qqchan = qqchannel_bot - thread_inst = threading.Thread(target=run_qqchan_bot, args=(cfg, qqchan_loop, qqchannel_bot), daemon=True) - thread_inst.start() - # thread.join() - - if thread_inst == None: - gu.log("没有启用/成功启用任何机器人平台", gu.LEVEL_CRITICAL) + gu.log("- 启用 QQ_OFFICIAL 机器人 -", gu.LEVEL_INFO) + threading.Thread(target=run_qqchan_bot, args=(cfg, _global_object), daemon=True).start() default_personality_str = cc.get("default_personality_str", "") if default_personality_str == "": @@ -442,11 +300,11 @@ def initBot(cfg): # 运行 monitor threading.Thread(target=run_monitor, args=(_global_object,), daemon=False).start() - + + gu.log("如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) + gu.log("请给 https://github.com/Soulter/AstrBot 点个 star。", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) gu.log("🎉 项目启动完成。") - # asyncio.get_event_loop().run_until_complete(cli()) - dashboard_thread.join() async def cli(): @@ -478,29 +336,24 @@ async def cli_pack_message(prompt: str) -> NakuruGuildMessage: return ngm ''' -运行QQ频道机器人 +运行 QQ_OFFICIAL 机器人 ''' -def run_qqchan_bot(cfg, loop, qqchannel_bot: QQChan): - asyncio.set_event_loop(loop) - intents = botpy.Intents(public_guild_messages=True, direct_message=True) - global client - client = botClient( - intents=intents, - bot_log=False - ) +def run_qqchan_bot(cfg: dict, global_object: GlobalObject): try: - qqchannel_bot.run_bot(client, cfg['qqbot']['appid'], cfg['qqbot']['token']) + from model.platform.qq_official import QQOfficial + qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg) + global_object.platform_qqchan = qqchannel_bot + qqchannel_bot.run() except BaseException as e: gu.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), gu.LEVEL_CRITICAL, tag="QQ频道") - gu.log(r"如果您是初次启动,请修改配置文件(QQChannelChatGPT/config.yaml)详情请看:https://github.com/Soulter/QQChannelChatGPT/wiki。" + str(e), gu.LEVEL_CRITICAL, tag="System") - - i = input("按回车退出程序。\n") + gu.log(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。" + str(e), gu.LEVEL_CRITICAL, tag="System") ''' -运行GOCQ机器人 +运行 QQ_GOCQ 机器人 ''' -def run_gocq_bot(loop, gocq_bot, gocq_app): - asyncio.set_event_loop(loop) +def run_gocq_bot(cfg: dict, _global_object: GlobalObject): + from model.platform.qq_gocq import QQGOCQ + gu.log("正在检查本地GO-CQHTTP连接...端口5700, 6700", tag="QQ") noticed = False while True: @@ -512,22 +365,13 @@ def run_gocq_bot(loop, gocq_bot, gocq_app): else: gu.log("检查完毕,未发现问题。", tag="QQ") break - - global gocq_client - gocq_client = gocqClient() try: - gocq_bot.run_bot(gocq_app) + qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg) + _global_object.platform_qq = qq_gocq + qq_gocq.run() except BaseException as e: input("启动QQ机器人出现错误"+str(e)) -''' -启动QQ群机器人(官方接口) -''' -def run_qqbot(loop: asyncio.AbstractEventLoop, qq_bot: UnofficialQQBotSDK): - asyncio.set_event_loop(loop) - QQBotClient() - qq_bot.run_bot() - ''' 检查发言频率 @@ -550,133 +394,53 @@ def check_frequency(id) -> bool: user_frequency[id] = t return True - -''' -通用消息回复 -''' -async def send_message(platform, message, res, session_id = None): - global qqchannel_bot, qqchannel_bot, gocq_loop, session_dict - - # 统计会话信息 - if session_id is not None: - if session_id not in session_dict: - session_dict[session_id] = {'cnt': 1} - else: - session_dict[session_id]['cnt'] += 1 - else: - session_dict[session_id]['cnt'] += 1 - +async def record_message(platform: str, session_id: str): # TODO: 这里会非常吃资源。然而 sqlite3 不支持多线程,所以暂时这样写。 curr_ts = int(time.time()) db_inst = dbConn() db_inst.increment_stat_session(platform, session_id, 1) db_inst.increment_stat_message(curr_ts, 1) db_inst.increment_stat_platform(curr_ts, platform, 1) - - if platform == PLATFORM_QQCHAN: - qqchannel_bot.send_qq_msg(message, res) - elif platform == PLATFORM_GOCQ: - await gocq_bot.send_qq_msg(message, res) - elif platform == PLATFROM_QQBOT: - message_chain = MessageChain() - message_chain.parse_from_nakuru(res) - await qq_bot.send(message, message_chain) - elif platform == PLATFORM_CLI: - print(res) - -async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage], - group: bool=False, - platform: str = None): - """ - 处理消息。 - group: 群聊模式, - message: 频道是频道的消息对象, QQ是nakuru-gocq的消息对象 - msg_ref: 引用消息(频道) - platform: 平台(gocq, qqchan) - """ - global chosen_provider, keywords, qqchannel_bot, gocq_bot - global _global_object - qq_msg = '' - session_id = '' - user_id = '' - role = "member" # 角色, member或admin - hit = False # 是否命中指令 - command_result = () # 调用指令返回的结果 - _global_object.cnt_total += 1 - with_tag = False # 是否带有昵称 +async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage], + session_id: str, + role: str = 'member', + platform: str = None, +) -> MessageResult: + """ + 处理消息。 + message: 消息对象 + session_id: 该消息源的唯一识别号 + role: member | admin + platform: 平台(gocq, qqchan) + """ + global chosen_provider, keywords + global _global_object + message_str = '' + session_id = session_id + role = role + hit = False # 是否命中指令 + command_result = () # 调用指令返回的结果 + message_result = None # 消息返回结果 - if platform == PLATFORM_QQCHAN or platform == PLATFROM_QQBOT or platform == PLATFORM_CLI: - with_tag = True + record_message(platform, session_id) - _len = 0 for i in message.message: - if isinstance(i, Plain) or isinstance(i, PlainText): - qq_msg += str(i.text).strip() - if isinstance(i, At): - if message.type == "GuildMessage": - if i.qq == message.user_id or i.qq == message.self_tiny_id: - with_tag = True - if message.type == "FriendMessage": - if i.qq == message.self_id: - with_tag = True - if message.type == "GroupMessage": - if i.qq == message.self_id: - with_tag = True - - for i in _global_object.nick: - if i != '' and qq_msg.startswith(i): - _len = len(i) - with_tag = True - break - qq_msg = qq_msg[_len:].strip() - - gu.log(f"收到消息:{qq_msg}", gu.LEVEL_INFO, tag="QQ") - user_id = message.user_id - - if group: - # 适配GO-CQHTTP的频道功能 - if message.type == "GuildMessage": - session_id = message.channel_id - else: - session_id = message.group_id - else: - with_tag = True - session_id = message.user_id - - if message.type == "GuildMessage": - sender_id = str(message.sender.tiny_id) - else: - sender_id = str(message.sender.user_id) - if sender_id == _global_object.admin_qq or \ - sender_id == _global_object.admin_qqchan or \ - sender_id in cc.get("other_admins", []) or \ - sender_id == cc.get("gocq_qqchan_admin", "") or \ - platform == PLATFORM_CLI: - role = "admin" - - if _global_object.uniqueSession: - # 独立会话时,一个用户一个 session - session_id = sender_id - - - if qq_msg == "": - await send_message(platform, message, f"Hi~", session_id=session_id) - return + if isinstance(i, Plain): + message_str += i.text.strip() + gu.log(f"收到消息:{message_str}", gu.LEVEL_INFO, tag=platform) + if message_str == "": + return MessageResult("Hi~") - if with_tag: - # 检查发言频率 - if not check_frequency(user_id): - await send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', session_id=session_id) - return - - # logf.write("[GOCQBOT] "+ qq_msg+'\n') - # logf.flush() + # 检查发言频率 + user_id = message.user_id + if not check_frequency(user_id): + return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。') # 关键词回复 for k in keywords: - if qq_msg == k: + if message_str == k: plain_text = "" if 'plain_text' in keywords[k]: plain_text = keywords[k]['plain_text'] @@ -687,45 +451,34 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak image_url = keywords[k]['image_url'] if image_url != "": res = [Plain(plain_text), Image.fromURL(image_url)] - await send_message(platform, message, res, session_id=session_id) - else: - await send_message(platform, message, plain_text, session_id=session_id) - return + return MessageResult(res) + return MessageResult(plain_text) # 检查是否是更换语言模型的请求 temp_switch = "" - if qq_msg.startswith('/bing') or qq_msg.startswith('/gpt') or qq_msg.startswith('/revgpt'): + if message_str.startswith('/bing') or message_str.startswith('/gpt') or message_str.startswith('/revgpt'): target = chosen_provider - if qq_msg.startswith('/bing'): + if message_str.startswith('/bing'): target = REV_EDGEGPT - elif qq_msg.startswith('/gpt'): + elif message_str.startswith('/gpt'): target = OPENAI_OFFICIAL - elif qq_msg.startswith('/revgpt'): + elif message_str.startswith('/revgpt'): target = REV_CHATGPT - l = qq_msg.split(' ') + l = message_str.split(' ') if len(l) > 1 and l[1] != "": # 临时对话模式,先记录下之前的语言模型,回答完毕后再切回 temp_switch = chosen_provider chosen_provider = target - qq_msg = l[1] + message_str = l[1] else: chosen_provider = target cc.put("chosen_provider", chosen_provider) - await send_message(platform, message, f"已切换至【{chosen_provider}】", session_id=session_id) - return - - chatgpt_res = "" + return MessageResult(f"已切换至【{chosen_provider}】") - # 如果是等待回复的消息 - if platform == PLATFORM_GOCQ and session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '': - gocq_bot.waiting[session_id] = message - return - if platform == PLATFORM_QQCHAN and session_id in qqchannel_bot.waiting and qqchannel_bot.waiting[session_id] == '': - qqchannel_bot.waiting[session_id] = message - return + llm_result_str = "" hit, command_result = llm_command_instance[chosen_provider].check_command( - qq_msg, + message_str, session_id, role, platform, @@ -734,22 +487,17 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak # 没触发指令 if not hit: - if not with_tag: - return # 关键词拦截 for i in uw.unfit_words_q: - matches = re.match(i, qq_msg.strip(), re.I | re.M) + matches = re.match(i, message_str.strip(), re.I | re.M) if matches: - await send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", session_id=session_id) - return + return MessageResult(f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。") if baidu_judge != None: - check, msg = baidu_judge.judge(qq_msg) + check, msg = baidu_judge.judge(message_str) if not check: - await send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", session_id=session_id) - return + return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}") if chosen_provider == None: - await send_message(platform, message, f"管理员未启动任何语言模型或者语言模型初始化时失败。", session_id=session_id) - return + return MessageResult(f"管理员未启动任何语言模型或者语言模型初始化时失败。") try: # check image url image_url = None @@ -763,35 +511,24 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak break # web search keyword web_sch_flag = False - if qq_msg.startswith("ws ") and qq_msg != "ws ": - qq_msg = qq_msg[3:] + if message_str.startswith("ws ") and message_str != "ws ": + message_str = message_str[3:] web_sch_flag = True else: - qq_msg += " " + cc.get("llm_env_prompt", "") + message_str += " " + cc.get("llm_env_prompt", "") if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL: if _global_object.web_search or web_sch_flag: official_fc = chosen_provider == OPENAI_OFFICIAL - chatgpt_res = gplugin.web_search(qq_msg, llm_instance[chosen_provider], session_id, official_fc) + llm_result_str = gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) else: - chatgpt_res = str(llm_instance[chosen_provider].text_chat(qq_msg, session_id, image_url, default_personality = _global_object.default_personality)) + llm_result_str = str(llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality)) elif chosen_provider == REV_EDGEGPT: - res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform) - if res_code == 0: # bing不想继续话题,重置会话后重试。 - await send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", session_id=session_id) - await llm_instance[chosen_provider].forget() - res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform) - if res_code == 0: # bing还是不想继续话题,大概率说明提问有问题。 - await llm_instance[chosen_provider].forget() - await send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", session_id=session_id) - res = "" - chatgpt_res = str(res) + return MessageResult("AstrBot 不再默认支持 NewBing 模型。") - chatgpt_res = _global_object.reply_prefix + chatgpt_res + llm_result_str = _global_object.reply_prefix + llm_result_str except BaseException as e: gu.log(f"调用异常:{traceback.format_exc()}", gu.LEVEL_ERROR, max_len=100000) - gu.log("调用语言模型例程时出现异常。原因: "+str(e), gu.LEVEL_ERROR) - await send_message(platform, message, "调用语言模型例程时出现异常。原因: "+str(e), session_id=session_id) - return + return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}") # 切换回原来的语言模型 if temp_switch != "": @@ -799,139 +536,58 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak # 指令回复 if hit: - # 检查指令. command_result是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型) + # 检查指令。 command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型) if command_result == None: return - command = command_result[2] + if command == "keyword": if os.path.exists("keyword.json"): with open("keyword.json", "r", encoding="utf-8") as f: keywords = json.load(f) else: try: - await send_message(platform, message, command_result[1], session_id=session_id) + return MessageResult(command_result[1]) except BaseException as e: - await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id) + return MessageResult(f"回复消息出错: {str(e)}") if command == "update latest r": - await send_message(platform, message, command_result[1] + "\n\n即将自动重启。", session_id=session_id) - py = sys.executable - os.execl(py, py, *sys.argv) + def update_restart(): + py = sys.executable + os.execl(py, py, *sys.argv) + return MessageResult(command_result[1] + "\n\n即将自动重启。", callback=update_restart) if not command_result[0]: - await send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", session_id=session_id) - return + return MessageResult(f"指令调用错误: \n{str(command_result[1])}") # 画图指令 if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw': for i in command_result[1]: - # i is a link # 保存到本地 pic_res = requests.get(i, stream = True) if pic_res.status_code == 200: image = PILImage.open(io.BytesIO(pic_res.content)) - await send_message(platform, message, [Image.fromFileSystem(gu.save_temp_img(image))], session_id=session_id) + return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))]) # 其他指令 else: try: - await send_message(platform, message, command_result[1], session_id=session_id) + return MessageResult(command_result[1]) except BaseException as e: - await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id) - + return MessageResult(f"回复消息出错: {str(e)}") return - # 记录日志 - # logf.write(f"{reply_prefix} {str(chatgpt_res)}\n") - # logf.flush() - # 敏感过滤 # 过滤不合适的词 for i in uw.unfit_words: - chatgpt_res = re.sub(i, "***", chatgpt_res) + llm_result_str = re.sub(i, "***", llm_result_str) # 百度内容审核服务二次审核 if baidu_judge != None: - check, msg = baidu_judge.judge(chatgpt_res) + check, msg = baidu_judge.judge(llm_result_str) if not check: - await send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", session_id=session_id) - return - + return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}") # 发送信息 try: - await send_message(platform, message, chatgpt_res, session_id=session_id) + return MessageResult(llm_result_str) except BaseException as e: - gu.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR) - -# QQ频道机器人 -class botClient(botpy.Client): - # 收到频道消息 - async def on_at_message_create(self, message: Message): - gu.log(str(message), gu.LEVEL_DEBUG, max_len=9999) - - # 转换层 - nakuru_guild_message = qqchannel_bot.gocq_compatible_receive(message) - gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999) - new_sub_thread(oper_msg, (nakuru_guild_message, True, PLATFORM_QQCHAN)) - - # 收到私聊消息 - async def on_direct_message_create(self, message: DirectMessage): - if direct_message_mode: - - # 转换层 - nakuru_guild_message = qqchannel_bot.gocq_compatible_receive(message) - gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999) - - new_sub_thread(oper_msg, (nakuru_guild_message, False, PLATFORM_QQCHAN)) -# QQ机器人 -class gocqClient(): - # 收到群聊消息 - @gocq_app.receiver("GroupMessage") - async def _(app: CQHTTP, source: GroupMessage): - if cc.get("gocq_react_group", True): - if isinstance(source.message[0], Plain): - new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ)) - if isinstance(source.message[0], At): - if source.message[0].qq == source.self_id: - new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ)) - else: - return - - @gocq_app.receiver("FriendMessage") - async def _(app: CQHTTP, source: FriendMessage): - if cc.get("gocq_react_friend", True): - if isinstance(source.message[0], Plain): - new_sub_thread(oper_msg, (source, False, PLATFORM_GOCQ)) - else: - return - - @gocq_app.receiver("GroupMemberIncrease") - async def _(app: CQHTTP, source: GroupMemberIncrease): - if cc.get("gocq_react_group_increase", True): - global announcement - await app.sendGroupMessage(source.group_id, [ - Plain(text = announcement), - ]) - - @gocq_app.receiver("Notify") - async def _(app: CQHTTP, source: Notify): - print(source) - if source.sub_type == "poke" and source.target_id == source.self_id: - new_sub_thread(oper_msg, (source, False, PLATFORM_GOCQ)) - - @gocq_app.receiver("GuildMessage") - async def _(app: CQHTTP, source: GuildMessage): - if cc.get("gocq_react_guild", True): - if isinstance(source.message[0], Plain): - new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ)) - if isinstance(source.message[0], At): - if source.message[0].qq == source.self_tiny_id: - new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ)) - else: - return - -class QQBotClient(): - @qq_bot.on('GroupMessage') - async def _(bot: UnofficialQQBotSDK, message: QQMessage): - print(message) - new_sub_thread(oper_msg, (message, True, PLATFROM_QQBOT)) \ No newline at end of file + gu.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR) \ No newline at end of file diff --git a/cores/qqbot/global_object.py b/cores/qqbot/global_object.py index 364fb92b7..de36b31a9 100644 --- a/cores/qqbot/global_object.py +++ b/cores/qqbot/global_object.py @@ -1,5 +1,5 @@ -from model.platform.qqchan import QQChan, NakuruGuildMember, NakuruGuildMessage -from model.platform.qq import QQ +from model.platform.qq_official import QQOfficial, NakuruGuildMember, NakuruGuildMessage +from model.platform.qq_gocq import QQGOCQ from model.provider.provider import Provider from addons.dashboard.server import DashBoardData from nakuru import ( @@ -17,7 +17,7 @@ class GlobalObject: 存放一些公用的数据,用于在不同模块(如core与command)之间传递 ''' nick: str # gocq 的昵称 - base_config: dict # config.yaml + base_config: dict # config.json cached_plugins: dict # 缓存的插件 web_search: bool # 是否开启了网页搜索 reply_prefix: str @@ -25,8 +25,8 @@ class GlobalObject: admin_qqchan: str uniqueSession: bool cnt_total: int - platform_qq: QQ - platform_qqchan: QQChan + platform_qq: QQGOCQ + platform_qqchan: QQOfficial default_personality: dict dashboard_data: DashBoardData stat: dict @@ -46,35 +46,13 @@ class GlobalObject: self.default_personality = None self.dashboard_data = None self.stat = {} - ''' - - { - "config": {}, - "session": [ - { - "platform": "qq", - "session_id": 123456, - "cnt": 0 - }, - {...} - ], - "message": [ - // 以一小时为单位 - { - "ts": 1234567, - "cnt": 0 - } - ] - } - - ''' class AstrMessageEvent(): message_str: str # 纯消息字符串 message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage] # 消息对象 - gocq_platform: QQ - qq_sdk_platform: QQChan + gocq_platform: QQGOCQ + qq_sdk_platform: QQOfficial platform: str # `gocq` 或 `qqchan` role: str # `admin` 或 `member` global_object: GlobalObject # 一些公用数据 @@ -82,8 +60,8 @@ class AstrMessageEvent(): def __init__(self, message_str: str, message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage], - gocq_platform: QQ, - qq_sdk_platform: QQChan, + gocq_platform: QQGOCQ, + qq_sdk_platform: QQOfficial, platform: str, role: str, global_object: GlobalObject, diff --git a/main.py b/main.py index 386e71b27..10e791f70 100644 --- a/main.py +++ b/main.py @@ -9,11 +9,6 @@ warnings.filterwarnings("ignore") abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' def main(): - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%H:%M:%S', - ) # config.yaml 配置文件加载和环境确认 try: import cores.qqbot.core as qqBot @@ -23,6 +18,7 @@ def main(): ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8') cfg = yaml.safe_load(ymlfile) except ImportError as import_error: + traceback.print_exc() print(import_error) input("第三方库未完全安装完毕,请退出程序重试。") except FileNotFoundError as file_not_found: @@ -86,18 +82,6 @@ def check_env(ch_mirror=False): break print("第三方库检查完毕。") -def get_platform(): - import platform - sys_platform = platform.platform().lower() - if "windows" in sys_platform: - return "win" - elif "macos" in sys_platform: - return "mac" - elif "linux" in sys_platform: - return "linux" - else: - print("other") - if __name__ == "__main__": args = sys.argv diff --git a/model/command/command.py b/model/command/command.py index ce59fef28..b0f6c7c94 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -16,7 +16,6 @@ import util.plugin_util as putil import shutil import importlib from util.cmd_config import CmdConfig as cc -from model.platform.qq import QQ import stat from nakuru.entities.components import ( Plain, @@ -60,7 +59,6 @@ class Command: if isinstance(result, CommandResult): hit = result.hit res = result._result_tuple() - print(hit, res) elif isinstance(result, tuple): hit = result[0] res = result[1] diff --git a/model/command/openai_official.py b/model/command/openai_official.py index 8d919a989..b53314f2c 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -1,9 +1,6 @@ from model.command.command import Command from model.provider.openai_official import ProviderOpenAIOfficial from cores.qqbot.personality import personalities - -from model.platform.qq import QQ -from util import general_utils as gu from cores.qqbot.global_object import GlobalObject class CommandOpenAIOfficial(Command): diff --git a/model/platform/_message_result.py b/model/platform/_message_result.py new file mode 100644 index 000000000..4fd601951 --- /dev/null +++ b/model/platform/_message_result.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Union, Optional + +@dataclass +class MessageResult(): + result_message: Union[str, list] + is_command_call: Optional[bool] = False + callback: Optional[callable] = None diff --git a/model/platform/_nakuru_translation_layer.py b/model/platform/_nakuru_translation_layer.py new file mode 100644 index 000000000..5c1f7068e --- /dev/null +++ b/model/platform/_nakuru_translation_layer.py @@ -0,0 +1,77 @@ +from nakuru.entities.components import Plain, At, Image +from botpy.message import Message, DirectMessage + +class NakuruGuildMember(): + tiny_id: int # 发送者识别号 + user_id: int # 发送者识别号 + title: str + nickname: str # 昵称 + role: int # 角色 + icon_url: str # 头像url + +class NakuruGuildMessage(): + type: str = "GuildMessage" + self_id: int # bot的qq号 + self_tiny_id: int # bot的qq号 + sub_type: str # 消息类型 + message_id: str # 消息id + guild_id: int # 频道号 + channel_id: int # 子频道号 + user_id: int # 发送者qq号 + message: list # 消息内容 + sender: NakuruGuildMember # 发送者信息 + raw_message: Message + + def __str__(self) -> str: + return str(self.__dict__) + +# gocq-频道SDK兼容层(发) +def gocq_compatible_send(gocq_message_chain: list): + plain_text = "" + image_path = None # only one img supported + for i in gocq_message_chain: + if isinstance(i, Plain): + plain_text += i.text + elif isinstance(i, Image) and image_path == None: + if i.path is not None: + image_path = i.path + else: + image_path = i.file + return plain_text, image_path + +# gocq-频道SDK兼容层(收) +def gocq_compatible_receive(message: Message) -> NakuruGuildMessage: + ngm = NakuruGuildMessage() + try: + ngm.self_id = message.mentions[0].id + ngm.self_tiny_id = message.mentions[0].id + except: + ngm.self_id = 0 + ngm.self_tiny_id = 0 + + ngm.sub_type = "normal" + ngm.message_id = message.id + ngm.guild_id = int(message.guild_id) + ngm.channel_id = int(message.channel_id) + ngm.user_id = int(message.author.id) + msg = [] + plain_content = message.content.replace("<@!"+str(ngm.self_id)+">", "").strip() + msg.append(Plain(plain_content)) + if message.attachments: + for i in message.attachments: + if i.content_type.startswith("image"): + url = i.url + if not url.startswith("http"): + url = "https://"+url + img = Image.fromURL(url) + msg.append(img) + ngm.message = msg + ngm.sender = NakuruGuildMember() + ngm.sender.tiny_id = int(message.author.id) + ngm.sender.user_id = int(message.author.id) + ngm.sender.title = "" + ngm.sender.nickname = message.author.username + ngm.sender.role = 0 + ngm.sender.icon_url = message.author.avatar + ngm.raw_message = message + return ngm diff --git a/model/platform/_platfrom.py b/model/platform/_platfrom.py new file mode 100644 index 000000000..4afca9c70 --- /dev/null +++ b/model/platform/_platfrom.py @@ -0,0 +1,30 @@ +import abc + +class Platform(): + def __init__(self, message_handler: callable) -> None: + ''' + 初始化平台的各种接口 + ''' + self.message_handler = message_handler + pass + + @abc.abstractmethod + def handle_msg(): + ''' + 处理到来的消息 + ''' + pass + + @abc.abstractmethod + def reply_msg(): + ''' + 回复消息(被动发送) + ''' + pass + + @abc.abstractmethod + def send_msg(): + ''' + 发送消息(主动发送) + ''' + pass \ No newline at end of file diff --git a/model/platform/qq.py b/model/platform/qq.py deleted file mode 100644 index 1b7a910d8..000000000 --- a/model/platform/qq.py +++ /dev/null @@ -1,190 +0,0 @@ -from nakuru.entities.components import Plain, At, Image, Node -from util import general_utils as gu -from util.cmd_config import CmdConfig -import asyncio -from nakuru import ( - CQHTTP, - GuildMessage, - GroupMessage, - FriendMessage -) -from typing import Union -import time - - -class FakeSource: - def __init__(self, type, group_id): - self.type = type - self.group_id = group_id - -class QQ: - def __init__(self, is_start: bool, cc: CmdConfig = None, gocq_loop = None) -> None: - self.is_start = is_start - self.gocq_loop = gocq_loop - self.cc = cc - self.waiting = {} - self.gocq_cnt = 0 - - def run_bot(self, gocq): - self.client: CQHTTP = gocq - self.client.run() - - def get_msg_loop(self): - return self.gocq_loop - - def get_cnt(self): - return self.gocq_cnt - - def set_cnt(self, cnt): - self.gocq_cnt = cnt - - async def send_qq_msg(self, - source, - res, - image_mode=None): - self.gocq_cnt += 1 - if not self.is_start: - raise Exception("管理员未启动GOCQ平台") - """ - res可以是一个数组, 也就是gocq的消息链。 - 插件开发者请使用send方法, 可以不用直接调用这个方法。 - """ - gu.log("回复GOCQ消息: "+str(res), level=gu.LEVEL_INFO, tag="GOCQ", max_len=300) - - if isinstance(source, int): - source = FakeSource("GroupMessage", source) - - # str convert to CQ Message Chain - if isinstance(res, str): - res_str = res - res = [] - if source.type == "GroupMessage" and not isinstance(source, FakeSource): - res.append(At(qq=source.user_id)) - res.append(Plain(text=res_str)) - - # if image mode, put all Plain texts into a new picture. - if image_mode is None: - image_mode = self.cc.get('qq_pic_mode', False) - if image_mode and isinstance(res, list): - plains = [] - news = [] - for i in res: - if isinstance(i, Plain): - plains.append(i.text) - else: - news.append(i) - plains_str = "".join(plains).strip() - if plains_str != "" and len(plains_str) > 50: - p = gu.create_markdown_image("".join(plains)) - news.append(Image.fromFileSystem(p)) - res = news - - # 回复消息链 - if isinstance(res, list) and len(res) > 0: - if source.type == "GuildMessage": - await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res) - return - elif source.type == "FriendMessage": - await self.client.sendFriendMessage(source.user_id, res) - return - elif source.type == "GroupMessage": - # 过长时forward发送 - plain_text_len = 0 - image_num = 0 - for i in res: - if isinstance(i, Plain): - plain_text_len += len(i.text) - elif isinstance(i, Image): - image_num += 1 - if plain_text_len > self.cc.get('qq_forward_threshold', 200): - # 删除At - for i in res: - if isinstance(i, At): - res.remove(i) - node = Node(res) - # node.content = res - node.uin = 123456 - node.name = f"bot" - node.time = int(time.time()) - # print(node) - nodes=[node] - await self.client.sendGroupForwardMessage(source.group_id, nodes) - return - await self.client.sendGroupMessage(source.group_id, res) - return - - def send(self, - to, - res, - image_mode=False, - ): - ''' - 提供给插件的发送QQ消息接口, 不用在外部await。 - 参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 - 第三个参数是是否开启图片模式,如果开启,那么所有纯文字信息都会被合并成一张图片。 - ''' - try: - asyncio.run_coroutine_threadsafe(self.send_qq_msg(to, res, image_mode), self.gocq_loop).result() - except BaseException as e: - raise e - - def send_guild(self, - message_obj, - res, - ): - ''' - 提供给插件的发送GOCQ QQ频道消息接口, 不用在外部await。 - 参数说明:第一个参数必须是消息对象, 第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 - ''' - try: - asyncio.run_coroutine_threadsafe(self.send_qq_msg(message_obj, res), self.gocq_loop).result() - except BaseException as e: - raise e - - def create_text_image(title: str, text: str, max_width=30, font_size=20): - ''' - 文本转图片。 - title: 标题 - text: 文本内容 - max_width: 文本宽度最大值(默认30) - font_size: 字体大小(默认20) - - 返回:文件路径 - ''' - try: - img = gu.word2img(title, text, max_width, font_size) - p = gu.save_temp_img(img) - return p - except Exception as e: - raise e - - def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]: - ''' - 等待下一条消息,超时 300s 后抛出异常 - ''' - self.waiting[group_id] = '' - cnt = 0 - while True: - if group_id in self.waiting and self.waiting[group_id] != '': - # 去掉 - ret = self.waiting[group_id] - del self.waiting[group_id] - return ret - cnt += 1 - if cnt > 300: - raise Exception("等待消息超时。") - time.sleep(1) - - def get_client(self): - return self.client - - def nakuru_method_invoker(self, func, *args, **kwargs): - """ - 返回一个方法调用器,可以用来立即调用nakuru的方法。 - """ - try: - ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.gocq_loop).result() - return ret - except BaseException as e: - raise e - diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py new file mode 100644 index 000000000..60b6ca843 --- /dev/null +++ b/model/platform/qq_gocq.py @@ -0,0 +1,305 @@ +from nakuru.entities.components import Plain, At, Image, Node +from util import general_utils as gu +from util.cmd_config import CmdConfig +import asyncio +from nakuru import ( + CQHTTP, + GuildMessage, + GroupMessage, + FriendMessage, + GroupMemberIncrease, + Notify +) +from typing import Union +import time + +from ._platfrom import Platform + + +class FakeSource: + def __init__(self, type, group_id): + self.type = type + self.group_id = group_id + + +class QQGOCQ(Platform): + def __init__(self, cfg: dict, message_handler: callable) -> None: + super().__init__(message_handler) + asyncio.set_event_loop(asyncio.new_event_loop()) + + self.waiting = {} + self.gocq_cnt = 0 + self.cc = CmdConfig() + self.cfg = cfg + + self.nick_qq = cfg['nick_qq'] + nick_qq = self.nick_qq + if nick_qq == None: + nick_qq = ("ai","!","!") + if isinstance(nick_qq, str): + nick_qq = (nick_qq,) + if isinstance(nick_qq, list): + nick_qq = tuple(nick_qq) + + self.unique_session = cfg['uniqueSessionMode'] + self.pic_mode = cfg['qq_pic_mode'] + + self.client = CQHTTP( + host=self.cc.get("gocq_host", "127.0.0.1"), + port=self.cc.get("gocq_websocket_port", 6700), + http_port=self.cc.get("gocq_http_port", 5700), + ) + gocq_app = self.client + + self.announcement = self.cc.get("announcement", "欢迎新人!") + + @gocq_app.receiver("GroupMessage") + async def _(app: CQHTTP, source: GroupMessage): + if self.cc.get("gocq_react_group", True): + if isinstance(source.message[0], Plain): + await self.handle_msg(source, True) + elif isinstance(source.message[0], At): + if source.message[0].qq == source.self_id: + await self.handle_msg(source, True) + else: + return + + @gocq_app.receiver("FriendMessage") + async def _(app: CQHTTP, source: FriendMessage): + if self.cc.get("gocq_react_friend", True): + if isinstance(source.message[0], Plain): + await self.handle_msg(source, False) + else: + return + + @gocq_app.receiver("GroupMemberIncrease") + async def _(app: CQHTTP, source: GroupMemberIncrease): + if self.cc.get("gocq_react_group_increase", True): + + await app.sendGroupMessage(source.group_id, [ + Plain(text = self.announcement) + ]) + + @gocq_app.receiver("Notify") + async def _(app: CQHTTP, source: Notify): + print(source) + if source.sub_type == "poke" and source.target_id == source.self_id: + await self.handle_msg(source, False) + + @gocq_app.receiver("GuildMessage") + async def _(app: CQHTTP, source: GuildMessage): + if self.cc.get("gocq_react_guild", True): + if isinstance(source.message[0], Plain): + await self.handle_msg(source, True) + elif isinstance(source.message[0], At): + if source.message[0].qq == source.self_tiny_id: + await self.handle_msg(source, True) + else: + return + + def run(self): + self.client.run() + + async def handle_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], is_group: bool): + # 判断是否响应消息 + resp = False + for i in message.message: + if isinstance(i, At): + if message.type == "GuildMessage": + if i.qq == message.user_id or i.qq == message.self_tiny_id: + resp = True + if message.type == "FriendMessage": + if i.qq == message.self_id: + resp = True + if message.type == "GroupMessage": + if i.qq == message.self_id: + resp = True + for i in self.nick_qq: + if i != '' and i in message.message[0].text: + resp = True + break + + if not resp: return + + # 解析 session_id + if self.unique_session or not is_group: + session_id = message.user_id + elif message.type == "GroupMessage": + session_id = message.group_id + elif message.type == "GuildMessage": + session_id = message.channel_id + else: + session_id = message.user_id + + # 解析 role + sender_id = str(message.user_id) + if sender_id == self.cc.get('admin_qq', '') or \ + sender_id == self.cc.get('gocq_qqchan_admin', '') or \ + sender_id in self.cc.get('other_admins', []): + role = 'admin' + else: + role = 'member' + + message_result = await self.message_handler( + message=message, + session_id=session_id, + role=role, + platform='gocq' + ) + + if message_result is None: + return + await self.reply_msg(message, message_result.result_message) + if message_result.callback is not None: + message_result.callback() + + # 如果是等待回复的消息 + if session_id in self.waiting and self.waiting[session_id] == '': + self.waiting[session_id] = message + + async def reply_msg(self, + message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], + result_message: list): + """ + 插件开发者请使用send方法, 可以不用直接调用这个方法。 + """ + source = message + res = result_message + + self.gocq_cnt += 1 + + gu.log("回复GOCQ消息: "+str(res), level=gu.LEVEL_INFO, tag="GOCQ", max_len=300) + + if isinstance(source, int): + source = FakeSource("GroupMessage", source) + + # str convert to CQ Message Chain + if isinstance(res, str): + res_str = res + res = [] + if source.type == "GroupMessage" and not isinstance(source, FakeSource): + res.append(At(qq=source.user_id)) + res.append(Plain(text=res_str)) + + # if image mode, put all Plain texts into a new picture. + if self.pic_mode and isinstance(res, list): + plains = [] + news = [] + for i in res: + if isinstance(i, Plain): + plains.append(i.text) + else: + news.append(i) + plains_str = "".join(plains).strip() + if plains_str != "" and len(plains_str) > 50: + p = gu.create_markdown_image("".join(plains)) + news.append(Image.fromFileSystem(p)) + res = news + + # 回复消息链 + if isinstance(res, list) and len(res) > 0: + if source.type == "GuildMessage": + await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res) + return + elif source.type == "FriendMessage": + await self.client.sendFriendMessage(source.user_id, res) + return + elif source.type == "GroupMessage": + # 过长时forward发送 + plain_text_len = 0 + image_num = 0 + for i in res: + if isinstance(i, Plain): + plain_text_len += len(i.text) + elif isinstance(i, Image): + image_num += 1 + if plain_text_len > self.cc.get('qq_forward_threshold', 200): + # 删除At + for i in res: + if isinstance(i, At): + res.remove(i) + node = Node(res) + # node.content = res + node.uin = 123456 + node.name = f"bot" + node.time = int(time.time()) + # print(node) + nodes=[node] + await self.client.sendGroupForwardMessage(source.group_id, nodes) + return + await self.client.sendGroupMessage(source.group_id, res) + return + + async def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list): + ''' + 提供给插件的发送QQ消息接口。 + 参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 + ''' + try: + await self.reply_msg(message, result_message) + except BaseException as e: + raise e + + async def send(self, + to, + res): + ''' + 同 send_msg() + ''' + try: + await self.send_msg(to, res) + except BaseException as e: + raise e + + def create_text_image(title: str, text: str, max_width=30, font_size=20): + ''' + 文本转图片。 + title: 标题 + text: 文本内容 + max_width: 文本宽度最大值(默认30) + font_size: 字体大小(默认20) + + 返回:文件路径 + ''' + try: + img = gu.word2img(title, text, max_width, font_size) + p = gu.save_temp_img(img) + return p + except Exception as e: + raise e + + def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]: + ''' + 等待下一条消息,超时 300s 后抛出异常 + ''' + self.waiting[group_id] = '' + cnt = 0 + while True: + if group_id in self.waiting and self.waiting[group_id] != '': + # 去掉 + ret = self.waiting[group_id] + del self.waiting[group_id] + return ret + cnt += 1 + if cnt > 300: + raise Exception("等待消息超时。") + time.sleep(1) + + def get_client(self): + return self.client + + def nakuru_method_invoker(self, func, *args, **kwargs): + """ + 返回一个方法调用器,可以用来立即调用nakuru的方法。 + """ + try: + ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.gocq_loop).result() + return ret + except BaseException as e: + raise e + + def get_cnt(self): + return self.gocq_cnt + + def set_cnt(self, cnt): + self.gocq_cnt = cnt \ No newline at end of file diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py new file mode 100644 index 000000000..e7d0338eb --- /dev/null +++ b/model/platform/qq_official.py @@ -0,0 +1,227 @@ +import io +import botpy +from PIL import Image as PILImage +from botpy.message import Message, DirectMessage +import re +import asyncio +import requests +from util import general_utils as gu + +from botpy.types.message import Reference +from botpy import Client +import time +from ._platfrom import Platform +from ._nakuru_translation_layer import( + NakuruGuildMessage, + gocq_compatible_receive, + gocq_compatible_send, + NakuruGuildMember +) + +# QQ 机器人官方框架 +class botClient(Client): + def set_platform(self, platform: 'QQOfficial'): + self.platform = platform + + # 收到频道消息 + async def on_at_message_create(self, message: Message): + gu.log(str(message), gu.LEVEL_DEBUG, max_len=9999) + # 转换层 + nakuru_guild_message = gocq_compatible_receive(message) + gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999) + await self.platform.handle_msg(nakuru_guild_message, is_group=True) + + # 收到私聊消息 + async def on_direct_message_create(self, message: DirectMessage): + # 转换层 + nakuru_guild_message = gocq_compatible_receive(message) + gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999) + await self.platform.handle_msg(nakuru_guild_message, is_group=False) + +class QQOfficial(Platform): + + def __init__(self, cfg: dict, message_handler: callable) -> None: + super().__init__(message_handler) + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.qqchan_cnt = 0 + self.waiting: dict = {} + + self.cfg = cfg + self.appid = cfg['qqbot']['appid'] + self.token = cfg['qqbot']['token'] + self.secret = cfg['qqbot_secret'] + self.unique_session = cfg['uniqueSessionMode'] + + self.intents = botpy.Intents( + public_guild_messages=True, + direct_message=cfg['direct_message_mode'] + ) + self.client = botClient( + intents=self.intents, + bot_log=False + ) + self.client.set_platform(self) + + def run(self): + try: + self.loop.run_until_complete(self.client.run( + appid=self.appid, + secret=self.secret + )) + except BaseException as e: + print(e) + self.client = botClient( + intents=self.intents, + bot_log=False + ) + self.client.set_platform(self) + self.client.run( + appid=self.appid, + token=self.token + ) + + async def handle_msg(self, message: NakuruGuildMessage, is_group: bool): + + # 解析出 session_id + if self.unique_session or not is_group: + session_id = message.sender.user_id + else: + session_id = message.channel_id + + # 解析出 role + sender_id = str(message.sender.tiny_id) + if sender_id == self.cfg['admin_qqchan'] or \ + sender_id in self.cfg['other_admins']: + role = 'admin' + else: + role = 'member' + + message_result = await self.message_handler( + message=message, + session_id=session_id, + role=role, + platform='qqchan' + ) + + if message_result is None: + return + + await self.reply_msg(message, message_result.result_message) + if message_result.callback is not None: + message_result.callback() + + # 如果是等待回复的消息 + if session_id in self.waiting and self.waiting[session_id] == '': + self.waiting[session_id] = message + + async def reply_msg(self, + message: NakuruGuildMessage, + res: list): + ''' + 回复频道消息 + ''' + gu.log("回复QQ频道消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ频道", max_len=500) + self.qqchan_cnt += 1 + + plain_text = '' + image_path = '' + msg_ref = None + + if isinstance(res, list): + plain_text, image_path = gocq_compatible_send(res) + elif isinstance(res, str): + plain_text = res + + if image_path is not None and image_path != '': + msg_ref = None + if image_path.startswith("http"): + pic_res = requests.get(image_path, stream = True) + if pic_res.status_code == 200: + image = PILImage.open(io.BytesIO(pic_res.content)) + image_path = gu.save_temp_img(image) + + if message.raw_message is not None and image_path == '': # file_image与message_reference不能同时传入 + msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False) + + # 到这里,我们得到了 plain_text,image_path,msg_ref + + data = { + 'channel_id': str(message.channel_id), + 'content': plain_text, + 'msg_id': message.message_id, + 'message_reference': msg_ref + } + if image_path != '': + data['file_image'] = image_path + + try: + await self._send_wrapper(**data) + except BaseException as e: + print(e) + # 分割过长的消息 + if "msg over length" in str(e): + split_res = [] + split_res.append(plain_text[:len(plain_text)//2]) + split_res.append(plain_text[len(plain_text)//2:]) + for i in split_res: + data['content'] = i + await self._send_wrapper(**data) + else: + # 发送qq信息 + try: + # 防止被qq频道过滤消息 + plain_text = plain_text.replace(".", " . ") + await self._send_wrapper(**data) + except BaseException as e: + try: + data['content'] = str.join(" ", plain_text) + await self._send_wrapper(**data) + except BaseException as e: + plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) + plain_text = plain_text.replace(".", "·") + data['content'] = plain_text + await self._send_wrapper(**data) + + async def _send_wrapper(self, **kwargs): + await self.client.api.post_message(**kwargs) + + async def send_msg(self, channel_id: int, message_chain: list, message_id: int = None): + ''' + 推送消息, 如果有 message_id,那么就是回复消息。 + ''' + _n = NakuruGuildMessage() + _n.channel_id = channel_id + _n.message_id = message_id + await self.reply_msg(_n, message_chain) + + async def send(self, message_obj, message_chain: list): + ''' + 发送信息。内容同 reply_msg + ''' + await self.reply_msg(message_obj, message_chain) + + def wait_for_message(self, channel_id: int) -> NakuruGuildMessage: + ''' + 等待指定 channel_id 的下一条信息,超时 300s 后抛出异常 + ''' + self.waiting[channel_id] = '' + cnt = 0 + while True: + if channel_id in self.waiting and self.waiting[channel_id] != '': + # 去掉 + ret = self.waiting[channel_id] + del self.waiting[channel_id] + return ret + cnt += 1 + if cnt > 300: + raise Exception("等待消息超时。") + time.sleep(1) + + def get_cnt(self): + return self.qqchan_cnt + + def set_cnt(self, cnt): + self.qqchan_cnt = cnt diff --git a/model/platform/qqchan.py b/model/platform/qqchan.py deleted file mode 100644 index 1f891b6f4..000000000 --- a/model/platform/qqchan.py +++ /dev/null @@ -1,217 +0,0 @@ -import io -import botpy -from PIL import Image as PILImage -from botpy.message import Message, DirectMessage -import re -import asyncio -import requests -from cores.qqbot.personality import personalities -from util import general_utils as gu -from nakuru.entities.components import Plain, At, Image -from botpy.types.message import Reference -from botpy import Client -import time - -class NakuruGuildMember(): - tiny_id: int # 发送者识别号 - user_id: int # 发送者识别号 - title: str - nickname: str # 昵称 - role: int # 角色 - icon_url: str # 头像url - -class NakuruGuildMessage(): - type: str = "GuildMessage" - self_id: int # bot的qq号 - self_tiny_id: int # bot的qq号 - sub_type: str # 消息类型 - message_id: str # 消息id - guild_id: int # 频道号 - channel_id: int # 子频道号 - user_id: int # 发送者qq号 - message: list # 消息内容 - sender: NakuruGuildMember # 发送者信息 - raw_message: Message - - def __str__(self) -> str: - return str(self.__dict__) - -class QQChan(): - def __init__(self, cnt: dict = None) -> None: - self.qqchan_cnt = 0 - self.waiting: dict = {} - - def get_cnt(self): - return self.qqchan_cnt - - def set_cnt(self, cnt): - self.qqchan_cnt = cnt - - def run_bot(self, botclient: Client, appid, token): - intents = botpy.Intents(public_guild_messages=True, direct_message=True) - self.client = botclient - self.client.run(appid=appid, token=token) - - # gocq-频道SDK兼容层(发) - def gocq_compatible_send(self, gocq_message_chain: list): - plain_text = "" - image_path = None # only one img supported - for i in gocq_message_chain: - if isinstance(i, Plain): - plain_text += i.text - elif isinstance(i, Image) and image_path == None: - if i.path is not None: - image_path = i.path - else: - image_path = i.file - return plain_text, image_path - - # gocq-频道SDK兼容层(收) - def gocq_compatible_receive(self, message: Message) -> NakuruGuildMessage: - ngm = NakuruGuildMessage() - try: - ngm.self_id = message.mentions[0].id - ngm.self_tiny_id = message.mentions[0].id - except: - ngm.self_id = 0 - ngm.self_tiny_id = 0 - - ngm.sub_type = "normal" - ngm.message_id = message.id - ngm.guild_id = int(message.guild_id) - ngm.channel_id = int(message.channel_id) - ngm.user_id = int(message.author.id) - msg = [] - plain_content = message.content.replace("<@!"+str(ngm.self_id)+">", "").strip() - msg.append(Plain(plain_content)) - if message.attachments: - for i in message.attachments: - if i.content_type.startswith("image"): - url = i.url - if not url.startswith("http"): - url = "https://"+url - img = Image.fromURL(url) - msg.append(img) - ngm.message = msg - ngm.sender = NakuruGuildMember() - ngm.sender.tiny_id = int(message.author.id) - ngm.sender.user_id = int(message.author.id) - ngm.sender.title = "" - ngm.sender.nickname = message.author.username - ngm.sender.role = 0 - ngm.sender.icon_url = message.author.avatar - ngm.raw_message = message - return ngm - - def send_qq_msg(self, - message: NakuruGuildMessage, - res: list): - ''' - 回复频道消息 - ''' - gu.log("回复QQ频道消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ频道", max_len=500) - self.qqchan_cnt += 1 - plain_text = "" - image_path = None - if isinstance(res, list): - # 兼容gocq - plain_text, image_path = self.gocq_compatible_send(res) - elif isinstance(res, str): - plain_text = res - - # print(plain_text, image_path) - msg_ref = None - if message.raw_message is not None: - msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False) - if image_path is not None: - msg_ref = None - if image_path.startswith("http"): - pic_res = requests.get(image_path, stream = True) - if pic_res.status_code == 200: - image = PILImage.open(io.BytesIO(pic_res.content)) - image_path = gu.save_temp_img(image) - - - - try: - # reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop) - reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id), - content=str(plain_text), - msg_id=message.message_id, - file_image=image_path, - message_reference=msg_ref), self.client.loop) - reply_res.result() - except BaseException as e: - # 分割过长的消息 - if "msg over length" in str(e): - split_res = [] - split_res.append(plain_text[:len(plain_text)//2]) - split_res.append(plain_text[len(plain_text)//2:]) - for i in split_res: - reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id), - content=str(i), - msg_id=message.message_id, - file_image=image_path, - message_reference=msg_ref), self.client.loop) - reply_res.result() - else: - # 发送qq信息 - try: - # 防止被qq频道过滤消息 - plain_text = plain_text.replace(".", " . ") - reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id), - content=str(plain_text), - msg_id=message.message_id, - file_image=image_path, - message_reference=msg_ref), self.client.loop).result() # 发送信息 - except BaseException as e: - print("QQ频道API错误: \n"+str(e)) - try: - # reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop) - reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id), - content=str(str.join(" ", plain_text)), - msg_id=message.message_id, - file_image=image_path, - message_reference=msg_ref), self.client.loop).result() - except BaseException as e: - plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) - plain_text = plain_text.replace(".", "·") - reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id), - content=plain_text, - msg_id=message.message_id, - file_image=image_path, - message_reference=msg_ref), self.client.loop).result() - # send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}") - - def push_message(self, channel_id: int, message_chain: list, message_id: int = None): - ''' - 推送消息, 如果有 message_id,那么就是回复消息。 - ''' - _n = NakuruGuildMessage() - _n.channel_id = channel_id - _n.message_id = message_id - self.send_qq_msg(_n, message_chain) - - def send(self, message_obj, message_chain: list): - ''' - 发送信息 - ''' - self.send_qq_msg(message_obj, message_chain) - - def wait_for_message(self, channel_id: int) -> NakuruGuildMessage: - ''' - 等待指定 channel_id 的下一条信息,超时 300s 后抛出异常 - ''' - self.waiting[channel_id] = '' - cnt = 0 - while True: - if channel_id in self.waiting and self.waiting[channel_id] != '': - # 去掉 - ret = self.waiting[channel_id] - del self.waiting[channel_id] - return ret - cnt += 1 - if cnt > 300: - raise Exception("等待消息超时。") - time.sleep(1) - \ No newline at end of file diff --git a/model/platform/qqgroup.py b/model/platform/qqgroup.py deleted file mode 100644 index 753689991..000000000 --- a/model/platform/qqgroup.py +++ /dev/null @@ -1,188 +0,0 @@ -import requests -import asyncio -import websockets -from websockets import WebSocketClientProtocol -import json -import inspect -from typing import Callable, Awaitable, Union -from pydantic import BaseModel -import datetime - -class Event(BaseModel): - GroupMessage: str = "GuildMessage" - -class Sender(BaseModel): - user_id: str - member_openid: str - - -class MessageComponent(BaseModel): - type: str - -class PlainText(MessageComponent): - text: str - -class Image(MessageComponent): - path: str - file: str - url: str - -class MessageChain(list): - - def append(self, __object: MessageComponent) -> None: - if not isinstance(__object, MessageComponent): - raise TypeError("不受支持的消息链元素类型。回复的消息链必须是 MessageComponent 的子类。") - return super().append(__object) - - def insert(self, __index: int, __object: MessageComponent) -> None: - if not isinstance(__object, MessageComponent): - raise TypeError("不受支持的消息链元素类型。回复的消息链必须是 MessageComponent 的子类。") - return super().insert(__index, __object) - - def parse_from_nakuru(self, nakuru_message_chain: Union[list, str]) -> None: - if isinstance(nakuru_message_chain, str): - self.append(PlainText(type='Plain', text=nakuru_message_chain)) - else: - for i in nakuru_message_chain: - if i['type'] == 'Plain': - self.append(PlainText(type='Plain', text=i['text'])) - elif i['type'] == 'Image': - self.append(Image(path=i['path'], file=i['file'], url=i['url'])) - -class Message(BaseModel): - type: str - user_id: str - member_openid: str - message_id: str - group_id: str - group_openid: str - content: str - message: MessageChain - time: int - sender: Sender - -class UnofficialQQBotSDK: - - GET_APP_ACCESS_TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken" - OPENAPI_BASE_URL = "https://api.sgroup.qq.com" - - def __init__(self, appid: str, client_secret: str) -> None: - self.appid = appid - self.client_secret = client_secret - self.events: dict[str, Awaitable] = {} - - - def run_bot(self) -> None: - self.__get_access_token() - self.__get_wss_endpoint() - asyncio.get_event_loop().run_until_complete(self.__ws_client()) - - def __get_access_token(self) -> None: - res = requests.post(self.GET_APP_ACCESS_TOKEN_URL, json={ - "appId": self.appid, - "clientSecret": self.client_secret - }, headers={ - "Content-Type": "application/json" - }) - res = res.json() - code = res['code'] if 'code' in res else 1 - if 'access_token' not in res: - raise Exception(f"获取 access_token 失败。原因:{res['message'] if 'message' in res else '未知'}") - self.access_token = 'QQBot ' + res['access_token'] - - def __auth_header(self) -> str: - return { - 'Authorization': self.access_token, - 'X-Union-Appid': self.appid, - } - - def __get_wss_endpoint(self): - res = requests.get(self.OPENAPI_BASE_URL + "/gateway", headers=self.__auth_header()) - self.wss_endpoint = res.json()['url'] - # print("wss_endpoint: " + self.wss_endpoint) - - async def __behav_heartbeat(self, ws: WebSocketClientProtocol, t: int): - while True: - await asyncio.sleep(t - 1) - try: - await ws.send(json.dumps({ - "op": 1, - "d": self.s - })) - except: - print("heartbeat error.") - - async def __handle_msg(self, ws: WebSocketClientProtocol, msg: dict): - if msg['op'] == 10: - asyncio.get_event_loop().create_task(self.__behav_heartbeat(ws, msg['d']['heartbeat_interval'] / 1000)) - # 鉴权,获得session - await ws.send(json.dumps({ - "op": 2, - "d": { - "token": self.access_token, - "intents": 33554432, - "shard": [0, 1], - "properties": { - "$os": "linux", - "$browser": "my_library", - "$device": "my_library" - } - } - })) - if msg['op'] == 0: - # ready - data = msg['d'] - event_typ: str = msg['t'] if 't' in msg else None - if event_typ == 'GROUP_AT_MESSAGE_CREATE': - if 'GroupMessage' in self.events: - coro = self.events['GroupMessage'] - else: - return - message_chain = MessageChain() - message_chain.append(PlainText(type="Plain", text=data['content'])) - group_message = Message( - type='GroupMessage', - user_id=data['author']['id'], - member_openid=data['author']['member_openid'], - message_id=data['id'], - group_id=data['group_id'], - group_openid=data['group_openid'], - content=data['content'], - # 2023-11-24T19:51:11+08:00 - time=int(datetime.datetime.strptime(data['timestamp'], "%Y-%m-%dT%H:%M:%S%z").timestamp()), - sender=Sender( - user_id=data['author']['id'], - member_openid=data['author']['member_openid'] - ), - message=message_chain - ) - await coro(self, group_message) - - async def send(self, message: Message, message_chain: MessageChain) -> None: - # todo: 消息链转换支持更多类型。 - plain_text = "" - for i in message_chain: - if isinstance(i, PlainText): - plain_text += i.text - requests.post(self.OPENAPI_BASE_URL + f"/v2/groups/{message.group_openid}/messages", headers=self.__auth_header(), json={ - "content": plain_text, - "message_type": 0, - "msg_id": message.message_id - }) - - async def __ws_client(self): - self.s = 0 - async with websockets.connect(self.wss_endpoint) as websocket: - while True: - msg = await websocket.recv() - msg = json.loads(msg) - if 's' in msg: - self.s = msg['s'] - await self.__handle_msg(websocket, msg) - - def on(self, event: str) -> None: - def wrapper(func: Awaitable): - if inspect.iscoroutinefunction(func) == False: - raise TypeError("func must be a coroutine function") - self.events[event] = func - return wrapper \ No newline at end of file diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 345a3fa45..d24611eb4 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -239,7 +239,7 @@ class ProviderOpenAIOfficial(Provider): err = str(e) retry += 1 if retry >= 10: - gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999) + gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki", max_len=999) raise BaseException("连接出错: "+str(err)) assert isinstance(response, ChatCompletion) gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999) diff --git a/requirements.txt b/requirements.txt index c65e9b4af..0f00413d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ pydantic~=1.10.4 requests~=2.28.1 openai~=1.2.3 -qq-botpy==1.1.2 +qq-botpy chardet~=5.1.0 Pillow~=9.4.0 GitPython~=3.1.31 diff --git a/util/cmd_config.py b/util/cmd_config.py index 047baed79..610a323b7 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -80,4 +80,43 @@ class CmdConfig(): if _tag: with open(cpath, "w", encoding="utf-8-sig") as f: json.dump(d, f, indent=4, ensure_ascii=False) - f.flush() \ No newline at end of file + f.flush() + +def init_astrbot_config_items(): + # 加载默认配置 + cc = CmdConfig() + cc.init_attributes("qq_forward_threshold", 200) + cc.init_attributes("qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n") + cc.init_attributes("bing_proxy", "") + cc.init_attributes("qq_pic_mode", False) + cc.init_attributes("rev_chatgpt_model", "") + cc.init_attributes("rev_chatgpt_plugin_ids", []) + cc.init_attributes("rev_chatgpt_PUID", "") + cc.init_attributes("rev_chatgpt_unverified_plugin_domains", []) + cc.init_attributes("gocq_host", "127.0.0.1") + cc.init_attributes("gocq_http_port", 5700) + cc.init_attributes("gocq_websocket_port", 6700) + cc.init_attributes("gocq_react_group", True) + cc.init_attributes("gocq_react_guild", True) + cc.init_attributes("gocq_react_friend", True) + cc.init_attributes("gocq_react_group_increase", True) + cc.init_attributes("gocq_qqchan_admin", "") + cc.init_attributes("other_admins", []) + cc.init_attributes("CHATGPT_BASE_URL", "") + cc.init_attributes("qqbot_appid", "") + cc.init_attributes("qqbot_secret", "") + cc.init_attributes("admin_qq", "") + cc.init_attributes("admin_qqchan", "") + cc.init_attributes("llm_env_prompt", "> hint: 末尾根据内容和心情添加 1-2 个emoji") + cc.init_attributes("default_personality_str", "") + cc.init_attributes("openai_image_generate", { + "model": "dall-e-3", + "size": "1024x1024", + "style": "vivid", + "quality": "standard", + }) + cc.init_attributes("http_proxy", "") + cc.init_attributes("https_proxy", "") + cc.init_attributes("dashboard_username", "") + cc.init_attributes("dashboard_password", "") + # cc.init_attributes(["qq_forward_mode"], False) diff --git a/util/general_utils.py b/util/general_utils.py index de68f2783..8f74e828e 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -7,6 +7,8 @@ import re import requests from util.cmd_config import CmdConfig import socket +from cores.qqbot.global_object import GlobalObject +import platform PLATFORM_GOCQ = 'gocq' PLATFORM_QQCHAN = 'qqchan' @@ -531,4 +533,18 @@ def get_local_ip_addresses(): finally: s.close() - return ip \ No newline at end of file + return ip + +def get_sys_info(global_object: GlobalObject): + mem = None + stats = global_object.dashboard_data.stats + os_name = platform.system() + os_version = platform.version() + + if 'sys_perf' in stats and 'memory' in stats['sys_perf']: + mem = stats['sys_perf']['memory'] + return { + 'mem': mem, + 'os': os_name + '_' + os_version, + 'py': platform.python_version(), + } \ No newline at end of file