diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 07749f52b..ad009aa23 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -28,6 +28,9 @@ from model.command.command_rev_edgegpt import CommandRevEdgeGPT from model.command.command_openai_official import CommandOpenAIOfficial from util import general_utils as gu from util.cmd_config import CmdConfig as cc +import util.gplugin as gplugin +from PIL import Image as PILImage +import io @@ -114,6 +117,9 @@ bing_cache_loop = None # 插件 cached_plugins = {} +# 全局对象 +_global_object = {} + # 统计 cnt_total = 0 cnt_valid = 0 @@ -165,7 +171,6 @@ def upload(): addr = requests.get('http://myip.ipip.net', timeout=5).text addr_ip = re.findall(r'\d+.\d+.\d+.\d+', addr)[0] except BaseException as e: - print(e) pass try: o = {"cnt_total": cnt_total,"admin": admin_qq,"addr": addr,} @@ -179,7 +184,6 @@ def upload(): cnt_valid = 0 cnt_total = 0 except BaseException as e: - print(e) pass time.sleep(60*10) @@ -188,8 +192,8 @@ def upload(): ''' def initBot(cfg, prov): global chatgpt, provider, rev_chatgpt, baidu_judge, rev_edgegpt, chosen_provider - global reply_prefix, gpt_config, config, uniqueSession, frequency_count, frequency_time,announcement, direct_message_mode, version - global command_openai_official, command_rev_chatgpt, command_rev_edgegpt,reply_prefix, keywords, cached_plugins + global reply_prefix, gpt_config, config, uniqueSession, frequency_count, frequency_time, announcement, direct_message_mode, version + global command_openai_official, command_rev_chatgpt, command_rev_edgegpt,reply_prefix, keywords, cached_plugins, _global_object provider = prov config = cfg if 'reply_prefix' in cfg: @@ -228,9 +232,9 @@ def initBot(cfg, prov): chatgpt = ProviderOpenAIOfficial(cfg['openai']) chosen_provider = OPENAI_OFFICIAL - command_rev_edgegpt = CommandRevEdgeGPT(rev_edgegpt) - command_rev_chatgpt = CommandRevChatGPT(rev_chatgpt) - command_openai_official = CommandOpenAIOfficial(chatgpt) + command_rev_edgegpt = CommandRevEdgeGPT(rev_edgegpt, _global_object) + command_rev_chatgpt = CommandRevChatGPT(rev_chatgpt, _global_object) + command_openai_official = CommandOpenAIOfficial(chatgpt, _global_object) gu.log("--------加载个性化配置--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow']) # 得到关键词 @@ -270,8 +274,11 @@ def initBot(cfg, prov): # 得到公告配置 if 'notice' in cfg: - gu.log("公告配置: "+cfg['notice'], gu.LEVEL_INFO) - announcement += cfg['notice'] + if cc.get("qq_welcome", None) != None and cfg['notice'] == '此机器人由Github项目QQChannelChatGPT驱动。': + announcement = cc.get("qq_welcome", None) + else: + announcement = cfg['notice'] + gu.log("公告配置: " + announcement, gu.LEVEL_INFO) try: if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']: uniqueSession = True @@ -387,7 +394,6 @@ def run_gocq_bot(loop, gocq_bot, gocq_app): gu.log("检查完毕,未发现问题。", tag="QQ") break - global gocq_client gocq_client = gocqClient() try: @@ -424,22 +430,15 @@ def save_provider_preference(chosen_provider): ''' 通用回复方法 ''' -def send_message(platform, message, res, msg_ref = None, image = None, gocq_loop = None, qqchannel_bot = None, gocq_bot = None, image_mode=False): +def send_message(platform, message, res, msg_ref = None, image = None, image_mode=False): # imagemode: # For GOCQ: when image_mode is true, ALL plain texts in res will change into a new pic - global cnt_valid + global cnt_valid, qqchannel_bot, qqchannel_bot, gocq_loop cnt_valid += 1 if platform == PLATFORM_QQCHAN: - if image != None: - qqchannel_bot.send_qq_msg(message, str(res), image_mode=True, msg_ref=msg_ref) - else: - qqchannel_bot.send_qq_msg(message, str(res), msg_ref=msg_ref) - if platform == PLATFORM_GOCQ: - if image != None: - # image is a url string - asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, [Plain(text="好的,我根据你的需要为你生成了一张图片😊"),Image.fromURL(image)], False), gocq_loop).result() - else: - asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, res, image_mode), gocq_loop).result() + qqchannel_bot.send_qq_msg(message, res, msg_ref=msg_ref) + if platform == PLATFORM_GOCQ: + asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, res, image_mode), gocq_loop).result() def oper_msg(message, @@ -461,7 +460,7 @@ def oper_msg(message, hit = False # 是否命中指令 command_result = () # 调用指令返回的结果 global admin_qq, admin_qqchan, cached_plugins, gocq_bot, nick_qq - global cnt_total + global cnt_total, _global_object cnt_total += 1 @@ -551,13 +550,13 @@ def oper_msg(message, role = "admin" if qq_msg == "": - send_message(platform, message, f"Hi~", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"Hi~", msg_ref=msg_ref) return if with_tag: # 检查发言频率 if not check_frequency(user_id): - send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', msg_ref=msg_ref) return # logf.write("[GOCQBOT] "+ qq_msg+'\n') @@ -566,19 +565,19 @@ def oper_msg(message, # 关键词回复 for k in keywords: if qq_msg == k: - send_message(platform, message, keywords[k], msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, keywords[k], msg_ref=msg_ref) return # 关键词拦截器 for i in uw.unfit_words_q: matches = re.match(i, qq_msg.strip(), re.I | re.M) if matches: - send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", msg_ref=msg_ref) return if baidu_judge != None: check, msg = baidu_judge.judge(qq_msg) if not check: - send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", msg_ref=msg_ref) return # 检查是否是更换语言模型的请求 @@ -599,50 +598,67 @@ def oper_msg(message, qq_msg = l[1] else: # if role != "admin": - # send_message(platform, message, "你没有权限更换语言模型。", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + # send_message(platform, message, "你没有权限更换语言模型。", msg_ref=msg_ref) # return chosen_provider = target save_provider_preference(chosen_provider) - send_message(platform, message, f"已切换至【{chosen_provider}】", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"已切换至【{chosen_provider}】", msg_ref=msg_ref) return chatgpt_res = "" if chosen_provider == OPENAI_OFFICIAL: - hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot) + hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, + platform=platform, message_obj=message, + cached_plugins=cached_plugins, + qq_platform=gocq_bot) # hit: 是否触发了指令 if not hit: if not with_tag: return if chatgpt == None: - send_message(platform, message, f"管理员未启动OpenAI模型或初始化时失败。", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"管理员未启动OpenAI模型或初始化时失败。", msg_ref=msg_ref) return # 请求ChatGPT获得结果 try: - chatgpt_res = chatgpt.text_chat(qq_msg, session_id) + if _global_object != None and "web_search" in _global_object and _global_object["web_search"]: + chatgpt_res = gplugin.web_search(qq_msg, chatgpt) + else: + chatgpt_res = str(chatgpt.text_chat(qq_msg)) if OPENAI_OFFICIAL in reply_prefix: chatgpt_res = reply_prefix[OPENAI_OFFICIAL] + chatgpt_res except (BaseException) as e: gu.log("OpenAI API请求错误, 原因: "+str(e), gu.LEVEL_ERROR) - send_message(platform, message, f"OpenAI API错误, 原因: {str(e)}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"OpenAI API错误, 原因: {str(e)}", msg_ref=msg_ref) elif chosen_provider == REV_CHATGPT: - hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot) + hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, + platform=platform, + message_obj=message, + cached_plugins=cached_plugins, + qq_platform=gocq_bot) if not hit: if not with_tag: return if rev_chatgpt == None: - send_message(platform, message, f"管理员未启动此模型或者此模型初始化时失败。", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"管理员未启动此模型或者此模型初始化时失败。", msg_ref=msg_ref) return try: while rev_chatgpt.is_all_busy(): time.sleep(1) - chatgpt_res = str(rev_chatgpt.text_chat(qq_msg)) + + # ws_prompt = f"{qq_msg}\n\n提示:" + # chatgpt_res = str(rev_chatgpt.text_chat(ws_prompt)) + if _global_object != None and "web_search" in _global_object and _global_object["web_search"]: + chatgpt_res = gplugin.web_search(qq_msg, rev_chatgpt) + else: + chatgpt_res = str(rev_chatgpt.text_chat(qq_msg)) + if REV_CHATGPT in reply_prefix: chatgpt_res = reply_prefix[REV_CHATGPT] + chatgpt_res except BaseException as e: gu.log("逆向ChatGPT请求错误, 原因: "+str(e), gu.LEVEL_ERROR) - send_message(platform, message, f"RevChatGPT错误, 原因: \n{str(e)}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"RevChatGPT错误, 原因: \n{str(e)}", msg_ref=msg_ref) elif chosen_provider == REV_EDGEGPT: if bing_cache_loop == None: @@ -650,32 +666,35 @@ def oper_msg(message, bing_cache_loop = gocq_loop elif platform == PLATFORM_QQCHAN: bing_cache_loop = qqchan_loop - hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot) + hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, + platform=platform, message_obj=message, + cached_plugins=cached_plugins, + qq_platform=gocq_bot) if not hit: try: if not with_tag: return if rev_edgegpt == None: - send_message(platform, message, f"管理员未启动此模型或者此模型初始化时失败。", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"管理员未启动此模型或者此模型初始化时失败。", msg_ref=msg_ref) return while rev_edgegpt.is_busy(): time.sleep(1) res, res_code = asyncio.run_coroutine_threadsafe(rev_edgegpt.text_chat(qq_msg, platform), bing_cache_loop).result() if res_code == 0: # bing不想继续话题,重置会话后重试。 - send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", msg_ref=msg_ref) asyncio.run_coroutine_threadsafe(rev_edgegpt.forget(), bing_cache_loop).result() res, res_code = asyncio.run_coroutine_threadsafe(rev_edgegpt.text_chat(qq_msg, platform), bing_cache_loop).result() if res_code == 0: # bing还是不想继续话题,大概率说明提问有问题。 asyncio.run_coroutine_threadsafe(rev_edgegpt.forget(), bing_cache_loop).result() - send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", msg_ref=msg_ref) res = "" chatgpt_res = str(res) if REV_EDGEGPT in reply_prefix: chatgpt_res = reply_prefix[REV_EDGEGPT] + chatgpt_res except BaseException as e: gu.log("NewBing请求错误, 原因: "+str(e), gu.LEVEL_ERROR) - send_message(platform, message, f"Rev NewBing API错误。原因如下:\n{str(e)} \n前往官方频道反馈~", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"Rev NewBing API错误。原因如下:\n{str(e)} \n前往官方频道反馈~", msg_ref=msg_ref) # 切换回原来的语言模型 if temp_switch != "": @@ -700,17 +719,22 @@ def oper_msg(message, if isinstance(command_result[1], list) and len(command_result) == 3 and command_result[2] == 'draw': if chatgpt != None: for i in command_result[1]: - send_message(platform, message, i, msg_ref=msg_ref, image=i, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + # i is a link + # 保存到本地 + pic_res = requests.get(i, stream = True) + if pic_res.status_code == 200: + image = PILImage.open(io.BytesIO(pic_res.content)) + send_message(platform, message, [Image.fromFileSystem(gu.save_temp_img(image))], msg_ref=msg_ref) else: - send_message(platform, message, "画图指令需要启用OpenAI官方模型.", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, "画图指令需要启用OpenAI官方模型.", msg_ref=msg_ref) else: try: - send_message(platform, message, command_result[1], msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, command_result[1], msg_ref=msg_ref) except BaseException as e: - send_message(platform, message, f"回复消息出错: {str(e)}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"回复消息出错: {str(e)}", msg_ref=msg_ref) else: - send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", msg_ref=msg_ref) return @@ -729,18 +753,18 @@ def oper_msg(message, if baidu_judge != None: check, msg = baidu_judge.judge(chatgpt_res) if not check: - send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", msg_ref=msg_ref) return # 发送qq信息 try: if platform==PLATFORM_GOCQ: if cc.get("qq_pic_mode", False): - send_message(platform, message, chatgpt_res, image_mode=True, msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, chatgpt_res, image_mode=True, msg_ref=msg_ref) else: - send_message(platform, message, chatgpt_res, msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, chatgpt_res, msg_ref=msg_ref) else: - send_message(platform, message, chatgpt_res, msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot) + send_message(platform, message, chatgpt_res, msg_ref=msg_ref) except BaseException as e: gu.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR) @@ -811,9 +835,9 @@ class gocqClient(): @gocq_app.receiver("GroupMemberIncrease") async def _(app: CQHTTP, source: GroupMemberIncrease): - global nick_qq, cc + global nick_qq, announcement await app.sendGroupMessage(source.group_id, [ - Plain(text=cc.get("qq_welcome", "欢迎新人~")), + Plain(text = announcement), ]) @gocq_app.receiver("GuildMessage") diff --git a/main.py b/main.py index d870ce13a..245226d1f 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,12 @@ def main(): os.environ['HTTP_PROXY'] = cfg['http_proxy'] if 'https_proxy' in cfg: os.environ['HTTPS_PROXY'] = cfg['https_proxy'] + + os.environ['NO_PROXY'] = 'cn.bing.com,https://api.sgroup.qq.com' + + # 检查temp文件夹 + if not os.path.exists(abs_path+"temp"): + os.mkdir(abs_path+"temp") provider = privider_chooser(cfg) if len(provider) == 0: diff --git a/model/command/command.py b/model/command/command.py index 1b4c97ed1..578c5ee39 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -41,7 +41,11 @@ class Command: except BaseException as e: raise e - def check_command(self, message, role, platform, message_obj, cached_plugins: dict, qq_platform: QQ): + def check_command(self, message, role, platform, + message_obj, + cached_plugins: dict, + qq_platform: QQ, + global_object: dict): # 插件 for k, v in cached_plugins.items(): @@ -62,9 +66,21 @@ class Command: return True, self.get_my_id(message_obj, platform) if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"): return True, self.get_new_conf(message, role, platform) + if self.command_start_with(message, "web"): # 网页搜索 + return True, self.web_search(message, global_object) return False, None + def web_search(self, message, global_object): + if "web_search" not in global_object: + global_object["web_search"] = False + if message == "web on": + global_object["web_search"] = True + return True, "已开启网页搜索", "web" + elif message == "web off": + global_object["web_search"] = False + return True, "已关闭网页搜索", "web" + return True, f"网页搜索功能当前状态: {global_object['web_search']}", "web" def get_my_id(self, message_obj, platform): print(message_obj) if platform == "gocq": @@ -303,6 +319,7 @@ class Command: "reset": "重置会话", "nick": "设置机器人昵称", "plugin": "插件安装、卸载和重载", + "web on/off": "启动或关闭网页搜索能力", "/bing": "切换到bing模型", "/gpt": "切换到OpenAI ChatGPT API", "/revgpt": "切换到网页版ChatGPT", diff --git a/model/command/command_openai_official.py b/model/command/command_openai_official.py index 21e3e07f8..54fcdfe3b 100644 --- a/model/command/command_openai_official.py +++ b/model/command/command_openai_official.py @@ -6,9 +6,10 @@ from util import general_utils as gu class CommandOpenAIOfficial(Command): - def __init__(self, provider: ProviderOpenAIOfficial): + def __init__(self, provider: ProviderOpenAIOfficial, global_object: dict): self.provider = provider self.cached_plugins = {} + self.global_object = global_object def check_command(self, message: str, @@ -18,9 +19,12 @@ class CommandOpenAIOfficial(Command): platform: str, message_obj, cached_plugins: dict, - qq_platform: QQ): + qq_platform: QQ,): self.platform = platform - hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform) + hit, res = super().check_command(message, role, platform, message_obj=message_obj, + cached_plugins=cached_plugins, + qq_platform=qq_platform, + global_object=self.global_object) if hit: return True, res if self.command_start_with(message, "reset", "重置"): @@ -43,7 +47,7 @@ class CommandOpenAIOfficial(Command): return True, self.set(message, session_id) elif self.command_start_with(message, "update"): return True, self.update(message, role) - elif self.command_start_with(message, "画"): + elif self.command_start_with(message, "画", "draw"): return True, self.draw(message) elif self.command_start_with(message, "keyword"): return True, self.keyword(message, role) diff --git a/model/command/command_rev_chatgpt.py b/model/command/command_rev_chatgpt.py index 655ac2aa7..ea687141e 100644 --- a/model/command/command_rev_chatgpt.py +++ b/model/command/command_rev_chatgpt.py @@ -3,9 +3,10 @@ from model.provider.provider_rev_chatgpt import ProviderRevChatGPT from model.platform.qq import QQ class CommandRevChatGPT(Command): - def __init__(self, provider: ProviderRevChatGPT): + def __init__(self, provider: ProviderRevChatGPT, global_object: dict): self.provider = provider self.cached_plugins = {} + self.global_object = global_object def check_command(self, message: str, @@ -15,7 +16,10 @@ class CommandRevChatGPT(Command): cached_plugins: dict, qq_platform: QQ): self.platform = platform - hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform) + hit, res = super().check_command(message, role, platform, message_obj=message_obj, + cached_plugins=cached_plugins, + qq_platform=qq_platform, + global_object=self.global_object) if hit: return True, res if self.command_start_with(message, "help", "帮助"): diff --git a/model/command/command_rev_edgegpt.py b/model/command/command_rev_edgegpt.py index 466e799b9..eee02d7ba 100644 --- a/model/command/command_rev_edgegpt.py +++ b/model/command/command_rev_edgegpt.py @@ -4,11 +4,11 @@ import asyncio from model.platform.qq import QQ class CommandRevEdgeGPT(Command): - def __init__(self, provider: ProviderRevEdgeGPT): + def __init__(self, provider: ProviderRevEdgeGPT, global_object: dict): self.provider = provider self.cached_plugins = {} + self.global_object = global_object - def check_command(self, message: str, loop, @@ -18,7 +18,10 @@ class CommandRevEdgeGPT(Command): cached_plugins: dict, qq_platform: QQ): self.platform = platform - hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform) + hit, res = super().check_command(message, role, platform, message_obj=message_obj, + cached_plugins=cached_plugins, + qq_platform=qq_platform, + global_object=self.global_object) if hit: return True, res if self.command_start_with(message, "reset"): diff --git a/model/platform/qqchan.py b/model/platform/qqchan.py index 0d1a93a5f..aae08a5a8 100644 --- a/model/platform/qqchan.py +++ b/model/platform/qqchan.py @@ -1,12 +1,13 @@ import io import botpy -from PIL import Image +from PIL import Image as PILImage from botpy.message import Message, DirectMessage import re import asyncio import requests from cores.qqbot.personality import personalities from util import general_utils as gu +from nakuru.entities.components import Plain, At, Image class QQChan(): @@ -15,51 +16,57 @@ class QQChan(): self.client = botclient self.client.run(appid=appid, token=token) - def send_qq_msg(self, message, res, image_mode=False, msg_ref = None): - gu.log("回复QQ频道消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ频道", max_len=30) + # gocq兼容层 + def gocq_compatible(self, gocq_message_chain: list): + plain_text = "" + image_path = None # only one img supported + for i in gocq_message_chain: + if isinstance(i, Plain): + plain_text += i.text + elif isinstance(i, Image) and image_path == None: + image_path = i.path + return plain_text, image_path - if not image_mode: - try: - if msg_ref is not None: - reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(res), message_reference = msg_ref), self.client.loop) - else: - reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(res)), self.client.loop) - reply_res.result() - except BaseException as e: - # 分割过长的消息 - if "msg over length" in str(e): - split_res = [] - split_res.append(res[:len(res)//2]) - split_res.append(res[len(res)//2:]) - for i in split_res: - if msg_ref is not None: - reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=i, message_reference = msg_ref), self.client.loop) - else: - reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=i), self.client.loop) - reply_res.result() - else: - # 发送qq信息 + + + def send_qq_msg(self, message: Message, res, msg_ref = None): + gu.log("回复QQ频道消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ频道", max_len=500) + + plain_text = "" + image_path = None + if isinstance(res, list): + # 兼容gocq + plain_text, image_path = self.gocq_compatible(res) + elif isinstance(res, str): + plain_text = res + + print(plain_text, image_path) + + try: + reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop) + reply_res.result() + except BaseException as e: + # 分割过长的消息 + if "msg over length" in str(e): + split_res = [] + split_res.append(plain_text[:len(plain_text)//2]) + split_res.append(plain_text[len(plain_text)//2:]) + for i in split_res: + reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(i), message_reference = msg_ref, file_image=image_path), self.client.loop) + reply_res.result() + else: + # 发送qq信息 + try: + # 防止被qq频道过滤消息 + plain_text = plain_text.replace(".", " . ") + reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop) + # 发送信息 + except BaseException as e: + print("QQ频道API错误: \n"+str(e)) try: - # 防止被qq频道过滤消息 - res = res.replace(".", " . ") - asyncio.run_coroutine_threadsafe(message.reply(content=res), self.client.loop).result() - # 发送信息 + reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop) except BaseException as e: - print("QQ频道API错误: \n"+str(e)) - res = str.join(" ", res) - try: - asyncio.run_coroutine_threadsafe(message.reply(content=res), self.client.loop).result() - except BaseException as e: - # 如果还是不行则报出错误 - res = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) - res = res.replace(".", "·") - asyncio.run_coroutine_threadsafe(message.reply(content=res), self.client.loop).result() - # send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}") - else: - pic_res = requests.get(str(res), stream=True) - if pic_res.status_code == 200: - # 将二进制数据转换成图片对象 - image = Image.open(io.BytesIO(pic_res.content)) - # 保存图片到本地 - image.save('tmp_image.jpg') - asyncio.run_coroutine_threadsafe(message.reply(file_image='tmp_image.jpg', content=""), self.client.loop) \ No newline at end of file + plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) + plain_text = plain_text.replace(".", "·") + asyncio.run_coroutine_threadsafe(message.reply(content=plain_text), self.client.loop).result() + # send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}") \ No newline at end of file diff --git a/model/provider/provider_openai_official.py b/model/provider/provider_openai_official.py index 01e099fb3..fec7d67a6 100644 --- a/model/provider/provider_openai_official.py +++ b/model/provider/provider_openai_official.py @@ -90,7 +90,10 @@ class ProviderOpenAIOfficial(Provider): # 每隔10分钟转储一次 time.sleep(10*self.history_dump_interval) - def text_chat(self, prompt, session_id): + def text_chat(self, prompt, session_id = None): + if session_id is None: + session_id = "unknown" + del self.session_dict["unknown"] # 会话机制 if session_id not in self.session_dict: self.session_dict[session_id] = [] @@ -136,6 +139,8 @@ class ProviderOpenAIOfficial(Provider): gu.log("token超限, 清空对应缓存") self.session_dict[session_id] = [] cache_data_list, new_record, req = self.wrap(prompt, session_id) + elif 'Limit: 3 / min. Please try again in 20s.' in str(e): + time.sleep(60) else: gu.log(str(e), level=gu.LEVEL_ERROR) err = str(e) diff --git a/model/provider/provider_rev_chatgpt.py b/model/provider/provider_rev_chatgpt.py index c42474a4c..f54b74d34 100644 --- a/model/provider/provider_rev_chatgpt.py +++ b/model/provider/provider_rev_chatgpt.py @@ -3,6 +3,7 @@ from revChatGPT import typings from model.provider.provider import Provider from util import general_utils as gu from util import cmd_config as cc +import time class ProviderRevChatGPT(Provider): @@ -27,9 +28,10 @@ class ProviderRevChatGPT(Provider): rev_account_config['PUID'] = self.cc.get("rev_chatgpt_PUID") if len(self.cc.get("rev_chatgpt_unverified_plugin_domains")) > 0: rev_account_config['unverified_plugin_domains'] = self.cc.get("rev_chatgpt_unverified_plugin_domains") - + cb = Chatbot(config=rev_account_config) + # cb.captcha_solver = self.__captcha_solver revstat = { - 'obj': Chatbot(config=rev_account_config), + 'obj': cb, 'busy': False } self.rev_chatgpt.append(revstat) @@ -39,6 +41,14 @@ class ProviderRevChatGPT(Provider): def forget(self) -> bool: return False + # def __captcha_solver(images: list[str], challenge_details: dict) -> int: + # # Create tempfile + # print("Captcha solver called") + # print(images) + # print(challenge_details) + # input("Press Enter to continue...") + # return 0 + def request_text(self, prompt: str, bot) -> str: resp = '' err_count = 0 @@ -50,8 +60,6 @@ class ProviderRevChatGPT(Provider): resp = data["message"] break except typings.Error as e: - if e.code == typings.ErrorType.RATE_LIMIT_ERROR: - raise e if e.code == typings.ErrorType.INVALID_ACCESS_TOKEN_ERROR: raise e if e.code == typings.ErrorType.EXPIRED_ACCESS_TOKEN_ERROR: @@ -59,17 +67,25 @@ class ProviderRevChatGPT(Provider): if e.code == typings.ErrorType.PROHIBITED_CONCURRENT_QUERY_ERROR: raise e + if "The message you submitted was too long" in str(e): + raise e + if "You've reached our limit of messages per hour." in str(e): + raise e + if "Rate limited by proxy" in str(e): + gu.log(f"触发请求频率限制, 60秒后自动重试。", level=gu.LEVEL_WARNING, tag="RevChatGPT") + time.sleep(60) + err_count += 1 - gu.log(f"请求出现问题: {str(e)} | 正在重试: {str(err_count)}", level=gu.LEVEL_WARNING, tag="RevChatGPT") + gu.log(f"请求异常: {str(e)},正在重试。({str(err_count)})", level=gu.LEVEL_WARNING, tag="RevChatGPT") if err_count >= retry_count: raise e except BaseException as e: err_count += 1 - gu.log(f"请求出现问题: {str(e)} | 正在重试: {str(err_count)}", level=gu.LEVEL_WARNING, tag="RevChatGPT") + gu.log(f"请求异常: {str(e)},正在重试。({str(err_count)})", level=gu.LEVEL_WARNING, tag="RevChatGPT") if err_count >= retry_count: raise e if resp == '': - resp = "RevChatGPT出现故障." + resp = "RevChatGPT请求异常。" # print("[RevChatGPT] "+str(resp)) return resp @@ -95,8 +111,7 @@ class ProviderRevChatGPT(Provider): else: err_msg += f"账号{cursor} - 错误原因: 忙碌" continue - res = f'回复失败。错误跟踪:{err_msg}' - return res + raise Exception(f'回复失败。错误跟踪:{err_msg}') def is_all_busy(self) -> bool: for revstat in self.rev_chatgpt: diff --git a/util/func_call.py b/util/func_call.py new file mode 100644 index 000000000..fd90b125b --- /dev/null +++ b/util/func_call.py @@ -0,0 +1,214 @@ + +import json +import util.general_utils as gu + +class FuncCallJsonFormatError(Exception): + def __init__(self, msg): + self.msg = msg + + def __str__(self): + return self.msg + +class FuncNotFoundError(Exception): + def __init__(self, msg): + self.msg = msg + + def __str__(self): + return self.msg + +class FuncCall(): + def __init__(self, provider) -> None: + self.func_list = [] + self.provider = provider + + def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None: + if name == None or func_args == None or desc == None or func_obj == None: + raise FuncCallJsonFormatError("name, func_args, desc must be provided.") + self._func = { + "name": name, + "args": func_args, + "description": desc, + "func_obj": func_obj, + } + self.func_list.append(self._func) + + def func_dump(self, intent: int = 2) -> str: + _l = [] + for f in self.func_list: + _l.append({ + "name": f["name"], + "args": f["args"], + "description": f["description"], + }) + + return json.dumps(_l, indent=intent, ensure_ascii=False) + + def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True): + + funccall_prompt = """ +我正在实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(这意味着你不是创造函数)。 +下面会给你提供可能会用到函数的相关信息,和一个问题,你需要将其转换成给定的函数调用。 +- 你的返回信息只含json,且严格仿照以下内容(不含注释): +``` +{ + "res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。 + "func_call": [ // 这是一个数组,里面包含了所有的函数调用,如果没有函数调用,那么这个数组是空数组。 + { + "res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。 + "name": str, // 函数的名字 + "args_type": { + "arg1": str, // 函数的参数的类型 + "arg2": str, + ... + }, + "args": { + "arg1": any, // 函数的参数 + "arg2": any, + ... + } + }, + ... // 可能在这个问题中会有多个函数调用 + ], +} +``` +- 如果用户的要求较复杂,允许返回多个函数调用,但需保证这些函数调用的顺序正确。 +- 当问题没有提到给定的函数时,相当于提问方不打算使用function call功能,这时你可以在res中正常输出这个问题的回答(以AI的身份正常回答该问题,并将答案输出在res字段中,回答不要涉及到任何函数调用的内容,就只是正常讨论这个问题。) + +提供的函数是: + +""" + + prompt = f"{funccall_prompt}\n```\n{func_definition}\n```\n" + prompt += f""" +用户的提问是: +``` +{question} +``` +""" + + # if is_task: + # # task_prompt = f"\n任务列表为{str(tasks)}\n你目前进行到了任务{str(taskindex)}, **你不需要重新进行已经进行过的任务, 不要生成已经进行过的**" + # prompt += task_prompt + + # provider.forget() + + _c = 0 + while _c < 3: + try: + res = self.provider.text_chat(prompt) + if res.find('```') != -1: + res = res[res.find('```json') + 7: res.rfind('```')] + gu.log("REVGPT func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + print(res) + res = json.loads(res) + break + except Exception as e: + _c += 1 + if _c == 3: + raise e + if "The message you submitted was too long" in str(e): + raise e + + invoke_func_res = "" + + if len(res["func_call"]) > 0: + task_list = res["func_call"] + + invoke_func_res_list = [] + + for res in task_list: + # 说明有函数调用 + func_name = res["name"] + # args_type = res["args_type"] + args = res["args"] + # 调用函数 + # func = eval(func_name) + func_target = None + for func in self.func_list: + if func["name"] == func_name: + func_target = func["func_obj"] + break + if func_target == None: + raise FuncNotFoundError(f"Request function {func_name} not found.") + t_res = str(func_target(**args)) + invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n" + invoke_func_res_list.append(invoke_func_res) + gu.log(f"[FUNC| {func_name} invoked]", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + # print(str(t_res)) + + if is_summary: + + # 生成返回结果 + after_prompt = """ +函数返回以下内容:"""+invoke_func_res+""" +请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。 +用户的提问是: +```""" + question + """``` +- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。 +- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释): +```json +{ + "res": string, // 回答的内容 + "func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false +} +``` +- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。""" + + _c = 0 + while _c < 5: + try: + res = self.provider.text_chat(after_prompt) + # 截取```之间的内容 + gu.log("DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) + print(res) + gu.log("DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) + if res.find('```') != -1: + res = res[res.find('```json') + 7: res.rfind('```')] + gu.log("REVGPT after_func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + after_prompt_res = res + after_prompt_res = json.loads(after_prompt_res) + break + except Exception as e: + _c += 1 + if _c == 5: + raise e + if "The message you submitted was too long" in str(e): + # 如果返回的内容太长了,那么就截取一部分 + invoke_func_res = invoke_func_res[:int(len(invoke_func_res) / 2)] + after_prompt = """ +函数返回以下内容:"""+invoke_func_res+""" +请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。 +用户的提问是: +```""" + question + """``` +- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。 +- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释): +```json +{ + "res": string, // 回答的内容 + "func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false +} +``` +- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。""" + else: + raise e + + if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]: + # 如果需要重新调用函数 + # 重新调用函数 + gu.log("REVGPT func_call_again", bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"]) + res = self.func_call(question, func_definition) + return res, True + + gu.log("REVGPT func callback:", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) + # print(after_prompt_res["res"]) + return after_prompt_res["res"], True + else: + return str(invoke_func_res_list), True + else: + # print(res["res"]) + return res["res"], False + + + + + diff --git a/util/gplugin.py b/util/gplugin.py new file mode 100644 index 000000000..d902b34a9 --- /dev/null +++ b/util/gplugin.py @@ -0,0 +1,150 @@ +import requests +import util.general_utils as gu +from bs4 import BeautifulSoup +import time +from util.func_call import ( + FuncCall, + FuncCallJsonFormatError, + FuncNotFoundError +) +def tidy_text(text: str) -> str: + return text.strip().replace("\n", "").replace(" ", "").replace("\r", "") + +def special_fetch_zhihu(link: str) -> str: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ + AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + response = requests.get(link, headers=headers) + soup = BeautifulSoup(response.text, "html.parser") + r = soup.find(class_="List-item").find(class_="RichContent-inner") + if r is None: + print("debug: zhihu none") + raise Exception("zhihu none") + return tidy_text(r.text) + +def web_keyword_search_via_bing(keyword) -> str: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ + AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + url = "https://cn.bing.com/search?q="+keyword + _cnt = 0 + _detail_store = [] + while _cnt < 5: + try: + response = requests.get(url, headers=headers) + soup = BeautifulSoup(response.text, "html.parser") + res = [] + ols = soup.find(id="b_results") + for i in ols.find_all("li", class_="b_algo"): + try: + title = i.find("h2").text + desc = i.find("p").text + link = i.find("h2").find("a").get("href") + res.append({ + "title": title, + "desc": desc, + "link": link, + }) + if len(_detail_store) < 2 and "zhihu.com" in link: + try: + _detail_store.append(special_fetch_zhihu(link)[:800]) + except BaseException as e: + print(f"zhihu parse err: {str(e)}") + if len(res) >= 5: # 限制5条 + break + except Exception as e: + print(f"bing parse err: {str(e)}") + if len(res) == 0: + break + if len(_detail_store) > 0: + ret = f"{str(res)} \n来源知乎的具体资料: {str(_detail_store)}" + else: + ret = f"{str(res)}" + return str(ret) + except Exception as e: + print(f"bing fetch err: {str(e)}") + _cnt += 1 + time.sleep(1) + print("fail to fetch bing info, using sougou.") + return web_keyword_search_via_sougou(keyword) + +def web_keyword_search_via_sougou(keyword) -> str: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ + AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + } + url = f"https://sogou.com/web?query={keyword}" + response = requests.get(url, headers=headers) + response.encoding = "utf-8" + soup = BeautifulSoup(response.text, "html.parser") + + res = [] + results = soup.find("div", class_="results") + for i in results.find_all("div", class_="vrwrap"): + try: + title = tidy_text(i.find("h3").text) + link = tidy_text(i.find("h3").find("a").get("href")) + if link.startswith("/link?url="): + link = "https://www.sogou.com" + link + res.append({ + "title": title, + "link": link, + }) + except: + pass + ret = f"{str(res)} \n全部内容: {tidy_text(soup.text)}" + return ret + +def fetch_website_content(url): + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ + AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + response = requests.get(url, headers=headers) + soup = BeautifulSoup(response.text, "html.parser") + res = soup.text + res = res.replace("\n", "") + with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f: + f.write(res) + return res + +def web_search(question, provider): + + new_func_call = FuncCall(provider) + + new_func_call.add_func("web_keyword_search_via_bing", [{ + "type": "string", + "name": "keyword", + "brief": "必应搜索的关键词(分词,尽量保留所有信息)" + }], + "在必应搜索引擎上搜索给定的关键词,并且返回第一页的搜索结果列表(标题,简介和链接)", + web_keyword_search_via_bing + ) + + func_definition1 = new_func_call.func_dump() + question1 = f"{question} \n(只能调用一个函数。)" + res1, has_func = new_func_call.func_call(question1, func_definition1, is_task=False, is_summary=False) + has_func = True + if has_func: + provider.forget() + question3 = f"""请你回答`{question}`问题。\n以下是相关材料,你请直接拿此材料针对问题进行总结回答,然后再给出参考链接。不要提到任何函数调用的信息。```\n{res1}\n```\n""" + print(question3) + _c = 0 + while _c < 5: + try: + print('text chat') + res3 = provider.text_chat(question3) + break + except Exception as e: + print(e) + _c += 1 + if _c == 5: + raise e + if "The message you submitted was too long" in str(e): + res2 = res2[:int(len(res2) / 2)] + question3 = f"""请你回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,然后再给出参考链接。```\n{res1}\n{res2}\n```\n""" + return res3 + else: + return res1 \ No newline at end of file