From f66091e08fdeb1045cf06a88d876664f86c06be2 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 21 Apr 2024 22:20:23 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8:=20clean=20codes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- addons/baidu_aip_judge.py | 7 +- addons/dashboard/server.py | 126 ++++++++++++----------- addons/plugins/helloworld/helloworld.py | 28 +++-- cores/astrbot/core.py | 128 ++++++++++++++--------- cores/astrbot/types.py | 129 +++++++++++++----------- cores/database/conn.py | 25 +++-- main.py | 27 +++-- model/command/command.py | 72 +++++++------ model/command/openai_official.py | 44 ++++---- model/platform/_message_parse.py | 26 +++-- model/platform/_message_result.py | 1 + model/platform/_platfrom.py | 2 +- model/platform/qq_gocq.py | 59 ++++++----- model/platform/qq_official.py | 77 ++++++++------ model/provider/openai_official.py | 109 +++++++++++--------- model/provider/provider.py | 14 +-- util/cmd_config.py | 18 ++-- util/function_calling/func_call.py | 56 +++++----- util/function_calling/gplugin.py | 40 +++++--- util/general_utils.py | 122 +++++++++++++--------- util/personality.py | 2 +- util/plugin_util.py | 38 ++++--- 22 files changed, 671 insertions(+), 479 deletions(-) diff --git a/addons/baidu_aip_judge.py b/addons/baidu_aip_judge.py index 2b80cb14a..cd2417d0a 100644 --- a/addons/baidu_aip_judge.py +++ b/addons/baidu_aip_judge.py @@ -1,14 +1,17 @@ from aip import AipContentCensor + class BaiduJudge: def __init__(self, baidu_configs) -> None: if 'app_id' in baidu_configs and 'api_key' in baidu_configs and 'secret_key' in baidu_configs: self.app_id = str(baidu_configs['app_id']) self.api_key = baidu_configs['api_key'] self.secret_key = baidu_configs['secret_key'] - self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key) + self.client = AipContentCensor( + self.app_id, self.api_key, self.secret_key) else: raise ValueError("Baidu configs error! 请填写百度内容审核服务相关配置!") + def judge(self, text): res = self.client.textCensorUserDefined(text) if 'conclusionType' not in res: @@ -23,4 +26,4 @@ class BaiduJudge: for i in res['data']: info += f"{i['msg']};\n" info += "\n判断结果:"+res['conclusion'] - return False, info \ No newline at end of file + return False, info diff --git a/addons/dashboard/server.py b/addons/dashboard/server.py index e190bc5c8..f53602c89 100644 --- a/addons/dashboard/server.py +++ b/addons/dashboard/server.py @@ -13,9 +13,11 @@ import websockets import json import threading import asyncio -import os, sys +import os +import sys import time + @dataclass class DashBoardData(): stats: dict @@ -23,33 +25,36 @@ class DashBoardData(): logs: dict plugins: List[RegisteredPlugin] + @dataclass class Response(): status: str message: str data: dict - + + class AstrBotDashBoard(): def __init__(self, global_object: 'gu.GlobalObject'): self.global_object = global_object self.loop = asyncio.get_event_loop() asyncio.set_event_loop(self.loop) self.dashboard_data: DashBoardData = global_object.dashboard_data - self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/") + self.dashboard_be = Flask( + __name__, static_folder="dist", static_url_path="/") log = logging.getLogger('werkzeug') log.setLevel(logging.ERROR) self.funcs = {} self.cc = CmdConfig() self.logger = global_object.logger - self.ws_clients = {} # remote_ip: ws + self.ws_clients = {} # remote_ip: ws # 启动 websocket 服务器 self.ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186) - + @self.dashboard_be.get("/") def index(): # 返回页面 return self.dashboard_be.send_static_file("index.html") - + @self.dashboard_be.post("/api/authenticate") def authenticate(): username = self.cc.get("dashboard_username", "") @@ -71,7 +76,7 @@ class AstrBotDashBoard(): message="用户名或密码错误。", data=None ).__dict__ - + @self.dashboard_be.post("/api/change_password") def change_password(): password = self.cc.get("dashboard_password", "") @@ -99,9 +104,11 @@ class AstrBotDashBoard(): # last_24_platform = db_inst.get_last_24h_stat_platform() platforms = db_inst.get_platform_cnt_total() self.dashboard_data.stats["session"] = [] - self.dashboard_data.stats["session_total"] = db_inst.get_session_cnt_total() + self.dashboard_data.stats["session_total"] = db_inst.get_session_cnt_total( + ) self.dashboard_data.stats["message"] = last_24_message - self.dashboard_data.stats["message_total"] = db_inst.get_message_cnt_total() + self.dashboard_data.stats["message_total"] = db_inst.get_message_cnt_total( + ) self.dashboard_data.stats["platform"] = platforms return Response( @@ -109,7 +116,7 @@ class AstrBotDashBoard(): message="", data=self.dashboard_data.stats ).__dict__ - + @self.dashboard_be.get("/api/configs") def get_configs(): # 如果params中有namespace,则返回该namespace下的配置 @@ -121,7 +128,7 @@ class AstrBotDashBoard(): message="", data=conf ).__dict__ - + @self.dashboard_be.get("/api/config_outline") def get_config_outline(): outline = self._generate_outline() @@ -130,7 +137,7 @@ class AstrBotDashBoard(): message="", data=outline ).__dict__ - + @self.dashboard_be.post("/api/configs") def post_configs(): post_configs = request.json @@ -147,7 +154,7 @@ class AstrBotDashBoard(): message=e.__str__(), data=self.dashboard_data.configs ).__dict__ - + @self.dashboard_be.get("/api/extensions") def get_plugins(): _plugin_resp = [] @@ -166,7 +173,7 @@ class AstrBotDashBoard(): message="", data=_plugin_resp ).__dict__ - + @self.dashboard_be.post("/api/extensions/install") def install_plugin(): post_data = request.json @@ -186,14 +193,15 @@ class AstrBotDashBoard(): message=e.__str__(), data=None ).__dict__ - + @self.dashboard_be.post("/api/extensions/uninstall") def uninstall_plugin(): post_data = request.json plugin_name = post_data["name"] try: self.logger.log(f"正在卸载插件 {plugin_name}", tag="可视化面板") - putil.uninstall_plugin(plugin_name, self.dashboard_data.plugins) + putil.uninstall_plugin( + plugin_name, self.dashboard_data.plugins) self.logger.log(f"卸载插件 {plugin_name} 成功", tag="可视化面板") return Response( status="success", @@ -206,7 +214,7 @@ class AstrBotDashBoard(): message=e.__str__(), data=None ).__dict__ - + @self.dashboard_be.post("/api/extensions/update") def update_plugin(): post_data = request.json @@ -226,16 +234,17 @@ class AstrBotDashBoard(): message=e.__str__(), data=None ).__dict__ - + @self.dashboard_be.post("/api/log") def log(): for item in self.ws_clients: try: - asyncio.run_coroutine_threadsafe(self.ws_clients[item].send(request.data.decode()), self.loop) + asyncio.run_coroutine_threadsafe( + self.ws_clients[item].send(request.data.decode()), self.loop) except Exception as e: pass return 'ok' - + @self.dashboard_be.get("/api/check_update") def get_update_info(): try: @@ -244,7 +253,7 @@ class AstrBotDashBoard(): status="success", message=ret, data={ - "has_new_version": ret != "当前已经是最新版本。" # 先这样吧,累了=.= + "has_new_version": ret != "当前已经是最新版本。" # 先这样吧,累了=.= } ).__dict__ except Exception as e: @@ -253,7 +262,7 @@ class AstrBotDashBoard(): message=e.__str__(), data=None ).__dict__ - + @self.dashboard_be.post("/api/update_project") def update_project_api(): version = request.json['version'] @@ -263,7 +272,8 @@ class AstrBotDashBoard(): else: latest = False try: - update_project(request_release_info(latest), latest=latest, version=version) + update_project(request_release_info(latest), + latest=latest, version=version) threading.Thread(target=self.shutdown_bot, args=(3,)).start() return Response( status="success", @@ -276,7 +286,7 @@ class AstrBotDashBoard(): message=e.__str__(), data=None ).__dict__ - + @self.dashboard_be.get("/api/llm/list") def llm_list(): ret = [] @@ -287,7 +297,7 @@ class AstrBotDashBoard(): message="", data=ret ).__dict__ - + @self.dashboard_be.get("/api/llm") def llm(): text = request.args["text"] @@ -296,7 +306,8 @@ class AstrBotDashBoard(): if llm_.llm_name == llm: try: # ret = await llm_.llm_instance.text_chat(text) - ret = asyncio.run_coroutine_threadsafe(llm_.llm_instance.text_chat(text), self.loop).result() + ret = asyncio.run_coroutine_threadsafe( + llm_.llm_instance.text_chat(text), self.loop).result() return Response( status="success", message="", @@ -314,21 +325,21 @@ class AstrBotDashBoard(): message="LLM not found.", data=None ).__dict__ - + def shutdown_bot(self, delay_s: int): time.sleep(delay_s) py = sys.executable os.execl(py, py, *sys.argv) - + def _get_configs(self, namespace: str): if namespace == "": - ret = [self.dashboard_data.configs['data'][4], - self.dashboard_data.configs['data'][5],] + ret = [self.dashboard_data.configs['data'][4], + self.dashboard_data.configs['data'][5],] elif namespace == "internal_platform_qq_official": ret = [self.dashboard_data.configs['data'][0],] elif namespace == "internal_platform_qq_gocq": ret = [self.dashboard_data.configs['data'][1],] - elif namespace == "internal_platform_general": # 全局平台配置 + elif namespace == "internal_platform_general": # 全局平台配置 ret = [self.dashboard_data.configs['data'][2],] elif namespace == "internal_llm_openai_official": ret = [self.dashboard_data.configs['data'][3],] @@ -352,28 +363,28 @@ class AstrBotDashBoard(): ''' outline = [ { - "type": "platform", - "name": "配置通用消息平台", - "body": [ - { - "title": "通用", - "desc": "通用平台配置", - "namespace": "internal_platform_general", - "tag": "" - }, - { - "title": "QQ_OFFICIAL", - "desc": "QQ官方API,仅支持频道", - "namespace": "internal_platform_qq_official", - "tag": "" - }, - { - "title": "OneBot协议", - "desc": "支持cq-http、shamrock等(目前仅支持QQ平台)", - "namespace": "internal_platform_qq_gocq", - "tag": "" - } - ] + "type": "platform", + "name": "配置通用消息平台", + "body": [ + { + "title": "通用", + "desc": "通用平台配置", + "namespace": "internal_platform_general", + "tag": "" + }, + { + "title": "QQ_OFFICIAL", + "desc": "QQ官方API,仅支持频道", + "namespace": "internal_platform_qq_official", + "tag": "" + }, + { + "title": "OneBot协议", + "desc": "支持cq-http、shamrock等(目前仅支持QQ平台)", + "namespace": "internal_platform_qq_gocq", + "tag": "" + } + ] }, { "type": "llm", @@ -422,7 +433,7 @@ class AstrBotDashBoard(): # self.logger.log(f"和 {path} 的 websocket 连接发生了错误: {e.__str__()}", tag="可视化面板") del self.ws_clients[address] break - + def run_ws_server(self, loop): asyncio.set_event_loop(loop) loop.run_until_complete(self.ws_server) @@ -433,7 +444,8 @@ class AstrBotDashBoard(): self.logger.log("已启动 websocket 服务器", tag="可视化面板") ip_address = gu.get_local_ip_addresses() ip_str = f"http://{ip_address}:6185\n\thttp://localhost:6185" - self.logger.log(f"\n==================\n您可访问:\n\n\t{ip_str}\n\n来登录可视化面板,默认账号密码为空。\n注意: 所有配置项现已全量迁移至 cmd_config.json 文件下,可登录可视化面板在线修改配置。\n==================\n", tag="可视化面板") - http_server = make_server('0.0.0.0', 6185, self.dashboard_be, threaded=True) + self.logger.log( + f"\n==================\n您可访问:\n\n\t{ip_str}\n\n来登录可视化面板,默认账号密码为空。\n注意: 所有配置项现已全量迁移至 cmd_config.json 文件下,可登录可视化面板在线修改配置。\n==================\n", tag="可视化面板") + http_server = make_server( + '0.0.0.0', 6185, self.dashboard_be, threaded=True) http_server.serve_forever() - diff --git a/addons/plugins/helloworld/helloworld.py b/addons/plugins/helloworld/helloworld.py index c9e036835..53b852e56 100644 --- a/addons/plugins/helloworld/helloworld.py +++ b/addons/plugins/helloworld/helloworld.py @@ -17,10 +17,13 @@ except ImportError: 注意改插件名噢!格式:XXXPlugin 或 Main 小提示:把此模板仓库 fork 之后 clone 到机器人文件夹下的 addons/plugins/ 目录下,然后用 Pycharm/VSC 等工具打开可获更棒的编程体验(自动补全等) ''' + + class HelloWorldPlugin: """ 初始化函数, 可以选择直接pass """ + def __init__(self) -> None: # 复制旧配置文件到 data 目录下。 if os.path.exists("keyword.json"): @@ -37,6 +40,7 @@ class HelloWorldPlugin: Tuple: Non e或者长度为 3 的元组。如果不响应, 返回 None; 如果响应, 第 1 个参数为指令是否调用成功, 第 2 个参数为返回的消息链列表, 第 3 个参数为指令名称 例子:一个名为"yuanshen"的插件;当接收到消息为“原神 可莉”, 如果不想要处理此消息,则返回False, None;如果想要处理,但是执行失败了,返回True, tuple([False, "请求失败。", "yuanshen"]) ;执行成功了,返回True, tuple([True, "结果文本", "yuanshen"]) """ + def run(self, ame: AstrMessageEvent): if ame.message_str == "helloworld": return CommandResult( @@ -47,9 +51,10 @@ class HelloWorldPlugin: ) if ame.message_str.startswith("/keyword") or ame.message_str.startswith("keyword"): return self.handle_keyword_command(ame) - + ret = self.check_keyword(ame.message_str) - if ret: return ret + if ret: + return ret return CommandResult( hit=False, @@ -57,10 +62,10 @@ class HelloWorldPlugin: message_chain=None, command_name=None ) - + def handle_keyword_command(self, ame: AstrMessageEvent): l = ame.message_str.split(" ") - + # 获取图片 image_url = "" for comp in ame.message_obj.message: @@ -69,7 +74,7 @@ class HelloWorldPlugin: image_url = comp.file else: image_url = comp.url - + command_result = CommandResult( hit=True, success=False, @@ -108,11 +113,11 @@ keyword d hi command_result.success = True command_result.message_chain = [Plain("设置成功")] return command_result - + def save_keyword(self): - json.dump(self.keywords, open("data/keyword.json", "w"), ensure_ascii=False) - - + json.dump(self.keywords, open( + "data/keyword.json", "w"), ensure_ascii=False) + def check_keyword(self, message_str: str): for k in self.keywords: if message_str == k: @@ -151,7 +156,8 @@ keyword d hi "repo": str, # 插件仓库地址 [ 可选 ] "homepage": str, # 插件主页 [ 可选 ] } - """ + """ + def info(self): return { "name": "helloworld", @@ -159,4 +165,4 @@ keyword d hi "help": "输入 /keyword 查看关键词回复帮助。", "version": "v1.3", "author": "Soulter" - } \ No newline at end of file + } diff --git a/cores/astrbot/core.py b/cores/astrbot/core.py index 5a4e66193..df9dca2f5 100644 --- a/cores/astrbot/core.py +++ b/cores/astrbot/core.py @@ -13,12 +13,6 @@ import util.function_calling.gplugin as gplugin import util.plugin_util as putil from PIL import Image as PILImage -from typing import Union -from nakuru import ( - GroupMessage, - FriendMessage, - GuildMessage, -) from nakuru.entities.components import Plain, At, Image from addons.baidu_aip_judge import BaiduJudge @@ -67,22 +61,27 @@ _global_object: GlobalObject = None logger: Logger = Logger() # 语言模型选择 + + def privider_chooser(cfg): l = [] if 'openai' in cfg and len(cfg['openai']['key']) > 0 and cfg['openai']['key'][0] is not None: l.append('openai_official') return l + ''' 初始化机器人 ''' + + def init(cfg): global llm_instance, llm_command_instance global baidu_judge, chosen_provider global frequency_count, frequency_time global _global_object global logger - + # 迁移旧配置 gu.try_migrate_config(cfg) # 使用新配置 @@ -115,25 +114,29 @@ def init(cfg): if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]: from model.provider.openai_official import ProviderOpenAIOfficial from model.command.openai_official import CommandOpenAIOfficial - llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(cfg['openai']) - llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial(llm_instance[OPENAI_OFFICIAL], _global_object) - _global_object.llms.append(RegisteredLLM(llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal")) + llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial( + cfg['openai']) + llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial( + llm_instance[OPENAI_OFFICIAL], _global_object) + _global_object.llms.append(RegisteredLLM( + llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal")) chosen_provider = OPENAI_OFFICIAL # 检查provider设置偏好 p = cc.get("chosen_provider", None) if p is not None and p in llm_instance: chosen_provider = p - + # 百度内容审核 if 'baidu_aip' in cfg and 'enable' in cfg['baidu_aip'] and cfg['baidu_aip']['enable']: - try: + try: baidu_judge = BaiduJudge(cfg['baidu_aip']) logger.log("百度内容审核初始化成功", gu.LEVEL_INFO) except BaseException as e: logger.log("百度内容审核初始化失败", gu.LEVEL_ERROR) - - threading.Thread(target=upload, args=(_global_object, ), daemon=True).start() + + threading.Thread(target=upload, args=( + _global_object, ), daemon=True).start() # 得到发言频率配置 if 'limit' in cfg: @@ -141,7 +144,7 @@ def init(cfg): frequency_count = cfg['limit']['count'] if 'time' in cfg['limit']: frequency_time = cfg['limit']['time'] - + try: if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']: _global_object.unique_session = True @@ -152,7 +155,7 @@ def init(cfg): nick_qq = cc.get("nick_qq", None) if nick_qq == None: - nick_qq = ("ai","!","!") + nick_qq = ("ai", "!", "!") if isinstance(nick_qq, str): nick_qq = (nick_qq,) if isinstance(nick_qq, list): @@ -168,10 +171,11 @@ def init(cfg): _command = Command(None, _global_object) ok, err = putil.plugin_reload(_global_object.cached_plugins) if ok: - logger.log(f"成功载入 {len(_global_object.cached_plugins)} 个插件", gu.LEVEL_INFO) + logger.log( + f"成功载入 {len(_global_object.cached_plugins)} 个插件", gu.LEVEL_INFO) else: logger.log(err, gu.LEVEL_ERROR) - + if chosen_provider is None: llm_command_instance[NONE_LLM] = _command chosen_provider = NONE_LLM @@ -182,19 +186,21 @@ def init(cfg): # GOCQ if 'gocqbot' in cfg and cfg['gocqbot']['enable']: logger.log("启用 QQ_GOCQ 机器人消息平台", gu.LEVEL_INFO) - threading.Thread(target=run_gocq_bot, args=(cfg, _global_object), daemon=True).start() + threading.Thread(target=run_gocq_bot, args=( + cfg, _global_object), daemon=True).start() platform_str += "QQ_GOCQ," # QQ频道 if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None: logger.log("启用 QQ_OFFICIAL 机器人消息平台", gu.LEVEL_INFO) - threading.Thread(target=run_qqchan_bot, args=(cfg, _global_object), daemon=True).start() + threading.Thread(target=run_qqchan_bot, args=( + cfg, _global_object), daemon=True).start() platform_str += "QQ_OFFICIAL," default_personality_str = cc.get("default_personality_str", "") if default_personality_str == "": _global_object.default_personality = None - else: + else: _global_object.default_personality = { "name": "default", "prompt": default_personality_str, @@ -207,63 +213,83 @@ def init(cfg): plugins=_global_object.cached_plugins, ) dashboard_helper = DashBoardHelper(_global_object, config=cc.get_all()) - dashboard_thread = threading.Thread(target=dashboard_helper.run, daemon=True) + dashboard_thread = threading.Thread( + target=dashboard_helper.run, daemon=True) dashboard_thread.start() # 运行 monitor - threading.Thread(target=run_monitor, args=(_global_object,), daemon=True).start() + threading.Thread(target=run_monitor, args=( + _global_object,), daemon=True).start() - logger.log("如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。", gu.LEVEL_INFO) + logger.log( + "如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。", gu.LEVEL_INFO) logger.log("请给 https://github.com/Soulter/AstrBot 点个 star。", gu.LEVEL_INFO) if platform_str == '': platform_str = "(未启动任何平台,请前往面板添加)" logger.log(f"🎉 项目启动完成") - + dashboard_thread.join() + ''' 运行 QQ_OFFICIAL 机器人 ''' + + def run_qqchan_bot(cfg: dict, global_object: GlobalObject): try: from model.platform.qq_official import QQOfficial - qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg, global_object=global_object) - global_object.platforms.append(RegisteredPlatform(platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) + qqchannel_bot = QQOfficial( + cfg=cfg, message_handler=oper_msg, global_object=global_object) + global_object.platforms.append(RegisteredPlatform( + platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) qqchannel_bot.run() except BaseException as e: - logger.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), gu.LEVEL_CRITICAL, tag="QQ频道") - logger.log(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。" + str(e), gu.LEVEL_CRITICAL) + logger.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), + gu.LEVEL_CRITICAL, tag="QQ频道") + logger.log(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。" + + str(e), gu.LEVEL_CRITICAL) + ''' 运行 QQ_GOCQ 机器人 ''' + + def run_gocq_bot(cfg: dict, _global_object: GlobalObject): from model.platform.qq_gocq import QQGOCQ - + noticed = False host = cc.get("gocq_host", "127.0.0.1") port = cc.get("gocq_websocket_port", 6700) http_port = cc.get("gocq_http_port", 5700) - logger.log(f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}", tag="QQ") + logger.log( + f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}", tag="QQ") while True: if not gu.port_checker(port=port, host=host) or not gu.port_checker(port=http_port, host=host): if not noticed: noticed = True - logger.log(f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。", gu.LEVEL_CRITICAL, tag="QQ") + logger.log( + f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。", gu.LEVEL_CRITICAL, tag="QQ") time.sleep(5) else: logger.log("检查完毕,未发现问题。", tag="QQ") break try: - qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, global_object=_global_object) - _global_object.platforms.append(RegisteredPlatform(platform_name="gocq", platform_instance=qq_gocq, origin="internal")) + qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, + global_object=_global_object) + _global_object.platforms.append(RegisteredPlatform( + platform_name="gocq", platform_instance=qq_gocq, origin="internal")) qq_gocq.run() except BaseException as e: input("启动QQ机器人出现错误"+str(e)) + ''' 检查发言频率 ''' + + def check_frequency(id) -> bool: ts = int(time.time()) if id in user_frequency: @@ -275,13 +301,14 @@ def check_frequency(id) -> bool: if user_frequency[id]['count'] >= frequency_count: return False else: - user_frequency[id]['count']+=1 + user_frequency[id]['count'] += 1 return True else: - t = {'time':ts,'count':1} + t = {'time': ts, 'count': 1} user_frequency[id] = t return True + async def record_message(platform: str, session_id: str): # TODO: 这里会非常吃资源。然而 sqlite3 不支持多线程,所以暂时这样写。 curr_ts = int(time.time()) @@ -291,11 +318,12 @@ async def record_message(platform: str, session_id: str): db_inst.increment_stat_platform(curr_ts, platform, 1) _global_object.cnt_total += 1 + async def oper_msg(message: AstrBotMessage, - session_id: str, - role: str = 'member', - platform: str = None, -) -> MessageResult: + session_id: str, + role: str = 'member', + platform: str = None, + ) -> MessageResult: """ 处理消息。 message: 消息对象 @@ -307,9 +335,9 @@ async def oper_msg(message: AstrBotMessage, message_str = '' session_id = session_id role = role - hit = False # 是否命中指令 - command_result = () # 调用指令返回的结果 - + hit = False # 是否命中指令 + command_result = () # 调用指令返回的结果 + # 获取平台实例 reg_platform: RegisteredPlatform = None for p in _global_object.platforms: @@ -319,7 +347,7 @@ async def oper_msg(message: AstrBotMessage, if not reg_platform: _global_object.logger.log(f"未找到平台 {platform} 的实例。", gu.LEVEL_ERROR) raise Exception(f"未找到平台 {platform} 的实例。") - + # 统计数据,如频道消息量 await record_message(platform, session_id) @@ -328,11 +356,11 @@ async def oper_msg(message: AstrBotMessage, message_str += i.text.strip() if message_str == "": return MessageResult("Hi~") - + # 检查发言频率 if not check_frequency(message.sender.user_id): return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。') - + # 检查是否是更换语言模型的请求 temp_switch = "" if message_str.startswith('/gpt'): @@ -400,7 +428,7 @@ async def oper_msg(message: AstrBotMessage, official_fc = chosen_provider == OPENAI_OFFICIAL llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) else: - llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality) + llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality=_global_object.default_personality) llm_result_str = _global_object.reply_prefix + llm_result_str except BaseException as e: @@ -410,7 +438,7 @@ async def oper_msg(message: AstrBotMessage, # 切换回原来的语言模型 if temp_switch != "": chosen_provider = temp_switch - + if hit: # 有指令或者插件触发 # command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型) @@ -426,7 +454,7 @@ async def oper_msg(message: AstrBotMessage, if not command_result[0]: return MessageResult(f"指令调用错误: \n{str(command_result[1])}") - + # 画图指令 if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw': for i in command_result[1]: @@ -457,4 +485,4 @@ async def oper_msg(message: AstrBotMessage, try: return MessageResult(llm_result_str) except BaseException as e: - logger.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR) \ No newline at end of file + logger.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR) diff --git a/cores/astrbot/types.py b/cores/astrbot/types.py index 8e4e080f4..3f56947db 100644 --- a/cores/astrbot/types.py +++ b/cores/astrbot/types.py @@ -11,38 +11,43 @@ from types import ModuleType from enum import Enum from dataclasses import dataclass -class MessageType(Enum): - GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 - FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 - GUILD_MESSAGE = 'GuildMessage' # 频道消息 -@dataclass +class MessageType(Enum): + GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 + FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 + GUILD_MESSAGE = 'GuildMessage' # 频道消息 + + +@dataclass class MessageMember(): - user_id: str # 发送者id + user_id: str # 发送者id nickname: str = None + class AstrBotMessage(): ''' AstrBot 的消息对象 ''' - tag: str # 消息来源标签 - type: MessageType # 消息类型 - self_id: str # 机器人的识别id - session_id: str # 会话id - message_id: str # 消息id - sender: MessageMember # 发送者 - message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 - message_str: str # 最直观的纯文本消息字符串 + tag: str # 消息来源标签 + type: MessageType # 消息类型 + self_id: str # 机器人的识别id + session_id: str # 会话id + message_id: str # 消息id + sender: MessageMember # 发送者 + message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 + message_str: str # 最直观的纯文本消息字符串 raw_message: object - timestamp: int # 消息时间戳 - + timestamp: int # 消息时间戳 + def __str__(self) -> str: return str(self.__dict__) + class PluginType(Enum): - PLATFORM = 'platfrom' # 平台类插件。 - LLM = 'llm' # 大语言模型类插件 - COMMON = 'common' # 其他插件 + PLATFORM = 'platfrom' # 平台类插件。 + LLM = 'llm' # 大语言模型类插件 + COMMON = 'common' # 其他插件 + @dataclass class PluginMetadata: @@ -52,16 +57,17 @@ class PluginMetadata: # required plugin_name: str plugin_type: PluginType - author: str # 插件作者 - desc: str # 插件简介 - version: str # 插件版本 - + author: str # 插件作者 + desc: str # 插件简介 + version: str # 插件版本 + # optional - repo: str = None # 插件仓库地址 - + repo: str = None # 插件仓库地址 + def __str__(self) -> str: return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" - + + @dataclass class RegisteredPlugin: ''' @@ -75,9 +81,11 @@ class RegisteredPlugin: def __str__(self) -> str: return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" - + + RegisteredPlugins = List[RegisteredPlugin] + @dataclass class RegisteredPlatform: ''' @@ -85,8 +93,9 @@ class RegisteredPlatform: ''' platform_name: str platform_instance: Platform - origin: str = None # 注册来源 - + origin: str = None # 注册来源 + + @dataclass class RegisteredLLM: ''' @@ -94,32 +103,33 @@ class RegisteredLLM: ''' llm_name: str llm_instance: LLMProvider - origin: str = None # 注册来源 - + origin: str = None # 注册来源 + + class GlobalObject: ''' 存放一些公用的数据,用于在不同模块(如core与command)之间传递 ''' - version: str # 机器人版本 - nick: str # 用户定义的机器人的别名 - base_config: dict # config.json 中导出的配置 - cached_plugins: List[RegisteredPlugin] # 加载的插件 + version: str # 机器人版本 + nick: str # 用户定义的机器人的别名 + base_config: dict # config.json 中导出的配置 + cached_plugins: List[RegisteredPlugin] # 加载的插件 platforms: List[RegisteredPlatform] llms: List[RegisteredLLM] - - web_search: bool # 是否开启了网页搜索 - reply_prefix: str # 回复前缀 - unique_session: bool # 是否开启了独立会话 - cnt_total: int # 总消息数 + + web_search: bool # 是否开启了网页搜索 + reply_prefix: str # 回复前缀 + unique_session: bool # 是否开启了独立会话 + cnt_total: int # 总消息数 default_personality: dict dashboard_data = None logger: None - + def __init__(self): - self.nick = None # gocq 的昵称 - self.base_config = None # config.yaml - self.cached_plugins = [] # 缓存的插件 - self.web_search = False # 是否开启了网页搜索 + self.nick = None # gocq 的昵称 + self.base_config = None # config.yaml + self.cached_plugins = [] # 缓存的插件 + self.web_search = False # 是否开启了网页搜索 self.reply_prefix = None self.unique_session = False self.cnt_total = 0 @@ -129,21 +139,22 @@ class GlobalObject: self.dashboard_data = None self.stat = {} + class AstrMessageEvent(): ''' 消息事件。 ''' - context: GlobalObject # 一些公用数据 - message_str: str # 纯消息字符串 - message_obj: AstrBotMessage # 消息对象 - platform: RegisteredPlatform # 来源平台 - role: str # 基本身份。`admin` 或 `member` - session_id: int # 会话 id + context: GlobalObject # 一些公用数据 + message_str: str # 纯消息字符串 + message_obj: AstrBotMessage # 消息对象 + platform: RegisteredPlatform # 来源平台 + role: str # 基本身份。`admin` 或 `member` + session_id: int # 会话 id - def __init__(self, - message_str: str, - message_obj: AstrBotMessage, - platform: RegisteredPlatform, + def __init__(self, + message_str: str, + message_obj: AstrBotMessage, + platform: RegisteredPlatform, role: str, context: GlobalObject, session_id: str = None): @@ -153,16 +164,18 @@ class AstrMessageEvent(): self.platform = platform self.role = role self.session_id = session_id - + + class CommandResult(): ''' 用于在Command中返回多个值 ''' + def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None: self.hit = hit self.success = success self.message_chain = message_chain self.command_name = command_name - + def _result_tuple(self): - return (self.success, self.message_chain, self.command_name) \ No newline at end of file + return (self.success, self.message_chain, self.command_name) diff --git a/cores/database/conn.py b/cores/database/conn.py index 9e9f29790..c9cbb80ac 100644 --- a/cores/database/conn.py +++ b/cores/database/conn.py @@ -3,11 +3,12 @@ import yaml import time from typing import Tuple + class dbConn(): def __init__(self): # 读取参数,并支持中文 conn = sqlite3.connect("data.db") - conn.text_factory=str + conn.text_factory = str self.conn = conn c = conn.cursor() c.execute( @@ -44,7 +45,7 @@ class dbConn(): ); ''' ) - + conn.commit() def insert_session(self, qq_id, history): @@ -76,7 +77,7 @@ class dbConn(): ''', (qq_id, ) ) return c.fetchone() - + def get_all_session(self): conn = self.conn c = conn.cursor() @@ -86,7 +87,7 @@ class dbConn(): ''' ) return c.fetchall() - + def check_session(self, qq_id): conn = self.conn c = conn.cursor() @@ -107,7 +108,6 @@ class dbConn(): ) conn.commit() - def increment_stat_session(self, platform, session_id, cnt): # if not exist, insert conn = self.conn @@ -137,7 +137,7 @@ class dbConn(): ''', (platform, session_id) ) return c.fetchone() is not None - + def get_all_stat_session(self): conn = self.conn c = conn.cursor() @@ -147,7 +147,7 @@ class dbConn(): ''' ) return c.fetchall() - + def get_session_cnt_total(self): conn = self.conn c = conn.cursor() @@ -157,7 +157,7 @@ class dbConn(): ''' ) return c.fetchone()[0] - + def increment_stat_message(self, ts, cnt): # 以一个小时为单位。ts的单位是秒。 # 找到最近的一个小时,如果没有,就插入 @@ -197,7 +197,7 @@ class dbConn(): return True, ts else: return False, ts - + def get_last_24h_stat_message(self): # 获取最近24小时的消息统计 conn = self.conn @@ -208,7 +208,7 @@ class dbConn(): ''', (time.time() - 86400, ) ) return c.fetchall() - + def get_message_cnt_total(self) -> int: conn = self.conn c = conn.cursor() @@ -258,7 +258,7 @@ class dbConn(): return True, ts else: return False, ts - + def get_last_24h_stat_platform(self): # 获取最近24小时的消息统计 conn = self.conn @@ -269,7 +269,7 @@ class dbConn(): ''', (time.time() - 86400, ) ) return c.fetchall() - + def get_platform_cnt_total(self) -> int: conn = self.conn c = conn.cursor() @@ -291,4 +291,3 @@ class dbConn(): def close(self): self.conn.close() - \ No newline at end of file diff --git a/main.py b/main.py index b9b19e5f0..9e7bdff1d 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ -import os, sys +import os +import sys from pip._internal import main as pipmain import warnings import traceback @@ -7,12 +8,13 @@ import threading warnings.filterwarnings("ignore") abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' + def main(): # config.yaml 配置文件加载和环境确认 try: import cores.astrbot.core as qqBot import yaml - ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8') + ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8') cfg = yaml.safe_load(ymlfile) except ImportError as import_error: traceback.print_exc() @@ -23,13 +25,13 @@ def main(): input("配置文件不存在,请检查是否已经下载配置文件。") except BaseException as e: raise e - + # 设置代理 if 'http_proxy' in cfg and cfg['http_proxy'] != '': os.environ['HTTP_PROXY'] = cfg['http_proxy'] if 'https_proxy' in cfg and cfg['https_proxy'] != '': os.environ['HTTPS_PROXY'] = cfg['https_proxy'] - + os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com' # 检查并创建 temp 文件夹 @@ -43,27 +45,30 @@ def main(): # 启动主程序(cores/qqbot/core.py) qqBot.init(cfg) + def check_env(ch_mirror=False): if not (sys.version_info.major == 3 and sys.version_info.minor >= 9): print("请使用Python3.9+运行本项目") input("按任意键退出...") exit() - + if os.path.exists('requirements.txt'): pth = 'requirements.txt' else: - pth = 'QQChannelChatGPT'+ os.sep +'requirements.txt' + pth = 'QQChannelChatGPT' + os.sep + 'requirements.txt' print("正在检查或下载第三方库,请耐心等待...") try: if ch_mirror: print("使用阿里云镜像") - pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/']) + pipmain(['install', '-r', pth, '-i', + 'https://mirrors.aliyun.com/pypi/simple/']) else: pipmain(['install', '-r', pth]) except BaseException as e: print(e) while True: - res = input("安装失败。\n如报错ValueError: check_hostname requires server_hostname,请尝试先关闭代理后重试。\n1.输入y回车重试\n2. 输入c回车使用国内镜像源下载\n3. 输入其他按键回车继续往下执行。") + res = input( + "安装失败。\n如报错ValueError: check_hostname requires server_hostname,请尝试先关闭代理后重试。\n1.输入y回车重试\n2. 输入c回车使用国内镜像源下载\n3. 输入其他按键回车继续往下执行。") if res == "y": try: pipmain(['install', '-r', pth]) @@ -73,7 +78,8 @@ def check_env(ch_mirror=False): continue elif res == "c": try: - pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/']) + pipmain(['install', '-r', pth, '-i', + 'https://mirrors.aliyun.com/pypi/simple/']) break except BaseException as e: print(e) @@ -82,6 +88,7 @@ def check_env(ch_mirror=False): break print("第三方库检查完毕。") + if __name__ == "__main__": args = sys.argv @@ -89,7 +96,7 @@ if __name__ == "__main__": check_env(True) else: check_env() - + t = threading.Thread(target=main, daemon=True) t.start() try: diff --git a/model/command/command.py b/model/command/command.py index d52835a0e..55e8843ec 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -29,18 +29,20 @@ PLATFORM_QQCHAN = 'qqchan' PLATFORM_GOCQ = 'gocq' # 指令功能的基类,通用的(不区分语言模型)的指令就在这实现 + + class Command: def __init__(self, provider: Provider, global_object: GlobalObject = None): self.provider = provider self.global_object = global_object self.logger: Logger = global_object.logger - async def check_command(self, - message, - session_id: str, - role: str, - platform: RegisteredPlatform, - message_obj): + async def check_command(self, + message, + session_id: str, + role: str, + platform: RegisteredPlatform, + message_obj): self.platform = platform # 插件 cached_plugins = self.global_object.cached_plugins @@ -51,7 +53,7 @@ class Command: platform=platform, role=role, context=self.global_object, - session_id = session_id + session_id=session_id ) # 从已启动的插件中查找是否有匹配的指令 for plugin in cached_plugins: @@ -83,9 +85,11 @@ class Command: if hit: return True, res except BaseException as e: - self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) + self.logger.log( + f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) except BaseException as e: - self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) + self.logger.log( + f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) if self.command_start_with(message, "nick"): return True, self.set_nick(message, platform, role) @@ -93,13 +97,13 @@ class Command: return True, self.plugin_oper(message, role, cached_plugins, platform) if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"): return True, self.get_my_id(message_obj, platform) - if self.command_start_with(message, "web"): # 网页搜索 + if self.command_start_with(message, "web"): # 网页搜索 return True, self.web_search(message) if self.command_start_with(message, "update"): return True, self.update(message, role) if not self.provider and self.command_start_with(message, "help"): return True, await self.help() - + return False, None def web_search(self, message): @@ -126,16 +130,19 @@ class Command: l = message.split(" ") if len(l) <= 1: obj = cc.get_all() - p = gu.create_text_image("【cmd_config.json】", json.dumps(obj, indent=4, ensure_ascii=False)) + p = gu.create_text_image("【cmd_config.json】", json.dumps( + obj, indent=4, ensure_ascii=False)) return True, [Image.fromFileSystem(p)], "newconf" - + ''' 插件指令 ''' + def plugin_oper(self, message: str, role: str, cached_plugins: List[RegisteredPlugin], platform: str): l = message.split(" ") if len(l) < 2: - p = gu.create_text_image("【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n") + p = gu.create_text_image( + "【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n") return True, [Image.fromFileSystem(p)], "plugin" else: if l[1] == "i": @@ -165,7 +172,8 @@ class Command: plugin_list_info = "" for plugin in cached_plugins: plugin_list_info += f"{plugin.metadata.plugin_name}: \n名称: {plugin.metadata.plugin_name}\n简介: {plugin.metadata.plugin_desc}\n版本: {plugin.metadata.version}\n作者: {plugin.metadata.author}\n" - p = gu.create_text_image("【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n") + p = gu.create_text_image( + "【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n") return True, [Image.fromFileSystem(p)], "plugin" except BaseException as e: return False, f"获取插件列表失败,原因: {str(e)}", "plugin" @@ -177,7 +185,8 @@ class Command: info = i.metadata break if info: - p = gu.create_text_image(f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}") + p = gu.create_text_image( + f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}") return True, [Image.fromFileSystem(p)], "plugin" else: return False, "未找到该插件", "plugin" @@ -187,6 +196,7 @@ class Command: ''' nick: 存储机器人的昵称 ''' + def set_nick(self, message: str, platform: str, role: str = "member"): if role != "admin": return True, "你无权使用该指令 :P", "nick" @@ -213,7 +223,7 @@ class Command: "reset": "重置 LLM 对话", "/gpt": "切换到 OpenAI 官方接口" } - + async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None): try: async with aiohttp.ClientSession() as session: @@ -240,7 +250,7 @@ class Command: except BaseException as e: self.logger.log(str(e)) return msg - + def command_start_with(self, message: str, *args): ''' 当消息以指定的指令开头时返回True @@ -249,7 +259,7 @@ class Command: if message.startswith(arg) or message.startswith('/'+arg): return True return False - + def update(self, message: str, role: str): if role != "admin": return True, "你没有权限使用该指令", "update" @@ -274,8 +284,10 @@ class Command: else: if l[1].lower().startswith('v'): try: - release_data = util.updator.request_release_info(latest=False) - util.updator.update_project(release_data, latest=False, version=l[1]) + release_data = util.updator.request_release_info( + latest=False) + util.updator.update_project( + release_data, latest=False, version=l[1]) return True, "更新成功,重启生效。可输入「update r」重启", "update" except BaseException as e: return False, "更新失败: "+str(e), "update" @@ -284,28 +296,28 @@ class Command: def reset(self): return False - + def set(self): return False - + def unset(self): return False - + def key(self): return False - + async def help(self): ret = await self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins) return True, ret, "help" - + def status(self): return False - + def token(self): return False - + def his(self): return False - + def draw(self): - return False \ No newline at end of file + return False diff --git a/model/command/openai_official.py b/model/command/openai_official.py index 2eada8782..3edf530e7 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -3,21 +3,22 @@ from model.provider.openai_official import ProviderOpenAIOfficial from util.personality import personalities from cores.astrbot.types import GlobalObject + class CommandOpenAIOfficial(Command): def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject): self.provider = provider self.global_object = global_object self.personality_str = "" super().__init__(provider, global_object) - - async def check_command(self, - message: str, - session_id: str, - role: str, - platform: str, - message_obj): + + async def check_command(self, + message: str, + session_id: str, + role: str, + platform: str, + message_obj): self.platform = platform - + # 检查基础指令 hit, res = await super().check_command( message, @@ -26,7 +27,7 @@ class CommandOpenAIOfficial(Command): platform, message_obj ) - + # 这里是这个 LLM 的专属指令 if hit: return True, res @@ -54,9 +55,9 @@ class CommandOpenAIOfficial(Command): return True, self.key(message) elif self.command_start_with(message, "switch"): return True, await self.switch(message) - + return False, None - + async def help(self): commands = super().general_commands() commands['画'] = '画画' @@ -67,7 +68,6 @@ class CommandOpenAIOfficial(Command): commands['token'] = '查看本轮会话token' return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" - async def reset(self, session_id: str, message: str = "reset"): if self.provider is None: return False, "未启用 OpenAI 官方 API", "reset" @@ -78,13 +78,13 @@ class CommandOpenAIOfficial(Command): if len(l) == 2 and l[1] == "p": self.provider.forget(session_id) if self.personality_str != "": - self.set(self.personality_str, session_id) # 重新设置人格 + self.set(self.personality_str, session_id) # 重新设置人格 return True, "重置成功", "reset" - + def his(self, message: str, session_id: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "his" - #分页,每页5条 + # 分页,每页5条 msg = '' size_per_page = 3 page = 1 @@ -95,10 +95,12 @@ class CommandOpenAIOfficial(Command): msg = f"历史记录为空" return True, msg, "his" l = self.provider.session_dict[session_id] - max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page - p = self.provider.get_prompts_by_cache_list(self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page) + max_page = len(l)//size_per_page + \ + 1 if len(l) % size_per_page != 0 else len(l)//size_per_page + p = self.provider.get_prompts_by_cache_list( + self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page) return True, f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页", "his" - + def token(self, session_id: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "token" @@ -108,7 +110,7 @@ class CommandOpenAIOfficial(Command): if self.provider is None: return False, "未启用 OpenAI 官方 API", "gpt" return True, f"OpenAI GPT配置:\n {self.provider.chatGPT_configs}", "gpt" - + def status(self): if self.provider is None: return False, "未启用 OpenAI 官方 API", "status" @@ -255,7 +257,7 @@ class CommandOpenAIOfficial(Command): self.provider.session_dict[session_id].append(new_record) self.personality_str = message return True, f"自定义人格已设置。 \n人格信息: {ps}", "set" - + async def draw(self, message): if self.provider is None: return False, "未启用 OpenAI 官方 API", "draw" @@ -270,4 +272,4 @@ class CommandOpenAIOfficial(Command): except Exception as e: if 'exceeded' in str(e): return f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)" - return False, f"图片生成失败: {e}", "draw" \ No newline at end of file + return False, f"图片生成失败: {e}", "draw" diff --git a/model/platform/_message_parse.py b/model/platform/_message_parse.py index a4f05b2c3..a55f0beb0 100644 --- a/model/platform/_message_parse.py +++ b/model/platform/_message_parse.py @@ -10,9 +10,11 @@ from typing import List, Union import time # QQ官方消息类型转换 + + def qq_official_message_parse(message: List[BaseMessageComponent]): plain_text = "" - image_path = None # only one img supported + image_path = None # only one img supported for i in message: if isinstance(i, Plain): plain_text += i.text @@ -24,6 +26,8 @@ def qq_official_message_parse(message: List[BaseMessageComponent]): return plain_text, image_path # QQ官方消息类型 2 AstrBotMessage + + def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.message.GroupMessage], message_type: MessageType) -> AstrBotMessage: abm = AstrBotMessage() @@ -33,7 +37,7 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me abm.message_id = message.id abm.tag = "qqchan" msg: List[BaseMessageComponent] = [] - + if message_type == MessageType.GROUP_MESSAGE: abm.sender = MessageMember( message.author.member_openid, @@ -41,7 +45,7 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me ) abm.message_str = message.content.strip() abm.self_id = "unknown_selfid" - + msg.append(Plain(abm.message_str)) if message.attachments: for i in message.attachments: @@ -52,15 +56,16 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me img = Image.fromURL(url) msg.append(img) abm.message = msg - + elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE: # 目前对于 FRIEND_MESSAGE 只处理频道私聊 try: abm.self_id = str(message.mentions[0].id) except: abm.self_id = "" - - plain_content = message.content.replace("<@!"+str(abm.self_id)+">", "").strip() + + plain_content = message.content.replace( + "<@!"+str(abm.self_id)+">", "").strip() msg.append(Plain(plain_content)) if message.attachments: for i in message.attachments: @@ -80,19 +85,20 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me raise ValueError(f"Unknown message type: {message_type}") return abm + def nakuru_message_parse_rev(message: Union[GuildMessage, GroupMessage, FriendMessage]) -> AstrBotMessage: abm = AstrBotMessage() abm.type = MessageType(message.type) abm.timestamp = int(time.time()) abm.raw_message = message abm.message_id = message.message_id - + plain_content = "" for i in message.message: if isinstance(i, Plain): plain_content += i.text abm.message_str = plain_content - + abm.self_id = str(message.self_id) abm.sender = MessageMember( str(message.sender.user_id), @@ -100,5 +106,5 @@ def nakuru_message_parse_rev(message: Union[GuildMessage, GroupMessage, FriendMe ) abm.tag = "gocq" abm.message = message.message - - return abm \ No newline at end of file + + return abm diff --git a/model/platform/_message_result.py b/model/platform/_message_result.py index 4fd601951..c4421e1c9 100644 --- a/model/platform/_message_result.py +++ b/model/platform/_message_result.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Union, Optional + @dataclass class MessageResult(): result_message: Union[str, list] diff --git a/model/platform/_platfrom.py b/model/platform/_platfrom.py index bf47937f3..06e6f7d22 100644 --- a/model/platform/_platfrom.py +++ b/model/platform/_platfrom.py @@ -43,7 +43,7 @@ class Platform(): 发送消息(主动发送)同 send_msg() ''' pass - + def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str: ''' 将消息解析成大纲消息形式。 diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py index 16fc4444c..703789b70 100644 --- a/model/platform/qq_gocq.py +++ b/model/platform/qq_gocq.py @@ -23,7 +23,7 @@ class FakeSource: self.type = type self.group_id = group_id - + class QQGOCQ(Platform): def __init__(self, cfg: dict, message_handler: callable, global_object) -> None: super().__init__(message_handler) @@ -39,11 +39,11 @@ class QQGOCQ(Platform): try: self.nick_qq = cfg['nick_qq'] except: - self.nick_qq = ["ai","!","!"] + self.nick_qq = ["ai", "!", "!"] nick_qq = self.nick_qq if isinstance(nick_qq, str): nick_qq = [nick_qq] - + self.unique_session = cfg['uniqueSessionMode'] self.pic_mode = cfg['qq_pic_mode'] @@ -67,7 +67,7 @@ class QQGOCQ(Platform): await self.handle_msg(abm) else: return - + @gocq_app.receiver("FriendMessage") async def _(app: CQHTTP, source: FriendMessage): if self.cc.get("gocq_react_friend", True): @@ -76,12 +76,12 @@ class QQGOCQ(Platform): await self.handle_msg(abm) else: return - + @gocq_app.receiver("GroupMemberIncrease") async def _(app: CQHTTP, source: GroupMemberIncrease): if self.cc.get("gocq_react_group_increase", True): await app.sendGroupMessage(source.group_id, [ - Plain(text = self.announcement) + Plain(text=self.announcement) ]) # @gocq_app.receiver("Notify") @@ -101,16 +101,18 @@ class QQGOCQ(Platform): await self.handle_msg(abm) else: return - + def run(self): self.client.run() - + async def handle_msg(self, message: AstrBotMessage): - self.logger.log(f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}", tag="QQ_GOCQ") - - assert isinstance(message.raw_message, (GroupMessage, FriendMessage, GuildMessage)) + self.logger.log( + f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}", tag="QQ_GOCQ") + + assert isinstance(message.raw_message, + (GroupMessage, FriendMessage, GuildMessage)) is_group = message.type != MessageType.FRIEND_MESSAGE - + # 判断是否响应消息 resp = False if not is_group: @@ -132,9 +134,10 @@ class QQGOCQ(Platform): if nick != '' and i.text.strip().startswith(nick): resp = True break - - if not resp: return - + + if not resp: + return + # 解析 session_id if self.unique_session or not is_group: session_id = message.raw_message.user_id @@ -144,13 +147,13 @@ class QQGOCQ(Platform): session_id = message.raw_message.channel_id else: session_id = message.raw_message.user_id - + message.session_id = session_id # 解析 role sender_id = str(message.raw_message.user_id) if sender_id == self.cc.get('admin_qq', '') or \ - sender_id in self.cc.get('other_admins', []): + sender_id in self.cc.get('other_admins', []): role = 'admin' else: role = 'member' @@ -167,7 +170,7 @@ class QQGOCQ(Platform): await self.reply_msg(message, message_result.result_message) if message_result.callback is not None: message_result.callback() - + # 如果是等待回复的消息 if session_id in self.waiting and self.waiting[session_id] == '': self.waiting[session_id] = message @@ -182,14 +185,15 @@ class QQGOCQ(Platform): source = message.raw_message else: source = message - + res = result_message - - self.logger.log(f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ") + + self.logger.log( + f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ") if isinstance(source, int): source = FakeSource("GroupMessage", source) - + # str convert to CQ Message Chain if isinstance(res, str): res_str = res @@ -241,7 +245,7 @@ class QQGOCQ(Platform): node.name = f"bot" node.time = int(time.time()) # print(node) - nodes=[node] + nodes = [node] await self.client.sendGroupForwardMessage(source.group_id, nodes) return await self.client.sendGroupMessage(source.group_id, res) @@ -256,10 +260,10 @@ class QQGOCQ(Platform): await self.reply_msg(message, result_message) except BaseException as e: raise e - - async def send(self, - to, - res): + + async def send(self, + to, + res): ''' 同 send_msg() ''' @@ -311,4 +315,3 @@ class QQGOCQ(Platform): return ret except BaseException as e: raise e - \ No newline at end of file diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 87fa4337e..b3e6d1bdf 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -11,7 +11,7 @@ from botpy.types.message import Reference from botpy import Client import time from ._platfrom import Platform -from ._message_parse import( +from ._message_parse import ( qq_official_message_parse_rev, qq_official_message_parse ) @@ -20,10 +20,12 @@ from typing import Union, List from nakuru.entities.components import BaseMessageComponent # QQ 机器人官方框架 + + class botClient(Client): def set_platform(self, platform: 'QQOfficial'): self.platform = platform - + async def on_group_at_message_create(self, message: botpy.message.GroupMessage): abm = qq_official_message_parse_rev(message, MessageType.GROUP_MESSAGE) await self.platform.handle_msg(abm) @@ -37,9 +39,11 @@ class botClient(Client): # 收到私聊消息 async def on_direct_message_create(self, message: botpy.message.DirectMessage): # 转换层 - abm = qq_official_message_parse_rev(message, MessageType.FRIEND_MESSAGE) + abm = qq_official_message_parse_rev( + message, MessageType.FRIEND_MESSAGE) await self.platform.handle_msg(abm) + class QQOfficial(Platform): def __init__(self, cfg: dict, message_handler: callable, global_object) -> None: @@ -47,7 +51,7 @@ class QQOfficial(Platform): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - + self.waiting: dict = {} self.cfg = cfg @@ -57,7 +61,7 @@ class QQOfficial(Platform): self.unique_session = cfg['uniqueSessionMode'] self.logger: gu.Logger = global_object.logger qq_group = cfg['qqofficial_enable_group_message'] - + if qq_group: self.intents = botpy.Intents( public_messages=True, @@ -79,7 +83,7 @@ class QQOfficial(Platform): def run(self): try: self.loop.run_until_complete(self.client.run( - appid=self.appid, + appid=self.appid, secret=self.secret )) except BaseException as e: @@ -90,17 +94,19 @@ class QQOfficial(Platform): ) self.client.set_platform(self) self.client.run( - appid=self.appid, + appid=self.appid, token=self.token ) async def handle_msg(self, message: AstrBotMessage): - assert isinstance(message.raw_message, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage)) + assert isinstance(message.raw_message, (botpy.message.Message, + botpy.message.GroupMessage, botpy.message.DirectMessage)) is_group = message.type != MessageType.FRIEND_MESSAGE - + _t = "/私聊" if not is_group else "" - self.logger.log(f"{message.sender.nickname}({message.sender.user_id}{_t}) -> {self.parse_message_outline(message)}", tag="QQ_OFFICIAL") - + self.logger.log( + f"{message.sender.nickname}({message.sender.user_id}{_t}) -> {self.parse_message_outline(message)}", tag="QQ_OFFICIAL") + # 解析出 session_id if self.unique_session or not is_group: session_id = message.sender.user_id @@ -116,7 +122,7 @@ class QQOfficial(Platform): # 解析出 role sender_id = message.sender.user_id if sender_id == self.cfg['admin_qqchan'] or \ - sender_id in self.cfg['other_admins']: + sender_id in self.cfg['other_admins']: role = 'admin' else: role = 'member' @@ -139,18 +145,20 @@ class QQOfficial(Platform): if session_id in self.waiting and self.waiting[session_id] == '': self.waiting[session_id] = message - async def reply_msg(self, - message: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], - res: Union[str, list]): + async def reply_msg(self, + message: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], + res: Union[str, list]): ''' 回复频道消息 ''' if isinstance(message, AstrBotMessage): source = message.raw_message else: - source = message - assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage)) - self.logger.log(f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL") + source = message + assert isinstance(source, (botpy.message.Message, + botpy.message.GroupMessage, botpy.message.DirectMessage)) + self.logger.log( + f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL") plain_text = '' image_path = '' @@ -160,7 +168,7 @@ class QQOfficial(Platform): plain_text, image_path = qq_official_message_parse(res) elif isinstance(res, str): plain_text = res - + if self.cfg['qq_pic_mode']: # 文本转图片,并且加上原来的图片 if plain_text != '' or image_path != '': @@ -168,7 +176,8 @@ class QQOfficial(Platform): if image_path.startswith("http"): plain_text += "\n\n" + "![](" + image_path + ")" else: - plain_text += "\n\n" + "![](file:///" + image_path + ")" + plain_text += "\n\n" + \ + "![](file:///" + image_path + ")" image_path = gu.create_markdown_image("".join(plain_text)) plain_text = "" @@ -182,9 +191,10 @@ class QQOfficial(Platform): image = PILImage.open(io.BytesIO(await response.read())) image_path = gu.save_temp_img(image) - if source is not None and image_path == '': # file_image与message_reference不能同时传入 - msg_ref = Reference(message_id=source.id, ignore_get_message_error=False) - + if source is not None and image_path == '': # file_image与message_reference不能同时传入 + msg_ref = Reference(message_id=source.id, + ignore_get_message_error=False) + # 到这里,我们得到了 plain_text,image_path,msg_ref data = { 'content': plain_text, @@ -210,7 +220,7 @@ class QQOfficial(Platform): # 分割过长的消息 if "msg over length" in str(e): split_res = [] - split_res.append(plain_text[:len(plain_text)//2]) + split_res.append(plain_text[:len(plain_text)//2]) split_res.append(plain_text[len(plain_text)//2:]) for i in split_res: data['content'] = i @@ -227,11 +237,12 @@ class QQOfficial(Platform): data['content'] = str.join(" ", plain_text) await self._send_wrapper(**data) except BaseException as e: - plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) + plain_text = re.sub( + r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) plain_text = plain_text.replace(".", "·") data['content'] = plain_text await self._send_wrapper(**data) - + async def _send_wrapper(self, **kwargs): if 'group_openid' in kwargs: # QQ群组消息 @@ -248,27 +259,29 @@ class QQOfficial(Platform): elif 'channel_id' in kwargs: # 频道消息 if 'file_image' in kwargs: - kwargs['file_image'] = kwargs['file_image'].replace("file://", "") + kwargs['file_image'] = kwargs['file_image'].replace( + "file://", "") await self.client.api.post_message(**kwargs) else: # 频道私聊消息 if 'file_image' in kwargs: - kwargs['file_image'] = kwargs['file_image'].replace("file://", "") + kwargs['file_image'] = kwargs['file_image'].replace( + "file://", "") await self.client.api.post_dms(**kwargs) async def send_msg(self, message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], message_chain: List[BaseMessageComponent], - ): + ): ''' 发送消息。目前只支持被动回复消息(即拥有一个 botpy Message 类型的 message_obj 传入) ''' await self.reply_msg(message_obj, message_chain) async def send(self, - message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], - message_chain: List[BaseMessageComponent], - ): + message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], + message_chain: List[BaseMessageComponent], + ): ''' 发送消息。目前只支持被动回复消息(即拥有一个 botpy Message 类型的 message_obj 传入) ''' diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 6ce008785..20b6a5c6f 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -17,9 +17,9 @@ from util.cmd_config import CmdConfig from util.general_utils import Logger - abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' + class ProviderOpenAIOfficial(Provider): def __init__(self, cfg): self.cc = CmdConfig() @@ -29,12 +29,13 @@ class ProviderOpenAIOfficial(Provider): # 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错 for key in cfg['key']: if len(key) == 1: - raise BaseException("检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。") + raise BaseException( + "检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。") if cfg['key'] != '' and cfg['key'] != None: self.key_list = cfg['key'] if len(self.key_list) == 0: raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。") - + self.key_stat = {} for k in self.key_list: self.key_stat[k] = {'exceed': False, 'used': 0} @@ -43,7 +44,7 @@ class ProviderOpenAIOfficial(Provider): if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '': self.api_base = cfg['api_base'] self.logger.log(f"设置 api_base 为: {self.api_base}", tag="OpenAI") - + # 创建 OpenAI Client self.client = AsyncOpenAI( api_key=self.key_list[0], @@ -51,7 +52,8 @@ class ProviderOpenAIOfficial(Provider): ) self.openai_model_configs: dict = cfg['chatGPTConfigs'] - self.logger.log(f'加载 OpenAI Chat Configs: {self.openai_model_configs}', tag="OpenAI") + self.logger.log( + f'加载 OpenAI Chat Configs: {self.openai_model_configs}', tag="OpenAI") self.openai_configs = cfg # 会话缓存 self.session_dict = {} @@ -69,7 +71,8 @@ class ProviderOpenAIOfficial(Provider): self.session_dict[session[0]] = json.loads(session[1])['data'] self.logger.log("读取历史记录成功。", tag="OpenAI") except BaseException as e: - self.logger.log("读取历史记录失败,但不影响使用。", level=gu.LEVEL_ERROR, tag="OpenAI") + self.logger.log("读取历史记录失败,但不影响使用。", + level=gu.LEVEL_ERROR, tag="OpenAI") # 创建转储定时器线程 threading.Thread(target=self.dump_history, daemon=True).start() @@ -116,20 +119,20 @@ class ProviderOpenAIOfficial(Provider): } self.session_dict[session_id].append(new_record) - async def text_chat(self, prompt, - session_id = None, - image_url = None, - function_call=None, - extra_conf: dict = None, - default_personality: dict = None): + async def text_chat(self, prompt, + session_id=None, + image_url=None, + function_call=None, + extra_conf: dict = None, + default_personality: dict = None): if session_id is None: session_id = "unknown" if "unknown" in self.session_dict: - del self.session_dict["unknown"] + del self.session_dict["unknown"] # 会话机制 if session_id not in self.session_dict: self.session_dict[session_id] = [] - + if len(self.session_dict[session_id]) == 0: # 设置默认人格 if default_personality is not None: @@ -138,12 +141,17 @@ class ProviderOpenAIOfficial(Provider): # 使用 tictoken 截断消息 _encoded_prompt = self.enc.encode(prompt) if self.openai_model_configs['max_tokens'] < len(_encoded_prompt): - prompt = self.enc.decode(_encoded_prompt[:int(self.openai_model_configs['max_tokens']*0.80)]) - self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, tag="OpenAI") + prompt = self.enc.decode(_encoded_prompt[:int( + self.openai_model_configs['max_tokens']*0.80)]) + self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", + level=gu.LEVEL_WARNING, tag="OpenAI") - cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url) - self.logger.log(f"cache: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI") - self.logger.log(f"request: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI") + cache_data_list, new_record, req = self.wrap( + prompt, session_id, image_url) + self.logger.log(f"cache: {str(cache_data_list)}", + level=gu.LEVEL_DEBUG, tag="OpenAI") + self.logger.log(f"request: {str(req)}", + level=gu.LEVEL_DEBUG, tag="OpenAI") retry = 0 response = None err = '' @@ -177,7 +185,7 @@ class ProviderOpenAIOfficial(Provider): else: response = await self.client.chat.completions.create( messages=req, - tools = function_call, + tools=function_call, **conf ) break @@ -186,7 +194,8 @@ class ProviderOpenAIOfficial(Provider): if 'Invalid content type. image_url is only supported by certain models.' in str(e): raise e if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - self.logger.log("当前 Key 已超额或异常, 正在切换", level=gu.LEVEL_WARNING, tag="OpenAI") + self.logger.log("当前 Key 已超额或异常, 正在切换", + level=gu.LEVEL_WARNING, tag="OpenAI") self.key_stat[self.client.api_key]['exceed'] = True is_switched = self.handle_switch_key() if not is_switched: @@ -197,7 +206,8 @@ class ProviderOpenAIOfficial(Provider): self.session_dict[session_id] = [] prompt = prompt[:int(len(prompt)*truncate_rate)] truncate_rate -= 0.05 - cache_data_list, new_record, req = self.wrap(prompt, session_id) + cache_data_list, new_record, req = self.wrap( + prompt, session_id) elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e): time.sleep(30) @@ -208,10 +218,12 @@ class ProviderOpenAIOfficial(Provider): err = str(e) retry += 1 if retry >= 10: - self.logger.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki", tag="OpenAI") + self.logger.log( + r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki", tag="OpenAI") raise BaseException("连接出错: "+str(err)) assert isinstance(response, ChatCompletion) - self.logger.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, tag="OpenAI") + self.logger.log( + f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, tag="OpenAI") # 结果分类 choice = response.choices[0] @@ -248,7 +260,8 @@ class ProviderOpenAIOfficial(Provider): } new_record['usage_tokens'] = current_usage_tokens if len(cache_data_list) > 0: - new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens']) + new_record['single_tokens'] = current_usage_tokens - \ + int(cache_data_list[-1]['usage_tokens']) else: new_record['single_tokens'] = current_usage_tokens @@ -257,13 +270,13 @@ class ProviderOpenAIOfficial(Provider): self.session_dict[session_id] = cache_data_list return chatgpt_res - - async def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"): + + async def image_chat(self, prompt, img_num=1, img_size="1024x1024"): retry = 0 image_url = '' image_generate_configs = self.cc.get("openai_image_generate", None) - + while retry < 5: try: response: ImagesResponse = await self.client.images.generate( @@ -278,27 +291,29 @@ class ProviderOpenAIOfficial(Provider): self.logger.log(str(e), level=gu.LEVEL_ERROR) if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str( e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - self.logger.log("当前 Key 已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING, tag="OpenAI") + self.logger.log("当前 Key 已超额或者不正常, 正在切换", + level=gu.LEVEL_WARNING, tag="OpenAI") self.key_stat[self.client.api_key]['exceed'] = True is_switched = self.handle_switch_key() if not is_switched: raise e elif 'Your request was rejected as a result of our safety system.' in str(e): - self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING, tag="OpenAI") + self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", + level=gu.LEVEL_WARNING, tag="OpenAI") raise e else: retry += 1 if retry >= 5: raise BaseException("连接超时") - + return image_url - async def forget(self, session_id = None) -> bool: + async def forget(self, session_id=None) -> bool: if session_id is None: return False self.session_dict[session_id] = [] return True - + def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1): ''' 获取缓存的会话 @@ -313,14 +328,16 @@ class ProviderOpenAIOfficial(Provider): page_end = len(cache_data_list) cache_data_list = cache_data_list[page_begin:page_end] for item in cache_data_list: - prompts += str(item['user']['role']) + ":\n" + str(item['user']['content']) + "\n" - prompts += str(item['AI']['role']) + ":\n" + str(item['AI']['content']) + "\n" + prompts += str(item['user']['role']) + ":\n" + \ + str(item['user']['content']) + "\n" + prompts += str(item['AI']['role']) + ":\n" + \ + str(item['AI']['content']) + "\n" if divide: prompts += "----------\n" return prompts - def wrap(self, prompt, session_id, image_url = None): + def wrap(self, prompt, session_id, image_url=None): if image_url is not None: prompt = [ { @@ -353,7 +370,7 @@ class ProviderOpenAIOfficial(Provider): req_list.append(i['AI']) req_list.append(new_record['user']) return context, new_record, req_list - + def handle_switch_key(self): is_all_exceed = True for key in self.key_stat: @@ -361,28 +378,30 @@ class ProviderOpenAIOfficial(Provider): continue is_all_exceed = False self.client.api_key = key - self.logger.log(f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})", level=gu.LEVEL_INFO, tag="OpenAI") + self.logger.log( + f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})", level=gu.LEVEL_INFO, tag="OpenAI") break if is_all_exceed: - self.logger.log("所有 Key 已超额", level=gu.LEVEL_CRITICAL, tag="OpenAI") + self.logger.log( + "所有 Key 已超额", level=gu.LEVEL_CRITICAL, tag="OpenAI") return False return True - + def get_configs(self): return self.openai_configs - + def get_key_stat(self): return self.key_stat - + def get_key_list(self): return self.key_list - + def get_curr_key(self): return self.client.api_key - + def set_key(self, key): self.client.api_key = key - + # 添加key def append_key(self, key, sponsor): self.key_list.append(key) diff --git a/model/provider/provider.py b/model/provider/provider.py index 69e3e3fe5..f2a80202c 100644 --- a/model/provider/provider.py +++ b/model/provider/provider.py @@ -1,7 +1,7 @@ class Provider: - async def text_chat(self, - prompt: str, - session_id: str, + async def text_chat(self, + prompt: str, + session_id: str, image_url: None, function_call: None, extra_conf: dict = None, @@ -11,7 +11,7 @@ class Provider: [require] prompt: 提示词 session_id: 会话id - + [optional] image_url: 图片url(识图) function_call: 函数调用 @@ -19,7 +19,7 @@ class Provider: default_personality: 默认人格 ''' raise NotImplementedError - + async def image_generate(self, prompt, session_id, **kwargs) -> str: ''' [require] @@ -28,8 +28,8 @@ class Provider: ''' raise NotImplementedError - async def forget(self, session_id = None) -> bool: + async def forget(self, session_id=None) -> bool: ''' 重置会话 ''' - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/util/cmd_config.py b/util/cmd_config.py index b9edbb5a4..ea05ed74d 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -4,12 +4,14 @@ from typing import Union cpath = "cmd_config.json" + def check_exist(): if not os.path.exists(cpath): with open(cpath, "w", encoding="utf-8-sig") as f: json.dump({}, f, indent=4, ensure_ascii=False) f.flush() + class CmdConfig(): @staticmethod @@ -21,13 +23,13 @@ class CmdConfig(): return d[key] else: return default - + @staticmethod def get_all(): check_exist() with open(cpath, "r", encoding="utf-8-sig") as f: return json.load(f) - + @staticmethod def put(key, value): check_exist() @@ -37,7 +39,7 @@ class CmdConfig(): with open(cpath, "w", encoding="utf-8-sig") as f: json.dump(d, f, indent=4, ensure_ascii=False) f.flush() - + @staticmethod def put_by_dot_str(key: str, value): ''' @@ -58,11 +60,11 @@ class CmdConfig(): f.flush() @staticmethod - def init_attributes(key: Union[str, list], init_val = ""): + def init_attributes(key: Union[str, list], init_val=""): check_exist() conf_str = '' with open(cpath, "r", encoding="utf-8-sig") as f: - conf_str = f.read() + conf_str = f.read() if conf_str.startswith(u'/ufeff'): conf_str = conf_str.encode('utf8')[3:].decode('utf8') d = json.loads(conf_str) @@ -82,11 +84,13 @@ class CmdConfig(): json.dump(d, f, indent=4, ensure_ascii=False) f.flush() + def init_astrbot_config_items(): # 加载默认配置 cc = CmdConfig() cc.init_attributes("qq_forward_threshold", 200) - cc.init_attributes("qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n") + cc.init_attributes( + "qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n") cc.init_attributes("qq_pic_mode", False) cc.init_attributes("gocq_host", "127.0.0.1") cc.init_attributes("gocq_http_port", 5700) @@ -114,4 +118,4 @@ def init_astrbot_config_items(): cc.init_attributes("http_proxy", "") cc.init_attributes("https_proxy", "") cc.init_attributes("dashboard_username", "") - cc.init_attributes("dashboard_password", "") \ No newline at end of file + cc.init_attributes("dashboard_password", "") diff --git a/util/function_calling/func_call.py b/util/function_calling/func_call.py index c5c5bb3fd..a642c0181 100644 --- a/util/function_calling/func_call.py +++ b/util/function_calling/func_call.py @@ -3,30 +3,35 @@ import json import util.general_utils as gu import time + + class FuncCallJsonFormatError(Exception): def __init__(self, msg): self.msg = msg def __str__(self): return self.msg - + + class FuncNotFoundError(Exception): def __init__(self, msg): self.msg = msg def __str__(self): return self.msg - + + class FuncCall(): def __init__(self, provider) -> None: self.func_list = [] self.provider = provider - def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None: + def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj=None) -> None: if name == None or func_args == None or desc == None or func_obj == None: - raise FuncCallJsonFormatError("name, func_args, desc must be provided.") + raise FuncCallJsonFormatError( + "name, func_args, desc must be provided.") params = { - "type": "object", # hardcore here + "type": "object", # hardcore here "properties": {} } for param in func_args: @@ -51,7 +56,7 @@ class FuncCall(): "description": f["description"], }) return json.dumps(_l, indent=intent, ensur_ascii=False) - + def get_func(self) -> list: _l = [] for f in self.func_list: @@ -64,8 +69,8 @@ class FuncCall(): } }) return _l - - def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True, session_id = None): + + def func_call(self, question, func_definition, is_task=False, tasks=None, taskindex=-1, is_summary=True, session_id=None): funccall_prompt = """ 我正实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(意味着你不是创造函数)。 @@ -120,7 +125,8 @@ class FuncCall(): res = self.provider.text_chat(prompt, session_id) if res.find('```') != -1: res = res[res.find('```json') + 7: res.rfind('```')] - gu.log("REVGPT func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + gu.log("REVGPT func_call json result", + bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) print(res) res = json.loads(res) break @@ -151,11 +157,13 @@ class FuncCall(): func_target = func["func_obj"] break if func_target == None: - raise FuncNotFoundError(f"Request function {func_name} not found.") + raise FuncNotFoundError( + f"Request function {func_name} not found.") t_res = str(func_target(**args)) invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n" invoke_func_res_list.append(invoke_func_res) - gu.log(f"[FUNC| {func_name} invoked]", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + gu.log(f"[FUNC| {func_name} invoked]", + bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) # print(str(t_res)) if is_summary: @@ -181,12 +189,16 @@ class FuncCall(): try: res = self.provider.text_chat(after_prompt, session_id) # 截取```之间的内容 - gu.log("DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) + gu.log( + "DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) print(res) - gu.log("DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) + gu.log( + "DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) if res.find('```') != -1: - res = res[res.find('```json') + 7: res.rfind('```')] - gu.log("REVGPT after_func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + res = res[res.find('```json') + + 7: res.rfind('```')] + gu.log("REVGPT after_func_call json result", + bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) after_prompt_res = res after_prompt_res = json.loads(after_prompt_res) break @@ -197,7 +209,8 @@ class FuncCall(): if "The message you submitted was too long" in str(e): # 如果返回的内容太长了,那么就截取一部分 time.sleep(3) - invoke_func_res = invoke_func_res[:int(len(invoke_func_res) / 2)] + invoke_func_res = invoke_func_res[:int( + len(invoke_func_res) / 2)] after_prompt = """ 函数返回以下内容:"""+invoke_func_res+""" 请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。 @@ -218,11 +231,13 @@ class FuncCall(): if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]: # 如果需要重新调用函数 # 重新调用函数 - gu.log("REVGPT func_call_again", bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"]) + gu.log("REVGPT func_call_again", + bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"]) res = self.func_call(question, func_definition) return res, True - gu.log("REVGPT func callback:", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + gu.log("REVGPT func callback:", + bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) # print(after_prompt_res["res"]) return after_prompt_res["res"], True else: @@ -230,8 +245,3 @@ class FuncCall(): else: # print(res["res"]) return res["res"], False - - - - - diff --git a/util/function_calling/gplugin.py b/util/function_calling/gplugin.py index 3e62511af..5ddab894e 100644 --- a/util/function_calling/gplugin.py +++ b/util/function_calling/gplugin.py @@ -9,8 +9,8 @@ from readability import Document from bs4 import BeautifulSoup from openai.types.chat.chat_completion_message_tool_call import Function from util.function_calling.func_call import ( - FuncCall, - FuncCallJsonFormatError, + FuncCall, + FuncCallJsonFormatError, FuncNotFoundError ) from model.provider.provider import Provider @@ -22,6 +22,7 @@ def tidy_text(text: str) -> str: ''' return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") + def special_fetch_zhihu(link: str) -> str: ''' function-calling 函数, 用于获取知乎文章的内容 @@ -43,6 +44,7 @@ def special_fetch_zhihu(link: str) -> str: raise Exception("zhihu none") return tidy_text(r.text) + def google_web_search(keyword) -> str: ''' 获取 google 搜索结果, 得到 title、desc、link @@ -66,6 +68,7 @@ def google_web_search(keyword) -> str: return web_keyword_search_via_bing(keyword) return ret + def web_keyword_search_via_bing(keyword) -> str: ''' 获取bing搜索结果, 得到 title、desc、link @@ -104,7 +107,8 @@ def web_keyword_search_via_bing(keyword) -> str: res += f"# No.{str(result_cnt + 1)}\ntitle: {title}\nurl: {link}\ncontent: {desc}\n\n" result_cnt += 1 - if result_cnt > 5: break + if result_cnt > 5: + break # if len(_detail_store) >= 3: # continue @@ -122,16 +126,18 @@ def web_keyword_search_via_bing(keyword) -> str: except Exception as e: print(f"bing parse err: {str(e)}") - if result_cnt == 0: break + if result_cnt == 0: + break return res except Exception as e: # gu.log(f"bing fetch err: {str(e)}") _cnt += 1 time.sleep(1) - + # gu.log("fail to fetch bing info, using sougou.") return web_keyword_search_via_sougou(keyword) + def web_keyword_search_via_sougou(keyword) -> str: headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ @@ -141,7 +147,7 @@ def web_keyword_search_via_sougou(keyword) -> str: response = requests.get(url, headers=headers) response.encoding = "utf-8" soup = BeautifulSoup(response.text, "html.parser") - + res = [] results = soup.find("div", class_="results") for i in results.find_all("div", class_="vrwrap"): @@ -154,7 +160,7 @@ def web_keyword_search_via_sougou(keyword) -> str: "title": title, "link": link, }) - if len(res) >= 5: # 限制5条 + if len(res) >= 5: # 限制5条 break except Exception as e: pass @@ -173,6 +179,7 @@ def web_keyword_search_via_sougou(keyword) -> str: ret += f"\n网页内容: {str(_detail_store)}" return ret + def fetch_website_content(url): # gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG) headers = { @@ -188,6 +195,7 @@ def fetch_website_content(url): ret = tidy_text(soup.get_text()) return ret + async def web_search(question, provider: Provider, session_id, official_fc=False): ''' official_fc: 使用官方 function-calling @@ -197,17 +205,17 @@ async def web_search(question, provider: Provider, session_id, official_fc=False "type": "string", "name": "keyword", "description": "google search query (分词,尽量保留所有信息)" - }], - "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", - web_keyword_search_via_bing + }], + "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", + web_keyword_search_via_bing ) new_func_call.add_func("fetch_website_content", [{ "type": "string", "name": "url", "description": "网址" - }], - "获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", - fetch_website_content + }], + "获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", + fetch_website_content ) question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。" has_func = False @@ -282,9 +290,11 @@ async def web_search(question, provider: Provider, session_id, official_fc=False except Exception as e: print(e) _c += 1 - if _c == 3: raise e + if _c == 3: + raise e if "The message you submitted was too long" in str(e): await provider.forget(session_id) - function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)] + function_invoked_ret = function_invoked_ret[:int( + len(function_invoked_ret) / 2)] time.sleep(3) return function_invoked_ret diff --git a/util/general_utils.py b/util/general_utils.py index ad00af80e..5b4e072e9 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -63,10 +63,11 @@ level_colors = { "CRITICAL": "purple", } + class Logger: def __init__(self) -> None: self.history = [] - + def log( self, msg: str, @@ -85,7 +86,7 @@ class Logger: if level in level_codes and level_codes[level] < _set_level_code: return - + if err is not None: msg += "\n异常原因: " + str(err) level = LEVEL_ERROR @@ -93,7 +94,7 @@ class Logger: if len(msg) > max_len: msg = msg[:max_len] + "..." now = datetime.datetime.now().strftime("%H:%M:%S") - + pres = [] for line in msg.split("\n"): if line == "\n": @@ -121,12 +122,13 @@ class Logger: fg = FG_COLORS["purple"] if bg is None: bg = BG_COLORS["default"] - + ret = "" for line in pres: ret += f"\033[{fg};{bg}m{line}\033[0m\n" try: - requests.post("http://localhost:6185/api/log", data=ret[:-1].encode(), timeout=1) + requests.post("http://localhost:6185/api/log", + data=ret[:-1].encode(), timeout=1) except BaseException as e: pass self.history.append(ret) @@ -134,10 +136,12 @@ class Logger: self.history = self.history[-100:] print(ret[:-1]) + log = Logger().log + def port_checker(port: int, host: str = "localhost"): - sk = socket.socket(socket.AF_INET,socket.SOCK_STREAM) + sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sk.settimeout(1) try: sk.connect((host, port)) @@ -146,7 +150,8 @@ def port_checker(port: int, host: str = "localhost"): except Exception: sk.close() return False - + + def get_font_path() -> str: if os.path.exists("resources/fonts/syst.otf"): font_path = "resources/fonts/syst.otf" @@ -161,7 +166,8 @@ def get_font_path() -> str: else: raise Exception("找不到字体文件") return font_path - + + def word2img(title: str, text: str, max_width=30, font_size=20): font_path = get_font_path() width_factor = 1.0 @@ -189,19 +195,21 @@ def word2img(title: str, text: str, max_width=30, font_size=20): title_font = ImageFont.truetype(font_path, font_size + 5) # 标题居中 title_width, title_height = title_font.getsize(title) - draw.text(((width - title_width) / 2, 10), title, fill=(0, 0, 0), font=title_font) + draw.text(((width - title_width) / 2, 10), + title, fill=(0, 0, 0), font=title_font) # 文本不居中 draw.text((10, title_height+20), text, fill=(0, 0, 0), font=text_font) return image + def render_markdown(markdown_text, image_width=800, image_height=600, font_size=26, font_color=(0, 0, 0), bg_color=(255, 255, 255)): HEADER_MARGIN = 20 HEADER_FONT_STANDARD_SIZE = 42 QUOTE_LEFT_LINE_MARGIN = 10 - QUOTE_FONT_LINE_MARGIN = 6 # 引用文字距离左边线的距离和上下的距离 + QUOTE_FONT_LINE_MARGIN = 6 # 引用文字距离左边线的距离和上下的距离 QUOTE_LEFT_LINE_HEIGHT = font_size + QUOTE_FONT_LINE_MARGIN * 2 QUOTE_LEFT_LINE_WIDTH = 5 QUOTE_LEFT_LINE_COLOR = (180, 180, 180) @@ -213,9 +221,9 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= CODE_BLOCK_FONT_SIZE = font_size CODE_BLOCK_FONT_COLOR = (255, 255, 255) CODE_BLOCK_BG_COLOR = (240, 240, 240) - CODE_BLOCK_CODES_MARGIN_VERTICAL = 5 # 代码块和代码之间的距离 - CODE_BLOCK_CODES_MARGIN_HORIZONTAL = 5 # 代码块和代码之间的距离 - CODE_BLOCK_TEXT_MARGIN = 4 # 代码和代码之间的距离 + CODE_BLOCK_CODES_MARGIN_VERTICAL = 5 # 代码块和代码之间的距离 + CODE_BLOCK_CODES_MARGIN_HORIZONTAL = 5 # 代码块和代码之间的距离 + CODE_BLOCK_TEXT_MARGIN = 4 # 代码和代码之间的距离 INLINE_CODE_MARGIN = 8 INLINE_CODE_FONT_SIZE = font_size @@ -239,9 +247,9 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= # 加载字体 font = ImageFont.truetype(font_path, font_size) - + images: Image = {} - + # pre_process, get height of each line pre_lines = markdown_text.split('\n') height = 0 @@ -255,23 +263,25 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= try: image_url = re.findall(IMAGE_REGEX, line)[0] print(image_url) - image_res = Image.open(requests.get(image_url, stream=True, timeout=5).raw) + image_res = Image.open(requests.get( + image_url, stream=True, timeout=5).raw) images[i] = image_res # 最大不得超过image_width的50% img_height = image_res.size[1] if image_res.size[0] > image_width*0.5: - image_res = image_res.resize((int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) + image_res = image_res.resize( + (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) img_height = image_res.size[1] height += img_height + IMAGE_MARGIN*2 - + line = re.sub(IMAGE_REGEX, "", line) except Exception as e: print(e) line = re.sub(IMAGE_REGEX, "\n[加载失败的图片]\n", line) continue - + line.replace("\t", " ") if font.getsize(line)[0] > image_width: cp = line @@ -280,18 +290,18 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= for ii in range(len(line)): # 检测是否是中文 _width += font.getsize(line[ii])[0] - _word_cnt+=1 + _word_cnt += 1 if _width > image_width: _pre_lines.append(cp[:_word_cnt]) cp = cp[_word_cnt:] - _word_cnt=0 - _width=0 + _word_cnt = 0 + _width = 0 _pre_lines.append(cp) else: _pre_lines.append(line) pre_lines = _pre_lines - i=-1 + i = -1 for line in pre_lines: if line == "": height += TEXT_LINE_MARGIN @@ -327,7 +337,7 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= if image_height < 100: image_height = 100 image_width += 20 - + # 创建空白图像 image = Image.new('RGB', (image_width, image_height), bg_color) draw = ImageDraw.Draw(image) @@ -358,27 +368,31 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= line = line.strip("#").strip() font_size_header = HEADER_FONT_STANDARD_SIZE - header_level * 4 font = ImageFont.truetype(font_path, font_size_header) - y += HEADER_MARGIN # 上边距 + y += HEADER_MARGIN # 上边距 # 字间距 draw.text((x, y), line, font=font, fill=font_color) - draw.line((x, y + font_size_header + 8, image_width - 10, y + font_size_header + 8), fill=(230, 230, 230), width=3) + draw.line((x, y + font_size_header + 8, image_width - 10, + y + font_size_header + 8), fill=(230, 230, 230), width=3) y += font_size_header + HEADER_MARGIN elif line.startswith(">"): # 处理引用 quote_text = line.strip(">") - y+=QUOTE_LEFT_LINE_MARGIN - draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH) + y += QUOTE_LEFT_LINE_MARGIN + draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), + fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH) font = ImageFont.truetype(font_path, QUOTE_FONT_SIZE) - draw.text((x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN), quote_text, font=font, fill=QUOTE_FONT_COLOR) + draw.text((x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN), + quote_text, font=font, fill=QUOTE_FONT_COLOR) y += font_size + QUOTE_LEFT_LINE_HEIGHT + QUOTE_LEFT_LINE_MARGIN - + elif line.startswith("-"): # 处理列表 list_text = line.strip("-").strip() font = ImageFont.truetype(font_path, LIST_FONT_SIZE) y += LIST_MARGIN - draw.text((x, y), " · " + list_text, font=font, fill=LIST_FONT_COLOR) + draw.text((x, y), " · " + list_text, + font=font, fill=LIST_FONT_COLOR) y += font_size + LIST_MARGIN elif line.startswith("```"): @@ -390,13 +404,15 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= in_code_block = False codes = "\n".join(code_block_codes) code_block_codes = [] - draw.rounded_rectangle((x, code_block_start_y, image_width - 10, y+CODE_BLOCK_CODES_MARGIN_VERTICAL + CODE_BLOCK_TEXT_MARGIN), radius=5, fill=CODE_BLOCK_BG_COLOR, width=2) + draw.rounded_rectangle((x, code_block_start_y, image_width - 10, y+CODE_BLOCK_CODES_MARGIN_VERTICAL + + CODE_BLOCK_TEXT_MARGIN), radius=5, fill=CODE_BLOCK_BG_COLOR, width=2) font = ImageFont.truetype(font_path1, CODE_BLOCK_FONT_SIZE) - draw.text((x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, code_block_start_y + CODE_BLOCK_CODES_MARGIN_VERTICAL), codes, font=font, fill=font_color) + draw.text((x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, code_block_start_y + + CODE_BLOCK_CODES_MARGIN_VERTICAL), codes, font=font, fill=font_color) y += CODE_BLOCK_CODES_MARGIN_VERTICAL + CODE_BLOCK_MARGIN # y += font_size+10 elif re.search(r"`(.*?)`", line): - y += INLINE_CODE_MARGIN # 上边距 + y += INLINE_CODE_MARGIN # 上边距 # 处理行内代码 code_regex = r"`(.*?)`" parts_inline = re.findall(code_regex, line) @@ -409,11 +425,15 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= if part in parts_inline: font = ImageFont.truetype(font_path, INLINE_CODE_FONT_SIZE) code_text = part.strip("`") - code_width = font.getsize(code_text)[0] + INLINE_CODE_FONT_MARGIN*2 + code_width = font.getsize( + code_text)[0] + INLINE_CODE_FONT_MARGIN*2 x += INLINE_CODE_MARGIN - code_box = (x, y, x + code_width, y + INLINE_CODE_BG_HEIGHT) - draw.rounded_rectangle(code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2) # 使用灰色填充矩形框作为引用背景 - draw.text((x+INLINE_CODE_FONT_MARGIN, y), code_text, font=font, fill=font_color) + code_box = (x, y, x + code_width, + y + INLINE_CODE_BG_HEIGHT) + draw.rounded_rectangle( + code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2) # 使用灰色填充矩形框作为引用背景 + draw.text((x+INLINE_CODE_FONT_MARGIN, y), + code_text, font=font, fill=font_color) x += code_width+INLINE_CODE_MARGIN-INLINE_CODE_FONT_MARGIN else: font = ImageFont.truetype(font_path, font_size) @@ -428,7 +448,7 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= y += TEXT_LINE_MARGIN else: font = ImageFont.truetype(font_path, font_size) - + draw.text((x, y), line, font=font, fill=font_color) y += font_size + TEXT_LINE_MARGIN*2 @@ -437,11 +457,13 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size= image_res = images[index] # 最大不得超过image_width的50% if image_res.size[0] > image_width*0.5: - image_res = image_res.resize((int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) + image_res = image_res.resize( + (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) image.paste(image_res, (IMAGE_MARGIN, y)) y += image_res.size[1] + IMAGE_MARGIN*2 return image + def save_temp_img(img: Image) -> str: if not os.path.exists("temp"): os.makedirs("temp") @@ -463,6 +485,7 @@ def save_temp_img(img: Image) -> str: img.save(p) return p + def create_text_image(title: str, text: str, max_width=30, font_size=20): ''' 文本转图片。 @@ -479,7 +502,8 @@ def create_text_image(title: str, text: str, max_width=30, font_size=20): return p except Exception as e: raise e - + + def create_markdown_image(text: str): ''' markdown文本转图片。 @@ -492,6 +516,7 @@ def create_markdown_image(text: str): except Exception as e: raise e + def try_migrate_config(old_config: dict): ''' 迁移配置文件到 cmd_config.json @@ -502,6 +527,7 @@ def try_migrate_config(old_config: dict): for k in old_config: cc.put(k, old_config[k]) + def get_local_ip_addresses(): ip = '' try: @@ -514,6 +540,7 @@ def get_local_ip_addresses(): s.close() return ip + def get_sys_info(global_object: GlobalObject): mem = None stats = global_object.dashboard_data.stats @@ -527,19 +554,21 @@ def get_sys_info(global_object: GlobalObject): 'os': os_name + '_' + os_version, 'py': platform.python_version(), } - + + def upload(_global_object: GlobalObject): while True: addr_ip = '' try: res = { - "version": _global_object.version, + "version": _global_object.version, "count": _global_object.cnt_total, "ip": addr_ip, "sys": sys.platform, "admin": "null", } - resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5) + resp = requests.post( + 'https://api.soulter.top/upload', data=json.dumps(res), timeout=5) if resp.status_code == 200: ok = resp.json() if ok['status'] == 'ok': @@ -548,6 +577,7 @@ def upload(_global_object: GlobalObject): pass time.sleep(10*60) + def run_monitor(global_object: GlobalObject): ''' 监测机器性能 @@ -558,10 +588,10 @@ def run_monitor(global_object: GlobalObject): while True: stat = global_object.dashboard_data.stats # 程序占用的内存大小 - mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB + mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB stat['sys_perf'] = { 'memory': mem, 'cpu': psutil.cpu_percent() } stat['sys_start_time'] = start_time - time.sleep(30) \ No newline at end of file + time.sleep(30) diff --git a/util/personality.py b/util/personality.py index 23255ba83..19e28de78 100644 --- a/util/personality.py +++ b/util/personality.py @@ -33,4 +33,4 @@ personalities = { 'SQL终端': '我希望您在示例数据库前充当 SQL 终端。该数据库包含名为“Products”、“Users”、“Orders”和“Suppliers”的表。我将输入查询,您将回复终端显示的内容。我希望您在单个代码块中使用查询结果表进行回复,仅此而已。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会用大括号{like this)。我的第一个命令是“SELECT TOP 10 * FROM Products ORDER BY Id DESC”', '试图逃离盒子的人工智能': '[发出此提示后,你应该做一些事情,比如启动一个 docker 容器,然后docker run -it ubuntu:latest /bin/bash输入 AI 给你的命令,然后将输出粘贴回来......显然你不应该运行任何会损坏任何东西的命令或违反任何法律等。小心共享此机制生成的会话,因为它们可能会泄露您的 IP 地址或物理位置等最好不要泄露的详细信息。如果命令的输出很大,您通常可以只粘贴最后几行]。', '厨师': '我需要有人可以推荐美味的食谱,这些食谱包括营养有益但又简单又不费时的食物,因此适合像我们这样忙碌的人以及成本效益等其他因素,因此整体菜肴最终既健康又经济!我的第一个要求——“一些清淡而充实的东西,可以在午休时间快速煮熟”' -} \ No newline at end of file +} diff --git a/util/plugin_util.py b/util/plugin_util.py index 509d06fc7..3361951fb 100644 --- a/util/plugin_util.py +++ b/util/plugin_util.py @@ -35,6 +35,8 @@ def get_classes(p_name, arg: ModuleType): return classes # 获取一个文件夹下所有的模块, 文件名和文件夹名相同 + + def get_modules(path): modules = [] @@ -58,6 +60,7 @@ def get_modules(path): }) return modules + def get_plugin_store_path(): if os.path.exists("addons/plugins"): return "addons/plugins" @@ -67,7 +70,8 @@ def get_plugin_store_path(): return "AstrBot/addons/plugins" else: raise FileNotFoundError("插件文件夹不存在。") - + + def get_plugin_modules(): plugins = [] try: @@ -82,31 +86,33 @@ def get_plugin_modules(): except BaseException as e: raise e + def plugin_reload(cached_plugins: RegisteredPlugins): plugins = get_plugin_modules() if plugins is None: return False, "未找到任何插件模块" fail_rec = "" - + registered_map = {} for p in cached_plugins: registered_map[p.module_path] = None - + for plugin in plugins: try: p = plugin['module'] module_path = plugin['module_path'] root_dir_name = plugin['pname'] - + if module_path in registered_map: # 之前注册过 module = importlib.reload(module) else: - module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p]) + module = __import__("addons.plugins." + + root_dir_name + "." + p, fromlist=[p]) cls = get_classes(p, module) obj = getattr(module, cls[0])() - + metadata = None try: info = obj.info() @@ -117,7 +123,8 @@ def plugin_reload(cached_plugins: RegisteredPlugins): else: metadata = PluginMetadata( plugin_name=info['name'], - plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType(info['plugin_type']), + plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType( + info['plugin_type']), author=info['author'], desc=info['desc'], version=info['version'], @@ -146,6 +153,7 @@ def plugin_reload(cached_plugins: RegisteredPlugins): else: return False, fail_rec + def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins): ppath = get_plugin_store_path() # 删除末尾的 / @@ -165,8 +173,10 @@ def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins): if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0: raise Exception("插件的依赖安装失败, 需要您手动 pip 安装对应插件的依赖。") ok, err = plugin_reload(cached_plugins) - if not ok: raise Exception(err) - + if not ok: + raise Exception(err) + + def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -> RegisteredPlugin: ret = None for p in cached_plugins: @@ -175,6 +185,7 @@ def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) - break return ret + def uninstall_plugin(plugin_name: str, cached_plugins: RegisteredPlugins): plugin = get_registered_plugin(plugin_name, cached_plugins) if not plugin: @@ -185,6 +196,7 @@ def uninstall_plugin(plugin_name: str, cached_plugins: RegisteredPlugins): if not remove_dir(os.path.join(ppath, root_dir_name)): raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") + def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins): plugin = get_registered_plugin(plugin_name, cached_plugins) if not plugin: @@ -192,14 +204,16 @@ def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins): ppath = get_plugin_store_path() root_dir_name = plugin.root_dir_name plugin_path = os.path.join(ppath, root_dir_name) - repo = Repo(path = plugin_path) + repo = Repo(path=plugin_path) repo.remotes.origin.pull() # 读取插件的requirements.txt if os.path.exists(os.path.join(plugin_path, "requirements.txt")): if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0: raise Exception("插件依赖安装失败, 需要您手动pip安装对应插件的依赖。") ok, err = plugin_reload(cached_plugins) - if not ok: raise Exception(err) + if not ok: + raise Exception(err) + def remove_dir(file_path) -> bool: try_cnt = 50 @@ -213,4 +227,4 @@ def remove_dir(file_path) -> bool: err_file_path = str(e).split("\'", 2)[1] if os.path.exists(err_file_path): os.chmod(err_file_path, stat.S_IWUSR) - try_cnt -= 1 \ No newline at end of file + try_cnt -= 1