feat: 异步重写

perf: 优化网页搜索回答规范
This commit is contained in:
Soulter
2024-03-03 18:54:50 +08:00
parent 53ef3bbf4f
commit 1cd1c8ea0d
12 changed files with 205 additions and 222 deletions
+30 -27
View File
@@ -1,33 +1,35 @@
import re
import json
import threading
import asyncio
import time
import requests
import aiohttp
import util.unfit_words as uw
import os
import sys
from addons.baidu_aip_judge import BaiduJudge
import io
import traceback
import util.function_calling.gplugin as gplugin
import util.plugin_util as putil
from PIL import Image as PILImage
from typing import Union
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from nakuru.entities.components import Plain, At, Image
from addons.baidu_aip_judge import BaiduJudge
from model.platform._nakuru_translation_layer import NakuruGuildMessage
from nakuru.entities.components import Plain,At,Image
from model.provider.provider import Provider
from model.command.command import Command
from util import general_utils as gu
from util.general_utils import Logger, upload, run_monitor
from util.cmd_config import CmdConfig as cc
from util.cmd_config import init_astrbot_config_items
import util.function_calling.gplugin as gplugin
import util.plugin_util as putil
from PIL import Image as PILImage
import io
import traceback
from . global_object import GlobalObject
from typing import Union
from addons.dashboard.helper import DashBoardHelper
from addons.dashboard.server import DashBoardData
from cores.database.conn import dbConn
@@ -41,7 +43,7 @@ frequency_time = 60
frequency_count = 10
# 版本
version = '3.1.5'
version = '3.1.6'
# 语言模型
REV_CHATGPT = 'rev_chatgpt'
@@ -325,7 +327,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
command_result = () # 调用指令返回的结果
# 统计数据,如频道消息量
record_message(platform, session_id)
await record_message(platform, session_id)
for i in message.message:
if isinstance(i, Plain):
@@ -334,8 +336,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
return MessageResult("Hi~")
# 检查发言频率
user_id = message.user_id
if not check_frequency(user_id):
if not check_frequency(message.user_id):
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。')
# 检查是否是更换语言模型的请求
@@ -359,7 +360,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
llm_result_str = ""
hit, command_result = llm_command_instance[chosen_provider].check_command(
# check commands and plugins
hit, command_result = await llm_command_instance[chosen_provider].check_command(
message_str,
session_id,
role,
@@ -375,11 +377,12 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if matches:
return MessageResult(f"你的提问得到的回复未通过【默认关键词拦截】服务, 不予回复。")
if baidu_judge != None:
check, msg = baidu_judge.judge(message_str)
check, msg = await asyncio.to_thread(baidu_judge.judge, message_str)
if not check:
return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}")
if chosen_provider == NONE_LLM:
return MessageResult("没有启动任何 LLM 并且未触发任何指令。")
logger.log("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。", gu.LEVEL_WARNING)
return
try:
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
return
@@ -403,9 +406,9 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag:
official_fc = chosen_provider == OPENAI_OFFICIAL
llm_result_str = gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
else:
llm_result_str = str(llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality))
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality)
llm_result_str = _global_object.reply_prefix + llm_result_str
except BaseException as e:
@@ -416,9 +419,9 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if temp_switch != "":
chosen_provider = temp_switch
# 指令回复
if hit:
# 检查指令。command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
# 有指令或者插件触发
# command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
if command_result == None:
return
command = command_result[2]
@@ -436,11 +439,11 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
for i in command_result[1]:
# 保存到本地
pic_res = requests.get(i, stream = True)
if pic_res.status_code == 200:
image = PILImage.open(io.BytesIO(pic_res.content))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
async with aiohttp.ClientSession() as session:
async with session.get(i) as resp:
if resp.status == 200:
image = PILImage.open(io.BytesIO(await resp.read()))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
# 其他指令
else:
try:
@@ -455,7 +458,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
llm_result_str = re.sub(i, "***", llm_result_str)
# 百度内容审核服务二次审核
if baidu_judge != None:
check, msg = baidu_judge.judge(llm_result_str)
check, msg = await asyncio.to_thread(baidu_judge.judge, llm_result_str)
if not check:
return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}")
# 发送信息
+28 -24
View File
@@ -1,17 +1,19 @@
import json
from util import general_utils as gu
import os
import requests
from model.provider.provider import Provider
import inspect
import aiohttp
import json
import util.plugin_util as putil
from util.cmd_config import CmdConfig as cc
from util.general_utils import Logger
import util.updator
from nakuru.entities.components import (
Plain,
Image
)
from util import general_utils as gu
from model.provider.provider import Provider
from util.cmd_config import CmdConfig as cc
from util.general_utils import Logger
from cores.qqbot.global_object import GlobalObject, AstrMessageEvent
from cores.qqbot.global_object import CommandResult
@@ -25,7 +27,7 @@ class Command:
self.global_object = global_object
self.logger: Logger = global_object.logger
def check_command(self,
async def check_command(self,
message,
session_id: str,
role,
@@ -51,7 +53,10 @@ class Command:
if "type" in v["info"] and v["info"]["plugin_type"] == "platform":
continue
try:
result = v["clsobj"].run(ame)
if inspect.iscoroutinefunction(v["clsobj"].run):
result = await v["clsobj"].run(ame)
else:
result = v["clsobj"].run(ame)
if isinstance(result, CommandResult):
hit = result.hit
res = result._result_tuple()
@@ -65,13 +70,16 @@ class Command:
except TypeError as e:
# 参数不匹配,尝试使用旧的参数方案
try:
hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
if inspect.iscoroutinefunction(v["clsobj"].run):
hit, res = await v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
else:
hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
if hit:
return True, res
except BaseException as e:
self.logger.log(f"{k}插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
except BaseException as e:
self.logger.log(f"{k} 插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform, role)
@@ -79,18 +87,13 @@ class Command:
return True, self.plugin_oper(message, role, cached_plugins, platform)
if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"):
return True, self.get_my_id(message_obj, platform)
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
return True, self.get_new_conf(message, role)
if self.command_start_with(message, "web"): # 网页搜索
return True, self.web_search(message)
if self.command_start_with(message, "ip"):
ip = requests.get("https://myip.ipip.net", timeout=5).text
return True, f"机器人 IP 信息:{ip}", "ip"
if not self.provider and self.command_start_with(message, "help"):
return True, self.help()
return True, await self.help()
return False, None
def web_search(self, message):
l = message.split(' ')
if len(l) == 1:
@@ -202,10 +205,11 @@ class Command:
"/revgpt": "切换到网页版ChatGPT",
}
def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None):
async def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None):
try:
resp = requests.get("https://soulter.top/channelbot/notice.json").text
notice = json.loads(resp)["notice"]
async with aiohttp.ClientSession() as session:
async with session.get("https://soulter.top/channelbot/notice.json") as resp:
notice = (await resp.json())["notice"]
except BaseException as e:
notice = ""
msg = "# Help Center\n## 指令列表\n"
@@ -279,9 +283,9 @@ class Command:
def key(self):
return False
def help(self):
return True, self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins), "help"
async def help(self):
ret = await self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins)
return True, ret, "help"
def status(self):
return False
+11 -11
View File
@@ -11,7 +11,7 @@ class CommandOpenAIOfficial(Command):
self.personality_str = ""
super().__init__(provider, global_object)
def check_command(self,
async def check_command(self,
message: str,
session_id: str,
role: str,
@@ -20,7 +20,7 @@ class CommandOpenAIOfficial(Command):
self.platform = platform
# 检查基础指令
hit, res = super().check_command(
hit, res = await super().check_command(
message,
session_id,
role,
@@ -32,7 +32,7 @@ class CommandOpenAIOfficial(Command):
if hit:
return True, res
if self.command_start_with(message, "reset", "重置"):
return True, self.reset(session_id, message)
return True, await self.reset(session_id, message)
elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id)
elif self.command_start_with(message, "token"):
@@ -42,7 +42,7 @@ class CommandOpenAIOfficial(Command):
elif self.command_start_with(message, "status"):
return True, self.status()
elif self.command_start_with(message, "help", "帮助"):
return True, self.help()
return True, await self.help()
elif self.command_start_with(message, "unset"):
return True, self.unset(session_id)
elif self.command_start_with(message, "set"):
@@ -54,11 +54,11 @@ class CommandOpenAIOfficial(Command):
elif self.command_start_with(message, "key"):
return True, self.key(message)
elif self.command_start_with(message, "switch"):
return True, self.switch(message)
return True, await self.switch(message)
return False, None
def help(self):
async def help(self):
commands = super().general_commands()
commands[''] = '画画'
commands['key'] = '添加OpenAI key'
@@ -66,15 +66,15 @@ class CommandOpenAIOfficial(Command):
commands['gpt'] = '查看gpt配置信息'
commands['status'] = '查看key使用状态'
commands['token'] = '查看本轮会话token'
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
def reset(self, session_id: str, message: str = "reset"):
async def reset(self, session_id: str, message: str = "reset"):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
self.provider.forget(session_id)
await self.provider.forget(session_id)
return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
@@ -146,7 +146,7 @@ class CommandOpenAIOfficial(Command):
else:
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
def switch(self, message: str):
async def switch(self, message: str):
'''
切换账号
'''
@@ -168,7 +168,7 @@ class CommandOpenAIOfficial(Command):
else:
try:
new_key = list(key_stat.keys())[index-1]
ret = self.provider.check_key(new_key)
ret = await self.provider.check_key(new_key)
self.provider.set_key(new_key)
except BaseException as e:
return True, "账号切换失败,原因: " + str(e), "switch"
+5 -5
View File
@@ -11,14 +11,14 @@ class CommandRevChatGPT(Command):
self.personality_str = ""
super().__init__(provider, global_object)
def check_command(self,
async def check_command(self,
message: str,
session_id: str,
role: str,
platform: str,
message_obj):
self.platform = platform
hit, res = super().check_command(
hit, res = await super().check_command(
message,
session_id,
role,
@@ -29,7 +29,7 @@ class CommandRevChatGPT(Command):
if hit:
return True, res
if self.command_start_with(message, "help", "帮助"):
return True, self.help()
return True, await self.help()
elif self.command_start_with(message, "reset"):
return True, self.reset(session_id, message)
elif self.command_start_with(message, "update"):
@@ -127,7 +127,7 @@ class CommandRevChatGPT(Command):
else:
return True, "参数过多。", "switch"
def help(self):
async def help(self):
commands = super().general_commands()
commands['set'] = '设置人格'
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
+1 -14
View File
@@ -1,7 +1,5 @@
import abc
import threading
import asyncio
from typing import Callable, Union
from typing import Union
from nakuru import (
GuildMessage,
GroupMessage,
@@ -70,14 +68,3 @@ class Platform():
pass
ret.replace('\n', '')
return ret
def new_sub_thread(self, func, args=()):
thread = threading.Thread(target=self._runner, args=(func, args), daemon=True)
thread.start()
def _runner(self, func: Callable, args: tuple):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(func(*args))
loop.close()
+19 -33
View File
@@ -58,10 +58,10 @@ class QQGOCQ(Platform):
async def _(app: CQHTTP, source: GroupMessage):
if self.cc.get("gocq_react_group", True):
if isinstance(source.message[0], Plain):
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
elif isinstance(source.message[0], At):
if source.message[0].qq == source.self_id:
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
else:
return
@@ -69,7 +69,7 @@ class QQGOCQ(Platform):
async def _(app: CQHTTP, source: FriendMessage):
if self.cc.get("gocq_react_friend", True):
if isinstance(source.message[0], Plain):
self.new_sub_thread(self.handle_msg, (source, False))
await self.handle_msg(source, False)
else:
return
@@ -85,19 +85,16 @@ class QQGOCQ(Platform):
async def _(app: CQHTTP, source: Notify):
print(source)
if source.sub_type == "poke" and source.target_id == source.self_id:
# await self.handle_msg(source, False)
self.new_sub_thread(self.handle_msg, (source, False))
await self.handle_msg(source, False)
@gocq_app.receiver("GuildMessage")
async def _(app: CQHTTP, source: GuildMessage):
if self.cc.get("gocq_react_guild", True):
if isinstance(source.message[0], Plain):
# await self.handle_msg(source, True)
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
elif isinstance(source.message[0], At):
if source.message[0].qq == source.self_tiny_id:
# await self.handle_msg(source, True)
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
else:
return
@@ -157,7 +154,7 @@ class QQGOCQ(Platform):
if message_result is None:
return
self.reply_msg(message, message_result.result_message)
await self.reply_msg(message, message_result.result_message)
if message_result.callback is not None:
message_result.callback()
@@ -165,11 +162,11 @@ class QQGOCQ(Platform):
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
def reply_msg(self,
async def reply_msg(self,
message: Union[GroupMessage, FriendMessage, GuildMessage, Notify],
result_message: list):
"""
插件开发者请使用send方法, 可以不用直接调用这个方法。
插件开发者请使用send方法, 可以不用直接调用这个方法。
"""
source = message
res = result_message
@@ -205,12 +202,10 @@ class QQGOCQ(Platform):
# 回复消息链
if isinstance(res, list) and len(res) > 0:
if source.type == "GuildMessage":
# await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res)
asyncio.run_coroutine_threadsafe(self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res), self.loop).result()
await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res)
return
elif source.type == "FriendMessage":
# await self.client.sendFriendMessage(source.user_id, res)
asyncio.run_coroutine_threadsafe(self.client.sendFriendMessage(source.user_id, res), self.loop).result()
await self.client.sendFriendMessage(source.user_id, res)
return
elif source.type == "GroupMessage":
# 过长时forward发送
@@ -233,37 +228,28 @@ class QQGOCQ(Platform):
node.time = int(time.time())
# print(node)
nodes=[node]
# await self.client.sendGroupForwardMessage(source.group_id, nodes)
asyncio.run_coroutine_threadsafe(self.client.sendGroupForwardMessage(source.group_id, nodes), self.loop).result()
await self.client.sendGroupForwardMessage(source.group_id, nodes)
return
# await self.client.sendGroupMessage(source.group_id, res)
asyncio.run_coroutine_threadsafe(self.client.sendGroupMessage(source.group_id, res), self.loop).result()
await self.client.sendGroupMessage(source.group_id, res)
return
def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list):
async def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list):
'''
提供给插件的发送QQ消息接口。
参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。
非异步
'''
try:
# await self.reply_msg(message, result_message)
self.reply_msg(message, result_message)
await self.reply_msg(message, result_message)
except BaseException as e:
raise e
def send(self,
async def send(self,
to,
res):
'''
同 send_msg()
非异步
'''
try:
# await self.send_msg(to, res)
self.reply_msg(to, res)
except BaseException as e:
raise e
await self.reply_msg(to, res)
def create_text_image(title: str, text: str, max_width=30, font_size=20):
'''
@@ -302,12 +288,12 @@ class QQGOCQ(Platform):
def get_client(self):
return self.client
def nakuru_method_invoker(self, func, *args, **kwargs):
async def nakuru_method_invoker(self, func, *args, **kwargs):
"""
返回一个方法调用器,可以用来立即调用nakuru的方法。
"""
try:
ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop).result()
ret = func(*args, **kwargs)
return ret
except BaseException as e:
raise e
+25 -32
View File
@@ -4,7 +4,7 @@ from PIL import Image as PILImage
from botpy.message import Message, DirectMessage
import re
import asyncio
import requests
import aiohttp
from util import general_utils as gu
from botpy.types.message import Reference
@@ -15,7 +15,7 @@ from ._nakuru_translation_layer import(
NakuruGuildMessage,
gocq_compatible_receive,
gocq_compatible_send
)
)
from typing import Union
# QQ 机器人官方框架
@@ -27,13 +27,13 @@ class botClient(Client):
async def on_at_message_create(self, message: Message):
# 转换层
nakuru_guild_message = gocq_compatible_receive(message)
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, True))
await self.platform.handle_msg(nakuru_guild_message, True)
# 收到私聊消息
async def on_direct_message_create(self, message: DirectMessage):
# 转换层
nakuru_guild_message = gocq_compatible_receive(message)
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, False))
await self.platform.handle_msg(nakuru_guild_message, False)
class QQOfficial(Platform):
@@ -107,7 +107,7 @@ class QQOfficial(Platform):
if message_result is None:
return
self.reply_msg(is_group, message, message_result.result_message)
await self.reply_msg(is_group, message, message_result.result_message)
if message_result.callback is not None:
message_result.callback()
@@ -115,7 +115,7 @@ class QQOfficial(Platform):
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
def reply_msg(self,
async def reply_msg(self,
is_group: bool,
message: NakuruGuildMessage,
res: Union[str, list]):
@@ -148,10 +148,11 @@ class QQOfficial(Platform):
if image_path is not None and image_path != '':
msg_ref = None
if image_path.startswith("http"):
pic_res = requests.get(image_path, stream = True)
if pic_res.status_code == 200:
image = PILImage.open(io.BytesIO(pic_res.content))
image_path = gu.save_temp_img(image)
async with aiohttp.ClientSession() as session:
async with session.get(image_path) as response:
if response.status == 200:
image = PILImage.open(io.BytesIO(await response.read()))
image_path = gu.save_temp_img(image)
if message.raw_message is not None and image_path == '': # file_image与message_reference不能同时传入
msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False)
@@ -170,8 +171,7 @@ class QQOfficial(Platform):
data['file_image'] = image_path
try:
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
except BaseException as e:
print(e)
# 分割过长的消息
@@ -181,51 +181,44 @@ class QQOfficial(Platform):
split_res.append(plain_text[len(plain_text)//2:])
for i in split_res:
data['content'] = i
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
else:
# 发送qq信息
try:
# 防止被qq频道过滤消息
plain_text = plain_text.replace(".", " . ")
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
except BaseException as e:
try:
data['content'] = str.join(" ", plain_text)
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
except BaseException as e:
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = plain_text.replace(".", "·")
data['content'] = plain_text
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
def _send_wrapper(self, **kwargs):
async def _send_wrapper(self, **kwargs):
if 'channel_id' in kwargs:
asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result()
await self.client.api.post_message(**kwargs)
else:
asyncio.run_coroutine_threadsafe(self.client.api.post_dms(**kwargs), self.loop).result()
await self.client.api.post_dms(**kwargs)
def send_msg(self, channel_id: int, message_chain: list, message_id: int = None):
async def send_msg(self, channel_id: int, message_chain: list, message_id: int = None):
'''
推送消息, 如果有 message_id,那么就是回复消息。非异步。
推送消息, 如果有 message_id,那么就是回复消息。
'''
_n = NakuruGuildMessage()
_n.channel_id = channel_id
_n.message_id = message_id
# await self.reply_msg(_n, message_chain)
self.reply_msg(_n, message_chain)
await self.reply_msg(_n, message_chain)
def send(self, message_obj, message_chain: list):
async def send(self, message_obj, message_chain: list):
'''
发送信息。内容同 reply_msg。非异步。
发送信息。内容同 reply_msg。
'''
# await self.reply_msg(message_obj, message_chain)
self.reply_msg(message_obj, message_chain)
await self.reply_msg(message_obj, message_chain)
def wait_for_message(self, channel_id: int) -> NakuruGuildMessage:
'''
+28 -38
View File
@@ -1,18 +1,21 @@
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.images_response import ImagesResponse
import json
import time
import os
import sys
import json
import time
import tiktoken
import threading
import traceback
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from cores.database.conn import dbConn
from model.provider.provider import Provider
import threading
from util import general_utils as gu
from util.cmd_config import CmdConfig
from util.general_utils import Logger
import traceback
import tiktoken
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
@@ -42,7 +45,7 @@ class ProviderOpenAIOfficial(Provider):
self.logger.log(f"设置 api_base 为: {self.api_base}", tag="OpenAI")
# 创建 OpenAI Client
self.client = OpenAI(
self.client = AsyncOpenAI(
api_key=self.key_list[0],
base_url=self.api_base
)
@@ -113,7 +116,7 @@ class ProviderOpenAIOfficial(Provider):
}
self.session_dict[session_id].append(new_record)
def text_chat(self, prompt,
async def text_chat(self, prompt,
session_id = None,
image_url = None,
function_call=None,
@@ -132,7 +135,6 @@ class ProviderOpenAIOfficial(Provider):
if default_personality is not None:
self.personality_set(default_personality, session_id)
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
@@ -140,8 +142,8 @@ class ProviderOpenAIOfficial(Provider):
self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, tag="OpenAI")
cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url)
self.logger.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
self.logger.log(f"OPENAI REQUEST: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
self.logger.log(f"cache: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
self.logger.log(f"request: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
retry = 0
response = None
err = ''
@@ -168,19 +170,19 @@ class ProviderOpenAIOfficial(Provider):
while retry < 10:
try:
if function_call is None:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
messages=req,
tools = function_call,
**conf
)
break
except Exception as e:
print(traceback.format_exc())
traceback.print_exc()
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
@@ -188,7 +190,6 @@ class ProviderOpenAIOfficial(Provider):
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
# 所有Key都超额或不正常
raise e
retry -= 1
elif 'maximum context length' in str(e):
@@ -239,7 +240,6 @@ class ProviderOpenAIOfficial(Provider):
index += 1
# 删除完后更新相关字段
self.session_dict[session_id] = cache_data_list
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
# 添加新条目进入缓存的prompt
new_record['AI'] = {
@@ -258,7 +258,7 @@ class ProviderOpenAIOfficial(Provider):
return chatgpt_res
def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
async def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
retry = 0
image_url = ''
@@ -266,7 +266,7 @@ class ProviderOpenAIOfficial(Provider):
while retry < 5:
try:
response: ImagesResponse = self.client.images.generate(
response: ImagesResponse = await self.client.images.generate(
prompt=prompt,
**image_generate_configs
)
@@ -282,7 +282,6 @@ class ProviderOpenAIOfficial(Provider):
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
# 所有Key都超额或不正常
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING, tag="OpenAI")
@@ -294,16 +293,16 @@ class ProviderOpenAIOfficial(Provider):
return image_url
def forget(self, session_id = None) -> bool:
async def forget(self, session_id = None) -> bool:
if session_id is None:
return False
self.session_dict[session_id] = []
return True
'''
获取缓存的会话
'''
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
'''
获取缓存的会话
'''
prompts = ""
if paging:
page_begin = (page-1)*size
@@ -320,15 +319,7 @@ class ProviderOpenAIOfficial(Provider):
if divide:
prompts += "----------\n"
return prompts
def get_user_usage_tokens(self,cache_list):
usage_tokens = 0
for item in cache_list:
usage_tokens += int(item['single_tokens'])
return usage_tokens
# 包装信息
def wrap(self, prompt, session_id, image_url = None):
if image_url is not None:
prompt = [
@@ -364,7 +355,6 @@ class ProviderOpenAIOfficial(Provider):
return context, new_record, req_list
def handle_switch_key(self):
# messages = [{"role": "user", "content": prompt}]
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
@@ -399,13 +389,13 @@ class ProviderOpenAIOfficial(Provider):
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
# 检查key是否可用
def check_key(self, key):
client_ = OpenAI(
async def check_key(self, key):
client_ = AsyncOpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
client_.chat.completions.create(
await client_.chat.completions.create(
messages=messages,
**self.openai_model_configs
)
+2 -2
View File
@@ -5,9 +5,9 @@ class Provider:
pass
@abc.abstractmethod
def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str:
async def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str:
pass
@abc.abstractmethod
def forget(self, session_id = None) -> bool:
async def forget(self, session_id = None) -> bool:
pass
+2 -1
View File
@@ -1,5 +1,6 @@
pydantic~=1.10.4
requests~=2.28.1
aiohttp
requests
openai~=1.2.3
qq-botpy
chardet~=5.1.0
+54 -34
View File
@@ -1,18 +1,19 @@
import requests
import util.general_utils as gu
from bs4 import BeautifulSoup
import traceback
import time
import json
import asyncio
from googlesearch import search, SearchResult
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.function_calling.func_call import (
FuncCall,
FuncCallJsonFormatError,
FuncNotFoundError
)
from openai.types.chat.chat_completion_message_tool_call import Function
import traceback
from googlesearch import search, SearchResult
from model.provider.provider import Provider
import json
from readability import Document
def tidy_text(text: str) -> str:
@@ -53,11 +54,11 @@ def google_web_search(keyword) -> str:
for i in ls:
desc = i.description
try:
gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
# gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(i.url)
except BaseException as e:
print(f"(google) fetch_website_content err: {str(e)}")
gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
# gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
index += 1
except Exception as e:
@@ -80,7 +81,7 @@ def web_keyword_search_via_bing(keyword) -> str:
try:
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
# gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
soup = BeautifulSoup(response.text, "html.parser")
res = ""
result_cnt = 0
@@ -96,7 +97,7 @@ def web_keyword_search_via_bing(keyword) -> str:
# "link": link,
# })
try:
gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
# gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(link)
except BaseException as e:
print(f"(bing) fetch_website_content err: {str(e)}")
@@ -124,11 +125,11 @@ def web_keyword_search_via_bing(keyword) -> str:
if result_cnt == 0: break
return res
except Exception as e:
gu.log(f"bing fetch err: {str(e)}")
# gu.log(f"bing fetch err: {str(e)}")
_cnt += 1
time.sleep(1)
gu.log("fail to fetch bing info, using sougou.")
# gu.log("fail to fetch bing info, using sougou.")
return web_keyword_search_via_sougou(keyword)
def web_keyword_search_via_sougou(keyword) -> str:
@@ -157,7 +158,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
break
except Exception as e:
pass
gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# 爬取网页内容
_detail_store = []
for i in res:
@@ -173,7 +174,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
return ret
def fetch_website_content(url):
gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
# gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
@@ -187,7 +188,7 @@ def fetch_website_content(url):
ret = tidy_text(soup.get_text())
return ret
def web_search(question, provider: Provider, session_id, official_fc=False):
async def web_search(question, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
@@ -197,7 +198,7 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
"name": "keyword",
"description": "google search query (分词,尽量保留所有信息)"
}],
"通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
web_keyword_search_via_bing
)
new_func_call.add_func("fetch_website_content", [{
@@ -205,16 +206,16 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
"name": "url",
"description": "网址"
}],
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。",
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
has_func = False
function_invoked_ret = ""
if official_fc:
func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
# we use official function-calling
func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
if isinstance(func, Function):
# arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search'
# 执行对应的结果:
func_obj = None
for i in new_func_call.func_list:
@@ -222,49 +223,68 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
func_obj = i["func_obj"]
break
if not func_obj:
gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
# gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(func.arguments)
function_invoked_ret = func_obj(**args)
# we use to_thread to avoid blocking the event loop
function_invoked_ret = await asyncio.to_thread(func_obj, **args)
has_func = True
except BaseException as e:
traceback.print_exc()
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
# now func is a string
return func
else:
# we use our own function-calling
try:
function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False)
args = {
'question': question1,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
provider.forget(session_id)
await provider.forget(session_id)
question3 = f"""
以下是相关材料,你的任务是:
1. 根据材料对问题`{question}`做切题的总结回答;
2. 发表你对这个问题的看法.
你的任务是:
1. 根据末尾的材料对问题`{question}`做切题的总结(详细);
2. 简单地发表你对这个问题的看法(简略)。
你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。
不要提到任何函数调用的信息。以下是相关材料:
不要提到任何函数调用的信息。
一些回复的消息模板:
模板1:
```
从网上的信息来看,可以知道...我个人认为...你觉得呢?
```
模板2:
```
根据网上的最新信息,可以得知...我觉得...你怎么看?
```
你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。
以下是相关材料:
"""
gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999)
_c = 0
while _c < 3:
try:
print('text chat')
final_ret = provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
return final_ret
except Exception as e:
print(e)
_c += 1
if _c == 3: raise e
if "The message you submitted was too long" in str(e):
provider.forget(session_id)
await provider.forget(session_id)
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
time.sleep(3)
return function_invoked_ret
-1
View File
@@ -9,7 +9,6 @@ from util.cmd_config import CmdConfig
import socket
from cores.qqbot.global_object import GlobalObject
import platform
import requests
import logging
import json
import sys