diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 144f512c6..7e3db3e81 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -1,33 +1,35 @@ import re -import json import threading import asyncio import time -import requests +import aiohttp import util.unfit_words as uw import os import sys -from addons.baidu_aip_judge import BaiduJudge +import io +import traceback + +import util.function_calling.gplugin as gplugin +import util.plugin_util as putil + +from PIL import Image as PILImage +from typing import Union from nakuru import ( GroupMessage, FriendMessage, GuildMessage, ) +from nakuru.entities.components import Plain, At, Image + +from addons.baidu_aip_judge import BaiduJudge from model.platform._nakuru_translation_layer import NakuruGuildMessage -from nakuru.entities.components import Plain,At,Image from model.provider.provider import Provider from model.command.command import Command from util import general_utils as gu from util.general_utils import Logger, upload, run_monitor from util.cmd_config import CmdConfig as cc from util.cmd_config import init_astrbot_config_items -import util.function_calling.gplugin as gplugin -import util.plugin_util as putil -from PIL import Image as PILImage -import io -import traceback from . global_object import GlobalObject -from typing import Union from addons.dashboard.helper import DashBoardHelper from addons.dashboard.server import DashBoardData from cores.database.conn import dbConn @@ -41,7 +43,7 @@ frequency_time = 60 frequency_count = 10 # 版本 -version = '3.1.5' +version = '3.1.6' # 语言模型 REV_CHATGPT = 'rev_chatgpt' @@ -325,7 +327,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak command_result = () # 调用指令返回的结果 # 统计数据,如频道消息量 - record_message(platform, session_id) + await record_message(platform, session_id) for i in message.message: if isinstance(i, Plain): @@ -334,8 +336,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak return MessageResult("Hi~") # 检查发言频率 - user_id = message.user_id - if not check_frequency(user_id): + if not check_frequency(message.user_id): return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。') # 检查是否是更换语言模型的请求 @@ -359,7 +360,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak llm_result_str = "" - hit, command_result = llm_command_instance[chosen_provider].check_command( + # check commands and plugins + hit, command_result = await llm_command_instance[chosen_provider].check_command( message_str, session_id, role, @@ -375,11 +377,12 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak if matches: return MessageResult(f"你的提问得到的回复未通过【默认关键词拦截】服务, 不予回复。") if baidu_judge != None: - check, msg = baidu_judge.judge(message_str) + check, msg = await asyncio.to_thread(baidu_judge.judge, message_str) if not check: return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}") if chosen_provider == NONE_LLM: - return MessageResult("没有启动任何 LLM 并且未触发任何指令。") + logger.log("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。", gu.LEVEL_WARNING) + return try: if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix): return @@ -403,9 +406,9 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL: if _global_object.web_search or web_sch_flag: official_fc = chosen_provider == OPENAI_OFFICIAL - llm_result_str = gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) + llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) else: - llm_result_str = str(llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality)) + llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality) llm_result_str = _global_object.reply_prefix + llm_result_str except BaseException as e: @@ -416,9 +419,9 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak if temp_switch != "": chosen_provider = temp_switch - # 指令回复 if hit: - # 检查指令。command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型) + # 有指令或者插件触发 + # command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型) if command_result == None: return command = command_result[2] @@ -436,11 +439,11 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw': for i in command_result[1]: # 保存到本地 - pic_res = requests.get(i, stream = True) - if pic_res.status_code == 200: - image = PILImage.open(io.BytesIO(pic_res.content)) - return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))]) - + async with aiohttp.ClientSession() as session: + async with session.get(i) as resp: + if resp.status == 200: + image = PILImage.open(io.BytesIO(await resp.read())) + return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))]) # 其他指令 else: try: @@ -455,7 +458,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak llm_result_str = re.sub(i, "***", llm_result_str) # 百度内容审核服务二次审核 if baidu_judge != None: - check, msg = baidu_judge.judge(llm_result_str) + check, msg = await asyncio.to_thread(baidu_judge.judge, llm_result_str) if not check: return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}") # 发送信息 diff --git a/model/command/command.py b/model/command/command.py index 205babfc4..632adca38 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -1,17 +1,19 @@ import json -from util import general_utils as gu -import os -import requests -from model.provider.provider import Provider +import inspect +import aiohttp import json + import util.plugin_util as putil -from util.cmd_config import CmdConfig as cc -from util.general_utils import Logger import util.updator + from nakuru.entities.components import ( Plain, Image ) +from util import general_utils as gu +from model.provider.provider import Provider +from util.cmd_config import CmdConfig as cc +from util.general_utils import Logger from cores.qqbot.global_object import GlobalObject, AstrMessageEvent from cores.qqbot.global_object import CommandResult @@ -25,7 +27,7 @@ class Command: self.global_object = global_object self.logger: Logger = global_object.logger - def check_command(self, + async def check_command(self, message, session_id: str, role, @@ -51,7 +53,10 @@ class Command: if "type" in v["info"] and v["info"]["plugin_type"] == "platform": continue try: - result = v["clsobj"].run(ame) + if inspect.iscoroutinefunction(v["clsobj"].run): + result = await v["clsobj"].run(ame) + else: + result = v["clsobj"].run(ame) if isinstance(result, CommandResult): hit = result.hit res = result._result_tuple() @@ -65,13 +70,16 @@ class Command: except TypeError as e: # 参数不匹配,尝试使用旧的参数方案 try: - hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq) + if inspect.iscoroutinefunction(v["clsobj"].run): + hit, res = await v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq) + else: + hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq) if hit: return True, res except BaseException as e: - self.logger.log(f"{k}插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) + self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) except BaseException as e: - self.logger.log(f"{k} 插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) + self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) if self.command_start_with(message, "nick"): return True, self.set_nick(message, platform, role) @@ -79,18 +87,13 @@ class Command: return True, self.plugin_oper(message, role, cached_plugins, platform) if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"): return True, self.get_my_id(message_obj, platform) - if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"): - return True, self.get_new_conf(message, role) if self.command_start_with(message, "web"): # 网页搜索 return True, self.web_search(message) - if self.command_start_with(message, "ip"): - ip = requests.get("https://myip.ipip.net", timeout=5).text - return True, f"机器人 IP 信息:{ip}", "ip" if not self.provider and self.command_start_with(message, "help"): - return True, self.help() + return True, await self.help() return False, None - + def web_search(self, message): l = message.split(' ') if len(l) == 1: @@ -202,10 +205,11 @@ class Command: "/revgpt": "切换到网页版ChatGPT", } - def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None): + async def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None): try: - resp = requests.get("https://soulter.top/channelbot/notice.json").text - notice = json.loads(resp)["notice"] + async with aiohttp.ClientSession() as session: + async with session.get("https://soulter.top/channelbot/notice.json") as resp: + notice = (await resp.json())["notice"] except BaseException as e: notice = "" msg = "# Help Center\n## 指令列表\n" @@ -279,9 +283,9 @@ class Command: def key(self): return False - def help(self): - return True, self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins), "help" - + async def help(self): + ret = await self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins) + return True, ret, "help" def status(self): return False diff --git a/model/command/openai_official.py b/model/command/openai_official.py index d73ab0ae9..9d632790f 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -11,7 +11,7 @@ class CommandOpenAIOfficial(Command): self.personality_str = "" super().__init__(provider, global_object) - def check_command(self, + async def check_command(self, message: str, session_id: str, role: str, @@ -20,7 +20,7 @@ class CommandOpenAIOfficial(Command): self.platform = platform # 检查基础指令 - hit, res = super().check_command( + hit, res = await super().check_command( message, session_id, role, @@ -32,7 +32,7 @@ class CommandOpenAIOfficial(Command): if hit: return True, res if self.command_start_with(message, "reset", "重置"): - return True, self.reset(session_id, message) + return True, await self.reset(session_id, message) elif self.command_start_with(message, "his", "历史"): return True, self.his(message, session_id) elif self.command_start_with(message, "token"): @@ -42,7 +42,7 @@ class CommandOpenAIOfficial(Command): elif self.command_start_with(message, "status"): return True, self.status() elif self.command_start_with(message, "help", "帮助"): - return True, self.help() + return True, await self.help() elif self.command_start_with(message, "unset"): return True, self.unset(session_id) elif self.command_start_with(message, "set"): @@ -54,11 +54,11 @@ class CommandOpenAIOfficial(Command): elif self.command_start_with(message, "key"): return True, self.key(message) elif self.command_start_with(message, "switch"): - return True, self.switch(message) + return True, await self.switch(message) return False, None - def help(self): + async def help(self): commands = super().general_commands() commands['画'] = '画画' commands['key'] = '添加OpenAI key' @@ -66,15 +66,15 @@ class CommandOpenAIOfficial(Command): commands['gpt'] = '查看gpt配置信息' commands['status'] = '查看key使用状态' commands['token'] = '查看本轮会话token' - return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" + return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" - def reset(self, session_id: str, message: str = "reset"): + async def reset(self, session_id: str, message: str = "reset"): if self.provider is None: return False, "未启用 OpenAI 官方 API", "reset" l = message.split(" ") if len(l) == 1: - self.provider.forget(session_id) + await self.provider.forget(session_id) return True, "重置成功", "reset" if len(l) == 2 and l[1] == "p": self.provider.forget(session_id) @@ -146,7 +146,7 @@ class CommandOpenAIOfficial(Command): else: return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key" - def switch(self, message: str): + async def switch(self, message: str): ''' 切换账号 ''' @@ -168,7 +168,7 @@ class CommandOpenAIOfficial(Command): else: try: new_key = list(key_stat.keys())[index-1] - ret = self.provider.check_key(new_key) + ret = await self.provider.check_key(new_key) self.provider.set_key(new_key) except BaseException as e: return True, "账号切换失败,原因: " + str(e), "switch" diff --git a/model/command/rev_chatgpt.py b/model/command/rev_chatgpt.py index 43c96f7c1..1e7b91aa4 100644 --- a/model/command/rev_chatgpt.py +++ b/model/command/rev_chatgpt.py @@ -11,14 +11,14 @@ class CommandRevChatGPT(Command): self.personality_str = "" super().__init__(provider, global_object) - def check_command(self, + async def check_command(self, message: str, session_id: str, role: str, platform: str, message_obj): self.platform = platform - hit, res = super().check_command( + hit, res = await super().check_command( message, session_id, role, @@ -29,7 +29,7 @@ class CommandRevChatGPT(Command): if hit: return True, res if self.command_start_with(message, "help", "帮助"): - return True, self.help() + return True, await self.help() elif self.command_start_with(message, "reset"): return True, self.reset(session_id, message) elif self.command_start_with(message, "update"): @@ -127,7 +127,7 @@ class CommandRevChatGPT(Command): else: return True, "参数过多。", "switch" - def help(self): + async def help(self): commands = super().general_commands() commands['set'] = '设置人格' - return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" + return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" diff --git a/model/platform/_platfrom.py b/model/platform/_platfrom.py index bc72ea68a..106f94df2 100644 --- a/model/platform/_platfrom.py +++ b/model/platform/_platfrom.py @@ -1,7 +1,5 @@ import abc -import threading -import asyncio -from typing import Callable, Union +from typing import Union from nakuru import ( GuildMessage, GroupMessage, @@ -70,14 +68,3 @@ class Platform(): pass ret.replace('\n', '') return ret - - - def new_sub_thread(self, func, args=()): - thread = threading.Thread(target=self._runner, args=(func, args), daemon=True) - thread.start() - - def _runner(self, func: Callable, args: tuple): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(func(*args)) - loop.close() diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py index ba48c2c96..6b13f3987 100644 --- a/model/platform/qq_gocq.py +++ b/model/platform/qq_gocq.py @@ -58,10 +58,10 @@ class QQGOCQ(Platform): async def _(app: CQHTTP, source: GroupMessage): if self.cc.get("gocq_react_group", True): if isinstance(source.message[0], Plain): - self.new_sub_thread(self.handle_msg, (source, True)) + await self.handle_msg(source, True) elif isinstance(source.message[0], At): if source.message[0].qq == source.self_id: - self.new_sub_thread(self.handle_msg, (source, True)) + await self.handle_msg(source, True) else: return @@ -69,7 +69,7 @@ class QQGOCQ(Platform): async def _(app: CQHTTP, source: FriendMessage): if self.cc.get("gocq_react_friend", True): if isinstance(source.message[0], Plain): - self.new_sub_thread(self.handle_msg, (source, False)) + await self.handle_msg(source, False) else: return @@ -85,19 +85,16 @@ class QQGOCQ(Platform): async def _(app: CQHTTP, source: Notify): print(source) if source.sub_type == "poke" and source.target_id == source.self_id: - # await self.handle_msg(source, False) - self.new_sub_thread(self.handle_msg, (source, False)) + await self.handle_msg(source, False) @gocq_app.receiver("GuildMessage") async def _(app: CQHTTP, source: GuildMessage): if self.cc.get("gocq_react_guild", True): if isinstance(source.message[0], Plain): - # await self.handle_msg(source, True) - self.new_sub_thread(self.handle_msg, (source, True)) + await self.handle_msg(source, True) elif isinstance(source.message[0], At): if source.message[0].qq == source.self_tiny_id: - # await self.handle_msg(source, True) - self.new_sub_thread(self.handle_msg, (source, True)) + await self.handle_msg(source, True) else: return @@ -157,7 +154,7 @@ class QQGOCQ(Platform): if message_result is None: return - self.reply_msg(message, message_result.result_message) + await self.reply_msg(message, message_result.result_message) if message_result.callback is not None: message_result.callback() @@ -165,11 +162,11 @@ class QQGOCQ(Platform): if session_id in self.waiting and self.waiting[session_id] == '': self.waiting[session_id] = message - def reply_msg(self, + async def reply_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list): """ - 插件开发者请使用send方法, 可以不用直接调用这个方法。 + 插件开发者请使用send方法, 可以不用直接调用这个方法。 """ source = message res = result_message @@ -205,12 +202,10 @@ class QQGOCQ(Platform): # 回复消息链 if isinstance(res, list) and len(res) > 0: if source.type == "GuildMessage": - # await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res) - asyncio.run_coroutine_threadsafe(self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res), self.loop).result() + await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res) return elif source.type == "FriendMessage": - # await self.client.sendFriendMessage(source.user_id, res) - asyncio.run_coroutine_threadsafe(self.client.sendFriendMessage(source.user_id, res), self.loop).result() + await self.client.sendFriendMessage(source.user_id, res) return elif source.type == "GroupMessage": # 过长时forward发送 @@ -233,37 +228,28 @@ class QQGOCQ(Platform): node.time = int(time.time()) # print(node) nodes=[node] - # await self.client.sendGroupForwardMessage(source.group_id, nodes) - asyncio.run_coroutine_threadsafe(self.client.sendGroupForwardMessage(source.group_id, nodes), self.loop).result() + await self.client.sendGroupForwardMessage(source.group_id, nodes) return - # await self.client.sendGroupMessage(source.group_id, res) - asyncio.run_coroutine_threadsafe(self.client.sendGroupMessage(source.group_id, res), self.loop).result() + await self.client.sendGroupMessage(source.group_id, res) return - def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list): + async def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list): ''' 提供给插件的发送QQ消息接口。 参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 - 非异步 ''' try: - # await self.reply_msg(message, result_message) - self.reply_msg(message, result_message) + await self.reply_msg(message, result_message) except BaseException as e: raise e - def send(self, + async def send(self, to, res): ''' 同 send_msg() - 非异步 ''' - try: - # await self.send_msg(to, res) - self.reply_msg(to, res) - except BaseException as e: - raise e + await self.reply_msg(to, res) def create_text_image(title: str, text: str, max_width=30, font_size=20): ''' @@ -302,12 +288,12 @@ class QQGOCQ(Platform): def get_client(self): return self.client - def nakuru_method_invoker(self, func, *args, **kwargs): + async def nakuru_method_invoker(self, func, *args, **kwargs): """ 返回一个方法调用器,可以用来立即调用nakuru的方法。 """ try: - ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop).result() + ret = func(*args, **kwargs) return ret except BaseException as e: raise e diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 389b4cb71..5f1b1f2a7 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -4,7 +4,7 @@ from PIL import Image as PILImage from botpy.message import Message, DirectMessage import re import asyncio -import requests +import aiohttp from util import general_utils as gu from botpy.types.message import Reference @@ -15,7 +15,7 @@ from ._nakuru_translation_layer import( NakuruGuildMessage, gocq_compatible_receive, gocq_compatible_send -) +) from typing import Union # QQ 机器人官方框架 @@ -27,13 +27,13 @@ class botClient(Client): async def on_at_message_create(self, message: Message): # 转换层 nakuru_guild_message = gocq_compatible_receive(message) - self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, True)) + await self.platform.handle_msg(nakuru_guild_message, True) # 收到私聊消息 async def on_direct_message_create(self, message: DirectMessage): # 转换层 nakuru_guild_message = gocq_compatible_receive(message) - self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, False)) + await self.platform.handle_msg(nakuru_guild_message, False) class QQOfficial(Platform): @@ -107,7 +107,7 @@ class QQOfficial(Platform): if message_result is None: return - self.reply_msg(is_group, message, message_result.result_message) + await self.reply_msg(is_group, message, message_result.result_message) if message_result.callback is not None: message_result.callback() @@ -115,7 +115,7 @@ class QQOfficial(Platform): if session_id in self.waiting and self.waiting[session_id] == '': self.waiting[session_id] = message - def reply_msg(self, + async def reply_msg(self, is_group: bool, message: NakuruGuildMessage, res: Union[str, list]): @@ -148,10 +148,11 @@ class QQOfficial(Platform): if image_path is not None and image_path != '': msg_ref = None if image_path.startswith("http"): - pic_res = requests.get(image_path, stream = True) - if pic_res.status_code == 200: - image = PILImage.open(io.BytesIO(pic_res.content)) - image_path = gu.save_temp_img(image) + async with aiohttp.ClientSession() as session: + async with session.get(image_path) as response: + if response.status == 200: + image = PILImage.open(io.BytesIO(await response.read())) + image_path = gu.save_temp_img(image) if message.raw_message is not None and image_path == '': # file_image与message_reference不能同时传入 msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False) @@ -170,8 +171,7 @@ class QQOfficial(Platform): data['file_image'] = image_path try: - # await self._send_wrapper(**data) - self._send_wrapper(**data) + await self._send_wrapper(**data) except BaseException as e: print(e) # 分割过长的消息 @@ -181,51 +181,44 @@ class QQOfficial(Platform): split_res.append(plain_text[len(plain_text)//2:]) for i in split_res: data['content'] = i - # await self._send_wrapper(**data) - self._send_wrapper(**data) + await self._send_wrapper(**data) else: # 发送qq信息 try: # 防止被qq频道过滤消息 plain_text = plain_text.replace(".", " . ") - # await self._send_wrapper(**data) - self._send_wrapper(**data) + await self._send_wrapper(**data) except BaseException as e: try: data['content'] = str.join(" ", plain_text) - # await self._send_wrapper(**data) - self._send_wrapper(**data) + await self._send_wrapper(**data) except BaseException as e: plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) plain_text = plain_text.replace(".", "·") data['content'] = plain_text - # await self._send_wrapper(**data) - self._send_wrapper(**data) + await self._send_wrapper(**data) - def _send_wrapper(self, **kwargs): + async def _send_wrapper(self, **kwargs): if 'channel_id' in kwargs: - asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result() + await self.client.api.post_message(**kwargs) else: - asyncio.run_coroutine_threadsafe(self.client.api.post_dms(**kwargs), self.loop).result() + await self.client.api.post_dms(**kwargs) - - def send_msg(self, channel_id: int, message_chain: list, message_id: int = None): + async def send_msg(self, channel_id: int, message_chain: list, message_id: int = None): ''' - 推送消息, 如果有 message_id,那么就是回复消息。非异步。 + 推送消息, 如果有 message_id,那么就是回复消息。 ''' _n = NakuruGuildMessage() _n.channel_id = channel_id _n.message_id = message_id - # await self.reply_msg(_n, message_chain) - self.reply_msg(_n, message_chain) + await self.reply_msg(_n, message_chain) - def send(self, message_obj, message_chain: list): + async def send(self, message_obj, message_chain: list): ''' - 发送信息。内容同 reply_msg。非异步。 + 发送信息。内容同 reply_msg。 ''' - # await self.reply_msg(message_obj, message_chain) - self.reply_msg(message_obj, message_chain) + await self.reply_msg(message_obj, message_chain) def wait_for_message(self, channel_id: int) -> NakuruGuildMessage: ''' diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 9f9847d2b..6ce008785 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -1,18 +1,21 @@ -from openai import OpenAI -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.images_response import ImagesResponse -import json -import time import os import sys +import json +import time +import tiktoken +import threading +import traceback + +from openai import AsyncOpenAI +from openai.types.images_response import ImagesResponse +from openai.types.chat.chat_completion import ChatCompletion + from cores.database.conn import dbConn from model.provider.provider import Provider -import threading from util import general_utils as gu from util.cmd_config import CmdConfig from util.general_utils import Logger -import traceback -import tiktoken + abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' @@ -42,7 +45,7 @@ class ProviderOpenAIOfficial(Provider): self.logger.log(f"设置 api_base 为: {self.api_base}", tag="OpenAI") # 创建 OpenAI Client - self.client = OpenAI( + self.client = AsyncOpenAI( api_key=self.key_list[0], base_url=self.api_base ) @@ -113,7 +116,7 @@ class ProviderOpenAIOfficial(Provider): } self.session_dict[session_id].append(new_record) - def text_chat(self, prompt, + async def text_chat(self, prompt, session_id = None, image_url = None, function_call=None, @@ -132,7 +135,6 @@ class ProviderOpenAIOfficial(Provider): if default_personality is not None: self.personality_set(default_personality, session_id) - # 使用 tictoken 截断消息 _encoded_prompt = self.enc.encode(prompt) if self.openai_model_configs['max_tokens'] < len(_encoded_prompt): @@ -140,8 +142,8 @@ class ProviderOpenAIOfficial(Provider): self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, tag="OpenAI") cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url) - self.logger.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI") - self.logger.log(f"OPENAI REQUEST: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI") + self.logger.log(f"cache: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI") + self.logger.log(f"request: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI") retry = 0 response = None err = '' @@ -168,19 +170,19 @@ class ProviderOpenAIOfficial(Provider): while retry < 10: try: if function_call is None: - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( messages=req, **conf ) else: - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( messages=req, tools = function_call, **conf ) break except Exception as e: - print(traceback.format_exc()) + traceback.print_exc() if 'Invalid content type. image_url is only supported by certain models.' in str(e): raise e if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): @@ -188,7 +190,6 @@ class ProviderOpenAIOfficial(Provider): self.key_stat[self.client.api_key]['exceed'] = True is_switched = self.handle_switch_key() if not is_switched: - # 所有Key都超额或不正常 raise e retry -= 1 elif 'maximum context length' in str(e): @@ -239,7 +240,6 @@ class ProviderOpenAIOfficial(Provider): index += 1 # 删除完后更新相关字段 self.session_dict[session_id] = cache_data_list - # cache_prompt = get_prompts_by_cache_list(cache_data_list) # 添加新条目进入缓存的prompt new_record['AI'] = { @@ -258,7 +258,7 @@ class ProviderOpenAIOfficial(Provider): return chatgpt_res - def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"): + async def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"): retry = 0 image_url = '' @@ -266,7 +266,7 @@ class ProviderOpenAIOfficial(Provider): while retry < 5: try: - response: ImagesResponse = self.client.images.generate( + response: ImagesResponse = await self.client.images.generate( prompt=prompt, **image_generate_configs ) @@ -282,7 +282,6 @@ class ProviderOpenAIOfficial(Provider): self.key_stat[self.client.api_key]['exceed'] = True is_switched = self.handle_switch_key() if not is_switched: - # 所有Key都超额或不正常 raise e elif 'Your request was rejected as a result of our safety system.' in str(e): self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING, tag="OpenAI") @@ -294,16 +293,16 @@ class ProviderOpenAIOfficial(Provider): return image_url - def forget(self, session_id = None) -> bool: + async def forget(self, session_id = None) -> bool: if session_id is None: return False self.session_dict[session_id] = [] return True - ''' - 获取缓存的会话 - ''' def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1): + ''' + 获取缓存的会话 + ''' prompts = "" if paging: page_begin = (page-1)*size @@ -320,15 +319,7 @@ class ProviderOpenAIOfficial(Provider): if divide: prompts += "----------\n" return prompts - - - def get_user_usage_tokens(self,cache_list): - usage_tokens = 0 - for item in cache_list: - usage_tokens += int(item['single_tokens']) - return usage_tokens - - # 包装信息 + def wrap(self, prompt, session_id, image_url = None): if image_url is not None: prompt = [ @@ -364,7 +355,6 @@ class ProviderOpenAIOfficial(Provider): return context, new_record, req_list def handle_switch_key(self): - # messages = [{"role": "user", "content": prompt}] is_all_exceed = True for key in self.key_stat: if key == None or self.key_stat[key]['exceed']: @@ -399,13 +389,13 @@ class ProviderOpenAIOfficial(Provider): self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor} # 检查key是否可用 - def check_key(self, key): - client_ = OpenAI( + async def check_key(self, key): + client_ = AsyncOpenAI( api_key=key, base_url=self.api_base ) messages = [{"role": "user", "content": "please just echo `test`"}] - client_.chat.completions.create( + await client_.chat.completions.create( messages=messages, **self.openai_model_configs ) diff --git a/model/provider/provider.py b/model/provider/provider.py index a3a77867c..6f4157e6e 100644 --- a/model/provider/provider.py +++ b/model/provider/provider.py @@ -5,9 +5,9 @@ class Provider: pass @abc.abstractmethod - def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str: + async def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str: pass @abc.abstractmethod - def forget(self, session_id = None) -> bool: + async def forget(self, session_id = None) -> bool: pass \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0f00413d5..6dd7c7c44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ pydantic~=1.10.4 -requests~=2.28.1 +aiohttp +requests openai~=1.2.3 qq-botpy chardet~=5.1.0 diff --git a/util/function_calling/gplugin.py b/util/function_calling/gplugin.py index 428410c3d..3e62511af 100644 --- a/util/function_calling/gplugin.py +++ b/util/function_calling/gplugin.py @@ -1,18 +1,19 @@ import requests import util.general_utils as gu -from bs4 import BeautifulSoup +import traceback import time +import json +import asyncio +from googlesearch import search, SearchResult +from readability import Document +from bs4 import BeautifulSoup +from openai.types.chat.chat_completion_message_tool_call import Function from util.function_calling.func_call import ( FuncCall, FuncCallJsonFormatError, FuncNotFoundError ) -from openai.types.chat.chat_completion_message_tool_call import Function -import traceback -from googlesearch import search, SearchResult from model.provider.provider import Provider -import json -from readability import Document def tidy_text(text: str) -> str: @@ -53,11 +54,11 @@ def google_web_search(keyword) -> str: for i in ls: desc = i.description try: - gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO) + # gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO) desc = fetch_website_content(i.url) except BaseException as e: print(f"(google) fetch_website_content err: {str(e)}") - gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999) + # gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999) ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n" index += 1 except Exception as e: @@ -80,7 +81,7 @@ def web_keyword_search_via_bing(keyword) -> str: try: response = requests.get(url, headers=headers) response.encoding = "utf-8" - gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999) + # gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999) soup = BeautifulSoup(response.text, "html.parser") res = "" result_cnt = 0 @@ -96,7 +97,7 @@ def web_keyword_search_via_bing(keyword) -> str: # "link": link, # }) try: - gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO) + # gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO) desc = fetch_website_content(link) except BaseException as e: print(f"(bing) fetch_website_content err: {str(e)}") @@ -124,11 +125,11 @@ def web_keyword_search_via_bing(keyword) -> str: if result_cnt == 0: break return res except Exception as e: - gu.log(f"bing fetch err: {str(e)}") + # gu.log(f"bing fetch err: {str(e)}") _cnt += 1 time.sleep(1) - gu.log("fail to fetch bing info, using sougou.") + # gu.log("fail to fetch bing info, using sougou.") return web_keyword_search_via_sougou(keyword) def web_keyword_search_via_sougou(keyword) -> str: @@ -157,7 +158,7 @@ def web_keyword_search_via_sougou(keyword) -> str: break except Exception as e: pass - gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR) + # gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR) # 爬取网页内容 _detail_store = [] for i in res: @@ -173,7 +174,7 @@ def web_keyword_search_via_sougou(keyword) -> str: return ret def fetch_website_content(url): - gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG) + # gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG) headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" @@ -187,7 +188,7 @@ def fetch_website_content(url): ret = tidy_text(soup.get_text()) return ret -def web_search(question, provider: Provider, session_id, official_fc=False): +async def web_search(question, provider: Provider, session_id, official_fc=False): ''' official_fc: 使用官方 function-calling ''' @@ -197,7 +198,7 @@ def web_search(question, provider: Provider, session_id, official_fc=False): "name": "keyword", "description": "google search query (分词,尽量保留所有信息)" }], - "通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", + "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", web_keyword_search_via_bing ) new_func_call.add_func("fetch_website_content", [{ @@ -205,16 +206,16 @@ def web_search(question, provider: Provider, session_id, official_fc=False): "name": "url", "description": "网址" }], - "获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。", + "获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", fetch_website_content ) question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。" has_func = False function_invoked_ret = "" if official_fc: - func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func()) + # we use official function-calling + func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func()) if isinstance(func, Function): - # arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search' # 执行对应的结果: func_obj = None for i in new_func_call.func_list: @@ -222,49 +223,68 @@ def web_search(question, provider: Provider, session_id, official_fc=False): func_obj = i["func_obj"] break if not func_obj: - gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR) - return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" + # gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR) + return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" try: args = json.loads(func.arguments) - function_invoked_ret = func_obj(**args) + # we use to_thread to avoid blocking the event loop + function_invoked_ret = await asyncio.to_thread(func_obj, **args) has_func = True except BaseException as e: traceback.print_exc() - return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" + return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" else: # now func is a string return func else: + # we use our own function-calling try: - function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False) + args = { + 'question': question1, + 'func_definition': new_func_call.func_dump(), + 'is_task': False, + 'is_summary': False, + } + function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args) except BaseException as e: - res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)" + res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)" return res has_func = True if has_func: - provider.forget(session_id) + await provider.forget(session_id) question3 = f""" -以下是相关材料,你的任务是: -1. 根据材料对问题`{question}`做切题的总结回答; -2. 发表你对这个问题的看法. +你的任务是: +1. 根据末尾的材料对问题`{question}`做切题的总结(详细); +2. 简单地发表你对这个问题的看法(简略)。 你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。 -不要提到任何函数调用的信息。以下是相关材料: +不要提到任何函数调用的信息。 + +一些回复的消息模板: +模板1: +``` +从网上的信息来看,可以知道...我个人认为...你觉得呢? +``` +模板2: +``` +根据网上的最新信息,可以得知...我觉得...你怎么看? +``` +你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。 + +以下是相关材料: """ - - gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999) _c = 0 while _c < 3: try: print('text chat') - final_ret = provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id) + final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id) return final_ret except Exception as e: print(e) _c += 1 if _c == 3: raise e if "The message you submitted was too long" in str(e): - provider.forget(session_id) + await provider.forget(session_id) function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)] time.sleep(3) return function_invoked_ret diff --git a/util/general_utils.py b/util/general_utils.py index 3500edac3..d24cecbcf 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -9,7 +9,6 @@ from util.cmd_config import CmdConfig import socket from cores.qqbot.global_object import GlobalObject import platform -import requests import logging import json import sys