feat: 异步重写
perf: 优化网页搜索回答规范
This commit is contained in:
+30
-27
@@ -1,33 +1,35 @@
|
||||
import re
|
||||
import json
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
import requests
|
||||
import aiohttp
|
||||
import util.unfit_words as uw
|
||||
import os
|
||||
import sys
|
||||
from addons.baidu_aip_judge import BaiduJudge
|
||||
import io
|
||||
import traceback
|
||||
|
||||
import util.function_calling.gplugin as gplugin
|
||||
import util.plugin_util as putil
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from typing import Union
|
||||
from nakuru import (
|
||||
GroupMessage,
|
||||
FriendMessage,
|
||||
GuildMessage,
|
||||
)
|
||||
from nakuru.entities.components import Plain, At, Image
|
||||
|
||||
from addons.baidu_aip_judge import BaiduJudge
|
||||
from model.platform._nakuru_translation_layer import NakuruGuildMessage
|
||||
from nakuru.entities.components import Plain,At,Image
|
||||
from model.provider.provider import Provider
|
||||
from model.command.command import Command
|
||||
from util import general_utils as gu
|
||||
from util.general_utils import Logger, upload, run_monitor
|
||||
from util.cmd_config import CmdConfig as cc
|
||||
from util.cmd_config import init_astrbot_config_items
|
||||
import util.function_calling.gplugin as gplugin
|
||||
import util.plugin_util as putil
|
||||
from PIL import Image as PILImage
|
||||
import io
|
||||
import traceback
|
||||
from . global_object import GlobalObject
|
||||
from typing import Union
|
||||
from addons.dashboard.helper import DashBoardHelper
|
||||
from addons.dashboard.server import DashBoardData
|
||||
from cores.database.conn import dbConn
|
||||
@@ -41,7 +43,7 @@ frequency_time = 60
|
||||
frequency_count = 10
|
||||
|
||||
# 版本
|
||||
version = '3.1.5'
|
||||
version = '3.1.6'
|
||||
|
||||
# 语言模型
|
||||
REV_CHATGPT = 'rev_chatgpt'
|
||||
@@ -325,7 +327,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
command_result = () # 调用指令返回的结果
|
||||
|
||||
# 统计数据,如频道消息量
|
||||
record_message(platform, session_id)
|
||||
await record_message(platform, session_id)
|
||||
|
||||
for i in message.message:
|
||||
if isinstance(i, Plain):
|
||||
@@ -334,8 +336,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
return MessageResult("Hi~")
|
||||
|
||||
# 检查发言频率
|
||||
user_id = message.user_id
|
||||
if not check_frequency(user_id):
|
||||
if not check_frequency(message.user_id):
|
||||
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。')
|
||||
|
||||
# 检查是否是更换语言模型的请求
|
||||
@@ -359,7 +360,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
|
||||
llm_result_str = ""
|
||||
|
||||
hit, command_result = llm_command_instance[chosen_provider].check_command(
|
||||
# check commands and plugins
|
||||
hit, command_result = await llm_command_instance[chosen_provider].check_command(
|
||||
message_str,
|
||||
session_id,
|
||||
role,
|
||||
@@ -375,11 +377,12 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
if matches:
|
||||
return MessageResult(f"你的提问得到的回复未通过【默认关键词拦截】服务, 不予回复。")
|
||||
if baidu_judge != None:
|
||||
check, msg = baidu_judge.judge(message_str)
|
||||
check, msg = await asyncio.to_thread(baidu_judge.judge, message_str)
|
||||
if not check:
|
||||
return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}")
|
||||
if chosen_provider == NONE_LLM:
|
||||
return MessageResult("没有启动任何 LLM 并且未触发任何指令。")
|
||||
logger.log("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。", gu.LEVEL_WARNING)
|
||||
return
|
||||
try:
|
||||
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
|
||||
return
|
||||
@@ -403,9 +406,9 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
|
||||
if _global_object.web_search or web_sch_flag:
|
||||
official_fc = chosen_provider == OPENAI_OFFICIAL
|
||||
llm_result_str = gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
|
||||
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
|
||||
else:
|
||||
llm_result_str = str(llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality))
|
||||
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality)
|
||||
|
||||
llm_result_str = _global_object.reply_prefix + llm_result_str
|
||||
except BaseException as e:
|
||||
@@ -416,9 +419,9 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
if temp_switch != "":
|
||||
chosen_provider = temp_switch
|
||||
|
||||
# 指令回复
|
||||
if hit:
|
||||
# 检查指令。command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
|
||||
# 有指令或者插件触发
|
||||
# command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
|
||||
if command_result == None:
|
||||
return
|
||||
command = command_result[2]
|
||||
@@ -436,11 +439,11 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
|
||||
for i in command_result[1]:
|
||||
# 保存到本地
|
||||
pic_res = requests.get(i, stream = True)
|
||||
if pic_res.status_code == 200:
|
||||
image = PILImage.open(io.BytesIO(pic_res.content))
|
||||
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(i) as resp:
|
||||
if resp.status == 200:
|
||||
image = PILImage.open(io.BytesIO(await resp.read()))
|
||||
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
|
||||
# 其他指令
|
||||
else:
|
||||
try:
|
||||
@@ -455,7 +458,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
llm_result_str = re.sub(i, "***", llm_result_str)
|
||||
# 百度内容审核服务二次审核
|
||||
if baidu_judge != None:
|
||||
check, msg = baidu_judge.judge(llm_result_str)
|
||||
check, msg = await asyncio.to_thread(baidu_judge.judge, llm_result_str)
|
||||
if not check:
|
||||
return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}")
|
||||
# 发送信息
|
||||
|
||||
+28
-24
@@ -1,17 +1,19 @@
|
||||
import json
|
||||
from util import general_utils as gu
|
||||
import os
|
||||
import requests
|
||||
from model.provider.provider import Provider
|
||||
import inspect
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
import util.plugin_util as putil
|
||||
from util.cmd_config import CmdConfig as cc
|
||||
from util.general_utils import Logger
|
||||
import util.updator
|
||||
|
||||
from nakuru.entities.components import (
|
||||
Plain,
|
||||
Image
|
||||
)
|
||||
from util import general_utils as gu
|
||||
from model.provider.provider import Provider
|
||||
from util.cmd_config import CmdConfig as cc
|
||||
from util.general_utils import Logger
|
||||
from cores.qqbot.global_object import GlobalObject, AstrMessageEvent
|
||||
from cores.qqbot.global_object import CommandResult
|
||||
|
||||
@@ -25,7 +27,7 @@ class Command:
|
||||
self.global_object = global_object
|
||||
self.logger: Logger = global_object.logger
|
||||
|
||||
def check_command(self,
|
||||
async def check_command(self,
|
||||
message,
|
||||
session_id: str,
|
||||
role,
|
||||
@@ -51,7 +53,10 @@ class Command:
|
||||
if "type" in v["info"] and v["info"]["plugin_type"] == "platform":
|
||||
continue
|
||||
try:
|
||||
result = v["clsobj"].run(ame)
|
||||
if inspect.iscoroutinefunction(v["clsobj"].run):
|
||||
result = await v["clsobj"].run(ame)
|
||||
else:
|
||||
result = v["clsobj"].run(ame)
|
||||
if isinstance(result, CommandResult):
|
||||
hit = result.hit
|
||||
res = result._result_tuple()
|
||||
@@ -65,13 +70,16 @@ class Command:
|
||||
except TypeError as e:
|
||||
# 参数不匹配,尝试使用旧的参数方案
|
||||
try:
|
||||
hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
|
||||
if inspect.iscoroutinefunction(v["clsobj"].run):
|
||||
hit, res = await v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
|
||||
else:
|
||||
hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
|
||||
if hit:
|
||||
return True, res
|
||||
except BaseException as e:
|
||||
self.logger.log(f"{k}插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
|
||||
self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
|
||||
except BaseException as e:
|
||||
self.logger.log(f"{k} 插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
|
||||
self.logger.log(f"{k} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
|
||||
|
||||
if self.command_start_with(message, "nick"):
|
||||
return True, self.set_nick(message, platform, role)
|
||||
@@ -79,18 +87,13 @@ class Command:
|
||||
return True, self.plugin_oper(message, role, cached_plugins, platform)
|
||||
if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"):
|
||||
return True, self.get_my_id(message_obj, platform)
|
||||
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
|
||||
return True, self.get_new_conf(message, role)
|
||||
if self.command_start_with(message, "web"): # 网页搜索
|
||||
return True, self.web_search(message)
|
||||
if self.command_start_with(message, "ip"):
|
||||
ip = requests.get("https://myip.ipip.net", timeout=5).text
|
||||
return True, f"机器人 IP 信息:{ip}", "ip"
|
||||
if not self.provider and self.command_start_with(message, "help"):
|
||||
return True, self.help()
|
||||
return True, await self.help()
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def web_search(self, message):
|
||||
l = message.split(' ')
|
||||
if len(l) == 1:
|
||||
@@ -202,10 +205,11 @@ class Command:
|
||||
"/revgpt": "切换到网页版ChatGPT",
|
||||
}
|
||||
|
||||
def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None):
|
||||
async def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None):
|
||||
try:
|
||||
resp = requests.get("https://soulter.top/channelbot/notice.json").text
|
||||
notice = json.loads(resp)["notice"]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("https://soulter.top/channelbot/notice.json") as resp:
|
||||
notice = (await resp.json())["notice"]
|
||||
except BaseException as e:
|
||||
notice = ""
|
||||
msg = "# Help Center\n## 指令列表\n"
|
||||
@@ -279,9 +283,9 @@ class Command:
|
||||
def key(self):
|
||||
return False
|
||||
|
||||
def help(self):
|
||||
return True, self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins), "help"
|
||||
|
||||
async def help(self):
|
||||
ret = await self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins)
|
||||
return True, ret, "help"
|
||||
|
||||
def status(self):
|
||||
return False
|
||||
|
||||
@@ -11,7 +11,7 @@ class CommandOpenAIOfficial(Command):
|
||||
self.personality_str = ""
|
||||
super().__init__(provider, global_object)
|
||||
|
||||
def check_command(self,
|
||||
async def check_command(self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
@@ -20,7 +20,7 @@ class CommandOpenAIOfficial(Command):
|
||||
self.platform = platform
|
||||
|
||||
# 检查基础指令
|
||||
hit, res = super().check_command(
|
||||
hit, res = await super().check_command(
|
||||
message,
|
||||
session_id,
|
||||
role,
|
||||
@@ -32,7 +32,7 @@ class CommandOpenAIOfficial(Command):
|
||||
if hit:
|
||||
return True, res
|
||||
if self.command_start_with(message, "reset", "重置"):
|
||||
return True, self.reset(session_id, message)
|
||||
return True, await self.reset(session_id, message)
|
||||
elif self.command_start_with(message, "his", "历史"):
|
||||
return True, self.his(message, session_id)
|
||||
elif self.command_start_with(message, "token"):
|
||||
@@ -42,7 +42,7 @@ class CommandOpenAIOfficial(Command):
|
||||
elif self.command_start_with(message, "status"):
|
||||
return True, self.status()
|
||||
elif self.command_start_with(message, "help", "帮助"):
|
||||
return True, self.help()
|
||||
return True, await self.help()
|
||||
elif self.command_start_with(message, "unset"):
|
||||
return True, self.unset(session_id)
|
||||
elif self.command_start_with(message, "set"):
|
||||
@@ -54,11 +54,11 @@ class CommandOpenAIOfficial(Command):
|
||||
elif self.command_start_with(message, "key"):
|
||||
return True, self.key(message)
|
||||
elif self.command_start_with(message, "switch"):
|
||||
return True, self.switch(message)
|
||||
return True, await self.switch(message)
|
||||
|
||||
return False, None
|
||||
|
||||
def help(self):
|
||||
async def help(self):
|
||||
commands = super().general_commands()
|
||||
commands['画'] = '画画'
|
||||
commands['key'] = '添加OpenAI key'
|
||||
@@ -66,15 +66,15 @@ class CommandOpenAIOfficial(Command):
|
||||
commands['gpt'] = '查看gpt配置信息'
|
||||
commands['status'] = '查看key使用状态'
|
||||
commands['token'] = '查看本轮会话token'
|
||||
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||
|
||||
|
||||
def reset(self, session_id: str, message: str = "reset"):
|
||||
async def reset(self, session_id: str, message: str = "reset"):
|
||||
if self.provider is None:
|
||||
return False, "未启用 OpenAI 官方 API", "reset"
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
self.provider.forget(session_id)
|
||||
await self.provider.forget(session_id)
|
||||
return True, "重置成功", "reset"
|
||||
if len(l) == 2 and l[1] == "p":
|
||||
self.provider.forget(session_id)
|
||||
@@ -146,7 +146,7 @@ class CommandOpenAIOfficial(Command):
|
||||
else:
|
||||
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
|
||||
|
||||
def switch(self, message: str):
|
||||
async def switch(self, message: str):
|
||||
'''
|
||||
切换账号
|
||||
'''
|
||||
@@ -168,7 +168,7 @@ class CommandOpenAIOfficial(Command):
|
||||
else:
|
||||
try:
|
||||
new_key = list(key_stat.keys())[index-1]
|
||||
ret = self.provider.check_key(new_key)
|
||||
ret = await self.provider.check_key(new_key)
|
||||
self.provider.set_key(new_key)
|
||||
except BaseException as e:
|
||||
return True, "账号切换失败,原因: " + str(e), "switch"
|
||||
|
||||
@@ -11,14 +11,14 @@ class CommandRevChatGPT(Command):
|
||||
self.personality_str = ""
|
||||
super().__init__(provider, global_object)
|
||||
|
||||
def check_command(self,
|
||||
async def check_command(self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
platform: str,
|
||||
message_obj):
|
||||
self.platform = platform
|
||||
hit, res = super().check_command(
|
||||
hit, res = await super().check_command(
|
||||
message,
|
||||
session_id,
|
||||
role,
|
||||
@@ -29,7 +29,7 @@ class CommandRevChatGPT(Command):
|
||||
if hit:
|
||||
return True, res
|
||||
if self.command_start_with(message, "help", "帮助"):
|
||||
return True, self.help()
|
||||
return True, await self.help()
|
||||
elif self.command_start_with(message, "reset"):
|
||||
return True, self.reset(session_id, message)
|
||||
elif self.command_start_with(message, "update"):
|
||||
@@ -127,7 +127,7 @@ class CommandRevChatGPT(Command):
|
||||
else:
|
||||
return True, "参数过多。", "switch"
|
||||
|
||||
def help(self):
|
||||
async def help(self):
|
||||
commands = super().general_commands()
|
||||
commands['set'] = '设置人格'
|
||||
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import abc
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import Callable, Union
|
||||
from typing import Union
|
||||
from nakuru import (
|
||||
GuildMessage,
|
||||
GroupMessage,
|
||||
@@ -70,14 +68,3 @@ class Platform():
|
||||
pass
|
||||
ret.replace('\n', '')
|
||||
return ret
|
||||
|
||||
|
||||
def new_sub_thread(self, func, args=()):
|
||||
thread = threading.Thread(target=self._runner, args=(func, args), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _runner(self, func: Callable, args: tuple):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(func(*args))
|
||||
loop.close()
|
||||
|
||||
+19
-33
@@ -58,10 +58,10 @@ class QQGOCQ(Platform):
|
||||
async def _(app: CQHTTP, source: GroupMessage):
|
||||
if self.cc.get("gocq_react_group", True):
|
||||
if isinstance(source.message[0], Plain):
|
||||
self.new_sub_thread(self.handle_msg, (source, True))
|
||||
await self.handle_msg(source, True)
|
||||
elif isinstance(source.message[0], At):
|
||||
if source.message[0].qq == source.self_id:
|
||||
self.new_sub_thread(self.handle_msg, (source, True))
|
||||
await self.handle_msg(source, True)
|
||||
else:
|
||||
return
|
||||
|
||||
@@ -69,7 +69,7 @@ class QQGOCQ(Platform):
|
||||
async def _(app: CQHTTP, source: FriendMessage):
|
||||
if self.cc.get("gocq_react_friend", True):
|
||||
if isinstance(source.message[0], Plain):
|
||||
self.new_sub_thread(self.handle_msg, (source, False))
|
||||
await self.handle_msg(source, False)
|
||||
else:
|
||||
return
|
||||
|
||||
@@ -85,19 +85,16 @@ class QQGOCQ(Platform):
|
||||
async def _(app: CQHTTP, source: Notify):
|
||||
print(source)
|
||||
if source.sub_type == "poke" and source.target_id == source.self_id:
|
||||
# await self.handle_msg(source, False)
|
||||
self.new_sub_thread(self.handle_msg, (source, False))
|
||||
await self.handle_msg(source, False)
|
||||
|
||||
@gocq_app.receiver("GuildMessage")
|
||||
async def _(app: CQHTTP, source: GuildMessage):
|
||||
if self.cc.get("gocq_react_guild", True):
|
||||
if isinstance(source.message[0], Plain):
|
||||
# await self.handle_msg(source, True)
|
||||
self.new_sub_thread(self.handle_msg, (source, True))
|
||||
await self.handle_msg(source, True)
|
||||
elif isinstance(source.message[0], At):
|
||||
if source.message[0].qq == source.self_tiny_id:
|
||||
# await self.handle_msg(source, True)
|
||||
self.new_sub_thread(self.handle_msg, (source, True))
|
||||
await self.handle_msg(source, True)
|
||||
else:
|
||||
return
|
||||
|
||||
@@ -157,7 +154,7 @@ class QQGOCQ(Platform):
|
||||
|
||||
if message_result is None:
|
||||
return
|
||||
self.reply_msg(message, message_result.result_message)
|
||||
await self.reply_msg(message, message_result.result_message)
|
||||
if message_result.callback is not None:
|
||||
message_result.callback()
|
||||
|
||||
@@ -165,11 +162,11 @@ class QQGOCQ(Platform):
|
||||
if session_id in self.waiting and self.waiting[session_id] == '':
|
||||
self.waiting[session_id] = message
|
||||
|
||||
def reply_msg(self,
|
||||
async def reply_msg(self,
|
||||
message: Union[GroupMessage, FriendMessage, GuildMessage, Notify],
|
||||
result_message: list):
|
||||
"""
|
||||
插件开发者请使用send方法, 可以不用直接调用这个方法。
|
||||
插件开发者请使用send方法, 可以不用直接调用这个方法。
|
||||
"""
|
||||
source = message
|
||||
res = result_message
|
||||
@@ -205,12 +202,10 @@ class QQGOCQ(Platform):
|
||||
# 回复消息链
|
||||
if isinstance(res, list) and len(res) > 0:
|
||||
if source.type == "GuildMessage":
|
||||
# await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res)
|
||||
asyncio.run_coroutine_threadsafe(self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res), self.loop).result()
|
||||
await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res)
|
||||
return
|
||||
elif source.type == "FriendMessage":
|
||||
# await self.client.sendFriendMessage(source.user_id, res)
|
||||
asyncio.run_coroutine_threadsafe(self.client.sendFriendMessage(source.user_id, res), self.loop).result()
|
||||
await self.client.sendFriendMessage(source.user_id, res)
|
||||
return
|
||||
elif source.type == "GroupMessage":
|
||||
# 过长时forward发送
|
||||
@@ -233,37 +228,28 @@ class QQGOCQ(Platform):
|
||||
node.time = int(time.time())
|
||||
# print(node)
|
||||
nodes=[node]
|
||||
# await self.client.sendGroupForwardMessage(source.group_id, nodes)
|
||||
asyncio.run_coroutine_threadsafe(self.client.sendGroupForwardMessage(source.group_id, nodes), self.loop).result()
|
||||
await self.client.sendGroupForwardMessage(source.group_id, nodes)
|
||||
return
|
||||
# await self.client.sendGroupMessage(source.group_id, res)
|
||||
asyncio.run_coroutine_threadsafe(self.client.sendGroupMessage(source.group_id, res), self.loop).result()
|
||||
await self.client.sendGroupMessage(source.group_id, res)
|
||||
return
|
||||
|
||||
def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list):
|
||||
async def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list):
|
||||
'''
|
||||
提供给插件的发送QQ消息接口。
|
||||
参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。
|
||||
非异步
|
||||
'''
|
||||
try:
|
||||
# await self.reply_msg(message, result_message)
|
||||
self.reply_msg(message, result_message)
|
||||
await self.reply_msg(message, result_message)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
def send(self,
|
||||
async def send(self,
|
||||
to,
|
||||
res):
|
||||
'''
|
||||
同 send_msg()
|
||||
非异步
|
||||
'''
|
||||
try:
|
||||
# await self.send_msg(to, res)
|
||||
self.reply_msg(to, res)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
await self.reply_msg(to, res)
|
||||
|
||||
def create_text_image(title: str, text: str, max_width=30, font_size=20):
|
||||
'''
|
||||
@@ -302,12 +288,12 @@ class QQGOCQ(Platform):
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
def nakuru_method_invoker(self, func, *args, **kwargs):
|
||||
async def nakuru_method_invoker(self, func, *args, **kwargs):
|
||||
"""
|
||||
返回一个方法调用器,可以用来立即调用nakuru的方法。
|
||||
"""
|
||||
try:
|
||||
ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop).result()
|
||||
ret = func(*args, **kwargs)
|
||||
return ret
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
@@ -4,7 +4,7 @@ from PIL import Image as PILImage
|
||||
from botpy.message import Message, DirectMessage
|
||||
import re
|
||||
import asyncio
|
||||
import requests
|
||||
import aiohttp
|
||||
from util import general_utils as gu
|
||||
|
||||
from botpy.types.message import Reference
|
||||
@@ -15,7 +15,7 @@ from ._nakuru_translation_layer import(
|
||||
NakuruGuildMessage,
|
||||
gocq_compatible_receive,
|
||||
gocq_compatible_send
|
||||
)
|
||||
)
|
||||
from typing import Union
|
||||
|
||||
# QQ 机器人官方框架
|
||||
@@ -27,13 +27,13 @@ class botClient(Client):
|
||||
async def on_at_message_create(self, message: Message):
|
||||
# 转换层
|
||||
nakuru_guild_message = gocq_compatible_receive(message)
|
||||
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, True))
|
||||
await self.platform.handle_msg(nakuru_guild_message, True)
|
||||
|
||||
# 收到私聊消息
|
||||
async def on_direct_message_create(self, message: DirectMessage):
|
||||
# 转换层
|
||||
nakuru_guild_message = gocq_compatible_receive(message)
|
||||
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, False))
|
||||
await self.platform.handle_msg(nakuru_guild_message, False)
|
||||
|
||||
class QQOfficial(Platform):
|
||||
|
||||
@@ -107,7 +107,7 @@ class QQOfficial(Platform):
|
||||
if message_result is None:
|
||||
return
|
||||
|
||||
self.reply_msg(is_group, message, message_result.result_message)
|
||||
await self.reply_msg(is_group, message, message_result.result_message)
|
||||
if message_result.callback is not None:
|
||||
message_result.callback()
|
||||
|
||||
@@ -115,7 +115,7 @@ class QQOfficial(Platform):
|
||||
if session_id in self.waiting and self.waiting[session_id] == '':
|
||||
self.waiting[session_id] = message
|
||||
|
||||
def reply_msg(self,
|
||||
async def reply_msg(self,
|
||||
is_group: bool,
|
||||
message: NakuruGuildMessage,
|
||||
res: Union[str, list]):
|
||||
@@ -148,10 +148,11 @@ class QQOfficial(Platform):
|
||||
if image_path is not None and image_path != '':
|
||||
msg_ref = None
|
||||
if image_path.startswith("http"):
|
||||
pic_res = requests.get(image_path, stream = True)
|
||||
if pic_res.status_code == 200:
|
||||
image = PILImage.open(io.BytesIO(pic_res.content))
|
||||
image_path = gu.save_temp_img(image)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_path) as response:
|
||||
if response.status == 200:
|
||||
image = PILImage.open(io.BytesIO(await response.read()))
|
||||
image_path = gu.save_temp_img(image)
|
||||
|
||||
if message.raw_message is not None and image_path == '': # file_image与message_reference不能同时传入
|
||||
msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False)
|
||||
@@ -170,8 +171,7 @@ class QQOfficial(Platform):
|
||||
data['file_image'] = image_path
|
||||
|
||||
try:
|
||||
# await self._send_wrapper(**data)
|
||||
self._send_wrapper(**data)
|
||||
await self._send_wrapper(**data)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
# 分割过长的消息
|
||||
@@ -181,51 +181,44 @@ class QQOfficial(Platform):
|
||||
split_res.append(plain_text[len(plain_text)//2:])
|
||||
for i in split_res:
|
||||
data['content'] = i
|
||||
# await self._send_wrapper(**data)
|
||||
self._send_wrapper(**data)
|
||||
await self._send_wrapper(**data)
|
||||
else:
|
||||
# 发送qq信息
|
||||
try:
|
||||
# 防止被qq频道过滤消息
|
||||
plain_text = plain_text.replace(".", " . ")
|
||||
# await self._send_wrapper(**data)
|
||||
self._send_wrapper(**data)
|
||||
await self._send_wrapper(**data)
|
||||
|
||||
except BaseException as e:
|
||||
try:
|
||||
data['content'] = str.join(" ", plain_text)
|
||||
# await self._send_wrapper(**data)
|
||||
self._send_wrapper(**data)
|
||||
await self._send_wrapper(**data)
|
||||
except BaseException as e:
|
||||
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
|
||||
plain_text = plain_text.replace(".", "·")
|
||||
data['content'] = plain_text
|
||||
# await self._send_wrapper(**data)
|
||||
self._send_wrapper(**data)
|
||||
await self._send_wrapper(**data)
|
||||
|
||||
def _send_wrapper(self, **kwargs):
|
||||
async def _send_wrapper(self, **kwargs):
|
||||
if 'channel_id' in kwargs:
|
||||
asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result()
|
||||
await self.client.api.post_message(**kwargs)
|
||||
else:
|
||||
asyncio.run_coroutine_threadsafe(self.client.api.post_dms(**kwargs), self.loop).result()
|
||||
await self.client.api.post_dms(**kwargs)
|
||||
|
||||
|
||||
def send_msg(self, channel_id: int, message_chain: list, message_id: int = None):
|
||||
async def send_msg(self, channel_id: int, message_chain: list, message_id: int = None):
|
||||
'''
|
||||
推送消息, 如果有 message_id,那么就是回复消息。非异步。
|
||||
推送消息, 如果有 message_id,那么就是回复消息。
|
||||
'''
|
||||
_n = NakuruGuildMessage()
|
||||
_n.channel_id = channel_id
|
||||
_n.message_id = message_id
|
||||
# await self.reply_msg(_n, message_chain)
|
||||
self.reply_msg(_n, message_chain)
|
||||
await self.reply_msg(_n, message_chain)
|
||||
|
||||
def send(self, message_obj, message_chain: list):
|
||||
async def send(self, message_obj, message_chain: list):
|
||||
'''
|
||||
发送信息。内容同 reply_msg。非异步。
|
||||
发送信息。内容同 reply_msg。
|
||||
'''
|
||||
# await self.reply_msg(message_obj, message_chain)
|
||||
self.reply_msg(message_obj, message_chain)
|
||||
await self.reply_msg(message_obj, message_chain)
|
||||
|
||||
def wait_for_message(self, channel_id: int) -> NakuruGuildMessage:
|
||||
'''
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
from openai import OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.images_response import ImagesResponse
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import tiktoken
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from cores.database.conn import dbConn
|
||||
from model.provider.provider import Provider
|
||||
import threading
|
||||
from util import general_utils as gu
|
||||
from util.cmd_config import CmdConfig
|
||||
from util.general_utils import Logger
|
||||
import traceback
|
||||
import tiktoken
|
||||
|
||||
|
||||
|
||||
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
||||
@@ -42,7 +45,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.logger.log(f"设置 api_base 为: {self.api_base}", tag="OpenAI")
|
||||
|
||||
# 创建 OpenAI Client
|
||||
self.client = OpenAI(
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.key_list[0],
|
||||
base_url=self.api_base
|
||||
)
|
||||
@@ -113,7 +116,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
}
|
||||
self.session_dict[session_id].append(new_record)
|
||||
|
||||
def text_chat(self, prompt,
|
||||
async def text_chat(self, prompt,
|
||||
session_id = None,
|
||||
image_url = None,
|
||||
function_call=None,
|
||||
@@ -132,7 +135,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if default_personality is not None:
|
||||
self.personality_set(default_personality, session_id)
|
||||
|
||||
|
||||
# 使用 tictoken 截断消息
|
||||
_encoded_prompt = self.enc.encode(prompt)
|
||||
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
|
||||
@@ -140,8 +142,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, tag="OpenAI")
|
||||
|
||||
cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url)
|
||||
self.logger.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
|
||||
self.logger.log(f"OPENAI REQUEST: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
|
||||
self.logger.log(f"cache: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
|
||||
self.logger.log(f"request: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
|
||||
retry = 0
|
||||
response = None
|
||||
err = ''
|
||||
@@ -168,19 +170,19 @@ class ProviderOpenAIOfficial(Provider):
|
||||
while retry < 10:
|
||||
try:
|
||||
if function_call is None:
|
||||
response = self.client.chat.completions.create(
|
||||
response = await self.client.chat.completions.create(
|
||||
messages=req,
|
||||
**conf
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
response = await self.client.chat.completions.create(
|
||||
messages=req,
|
||||
tools = function_call,
|
||||
**conf
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
traceback.print_exc()
|
||||
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
|
||||
raise e
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
@@ -188,7 +190,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
retry -= 1
|
||||
elif 'maximum context length' in str(e):
|
||||
@@ -239,7 +240,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
index += 1
|
||||
# 删除完后更新相关字段
|
||||
self.session_dict[session_id] = cache_data_list
|
||||
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
|
||||
# 添加新条目进入缓存的prompt
|
||||
new_record['AI'] = {
|
||||
@@ -258,7 +258,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
return chatgpt_res
|
||||
|
||||
def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
|
||||
async def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
|
||||
retry = 0
|
||||
image_url = ''
|
||||
|
||||
@@ -266,7 +266,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
while retry < 5:
|
||||
try:
|
||||
response: ImagesResponse = self.client.images.generate(
|
||||
response: ImagesResponse = await self.client.images.generate(
|
||||
prompt=prompt,
|
||||
**image_generate_configs
|
||||
)
|
||||
@@ -282,7 +282,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
elif 'Your request was rejected as a result of our safety system.' in str(e):
|
||||
self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING, tag="OpenAI")
|
||||
@@ -294,16 +293,16 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
return image_url
|
||||
|
||||
def forget(self, session_id = None) -> bool:
|
||||
async def forget(self, session_id = None) -> bool:
|
||||
if session_id is None:
|
||||
return False
|
||||
self.session_dict[session_id] = []
|
||||
return True
|
||||
|
||||
'''
|
||||
获取缓存的会话
|
||||
'''
|
||||
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
|
||||
'''
|
||||
获取缓存的会话
|
||||
'''
|
||||
prompts = ""
|
||||
if paging:
|
||||
page_begin = (page-1)*size
|
||||
@@ -320,15 +319,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if divide:
|
||||
prompts += "----------\n"
|
||||
return prompts
|
||||
|
||||
|
||||
def get_user_usage_tokens(self,cache_list):
|
||||
usage_tokens = 0
|
||||
for item in cache_list:
|
||||
usage_tokens += int(item['single_tokens'])
|
||||
return usage_tokens
|
||||
|
||||
# 包装信息
|
||||
|
||||
def wrap(self, prompt, session_id, image_url = None):
|
||||
if image_url is not None:
|
||||
prompt = [
|
||||
@@ -364,7 +355,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
return context, new_record, req_list
|
||||
|
||||
def handle_switch_key(self):
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
is_all_exceed = True
|
||||
for key in self.key_stat:
|
||||
if key == None or self.key_stat[key]['exceed']:
|
||||
@@ -399,13 +389,13 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
||||
|
||||
# 检查key是否可用
|
||||
def check_key(self, key):
|
||||
client_ = OpenAI(
|
||||
async def check_key(self, key):
|
||||
client_ = AsyncOpenAI(
|
||||
api_key=key,
|
||||
base_url=self.api_base
|
||||
)
|
||||
messages = [{"role": "user", "content": "please just echo `test`"}]
|
||||
client_.chat.completions.create(
|
||||
await client_.chat.completions.create(
|
||||
messages=messages,
|
||||
**self.openai_model_configs
|
||||
)
|
||||
|
||||
@@ -5,9 +5,9 @@ class Provider:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str:
|
||||
async def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def forget(self, session_id = None) -> bool:
|
||||
async def forget(self, session_id = None) -> bool:
|
||||
pass
|
||||
+2
-1
@@ -1,5 +1,6 @@
|
||||
pydantic~=1.10.4
|
||||
requests~=2.28.1
|
||||
aiohttp
|
||||
requests
|
||||
openai~=1.2.3
|
||||
qq-botpy
|
||||
chardet~=5.1.0
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import requests
|
||||
import util.general_utils as gu
|
||||
from bs4 import BeautifulSoup
|
||||
import traceback
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from googlesearch import search, SearchResult
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from util.function_calling.func_call import (
|
||||
FuncCall,
|
||||
FuncCallJsonFormatError,
|
||||
FuncNotFoundError
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
import traceback
|
||||
from googlesearch import search, SearchResult
|
||||
from model.provider.provider import Provider
|
||||
import json
|
||||
from readability import Document
|
||||
|
||||
|
||||
def tidy_text(text: str) -> str:
|
||||
@@ -53,11 +54,11 @@ def google_web_search(keyword) -> str:
|
||||
for i in ls:
|
||||
desc = i.description
|
||||
try:
|
||||
gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
|
||||
# gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
|
||||
desc = fetch_website_content(i.url)
|
||||
except BaseException as e:
|
||||
print(f"(google) fetch_website_content err: {str(e)}")
|
||||
gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
# gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
|
||||
index += 1
|
||||
except Exception as e:
|
||||
@@ -80,7 +81,7 @@ def web_keyword_search_via_bing(keyword) -> str:
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
# gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
res = ""
|
||||
result_cnt = 0
|
||||
@@ -96,7 +97,7 @@ def web_keyword_search_via_bing(keyword) -> str:
|
||||
# "link": link,
|
||||
# })
|
||||
try:
|
||||
gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
|
||||
# gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
|
||||
desc = fetch_website_content(link)
|
||||
except BaseException as e:
|
||||
print(f"(bing) fetch_website_content err: {str(e)}")
|
||||
@@ -124,11 +125,11 @@ def web_keyword_search_via_bing(keyword) -> str:
|
||||
if result_cnt == 0: break
|
||||
return res
|
||||
except Exception as e:
|
||||
gu.log(f"bing fetch err: {str(e)}")
|
||||
# gu.log(f"bing fetch err: {str(e)}")
|
||||
_cnt += 1
|
||||
time.sleep(1)
|
||||
|
||||
gu.log("fail to fetch bing info, using sougou.")
|
||||
# gu.log("fail to fetch bing info, using sougou.")
|
||||
return web_keyword_search_via_sougou(keyword)
|
||||
|
||||
def web_keyword_search_via_sougou(keyword) -> str:
|
||||
@@ -157,7 +158,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
|
||||
break
|
||||
except Exception as e:
|
||||
pass
|
||||
gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
|
||||
# gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
|
||||
# 爬取网页内容
|
||||
_detail_store = []
|
||||
for i in res:
|
||||
@@ -173,7 +174,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
|
||||
return ret
|
||||
|
||||
def fetch_website_content(url):
|
||||
gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
|
||||
# gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
@@ -187,7 +188,7 @@ def fetch_website_content(url):
|
||||
ret = tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
async def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
'''
|
||||
official_fc: 使用官方 function-calling
|
||||
'''
|
||||
@@ -197,7 +198,7 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
"name": "keyword",
|
||||
"description": "google search query (分词,尽量保留所有信息)"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
web_keyword_search_via_bing
|
||||
)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
@@ -205,16 +206,16 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
"name": "url",
|
||||
"description": "网址"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
|
||||
# we use official function-calling
|
||||
func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
|
||||
if isinstance(func, Function):
|
||||
# arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search'
|
||||
# 执行对应的结果:
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
@@ -222,49 +223,68 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
|
||||
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
# gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
|
||||
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(func.arguments)
|
||||
function_invoked_ret = func_obj(**args)
|
||||
# we use to_thread to avoid blocking the event loop
|
||||
function_invoked_ret = await asyncio.to_thread(func_obj, **args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
# now func is a string
|
||||
return func
|
||||
else:
|
||||
# we use our own function-calling
|
||||
try:
|
||||
function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False)
|
||||
args = {
|
||||
'question': question1,
|
||||
'func_definition': new_func_call.func_dump(),
|
||||
'is_task': False,
|
||||
'is_summary': False,
|
||||
}
|
||||
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
|
||||
except BaseException as e:
|
||||
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
|
||||
res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return res
|
||||
has_func = True
|
||||
|
||||
if has_func:
|
||||
provider.forget(session_id)
|
||||
await provider.forget(session_id)
|
||||
question3 = f"""
|
||||
以下是相关材料,你的任务是:
|
||||
1. 根据材料对问题`{question}`做切题的总结回答;
|
||||
2. 发表你对这个问题的看法.
|
||||
你的任务是:
|
||||
1. 根据末尾的材料对问题`{question}`做切题的总结(详细);
|
||||
2. 简单地发表你对这个问题的看法(简略)。
|
||||
你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。
|
||||
不要提到任何函数调用的信息。以下是相关材料:
|
||||
不要提到任何函数调用的信息。
|
||||
|
||||
一些回复的消息模板:
|
||||
模板1:
|
||||
```
|
||||
从网上的信息来看,可以知道...我个人认为...你觉得呢?
|
||||
```
|
||||
模板2:
|
||||
```
|
||||
根据网上的最新信息,可以得知...我觉得...你怎么看?
|
||||
```
|
||||
你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。
|
||||
|
||||
以下是相关材料:
|
||||
"""
|
||||
|
||||
gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999)
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
print('text chat')
|
||||
final_ret = provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
|
||||
final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
|
||||
return final_ret
|
||||
except Exception as e:
|
||||
print(e)
|
||||
_c += 1
|
||||
if _c == 3: raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
provider.forget(session_id)
|
||||
await provider.forget(session_id)
|
||||
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
|
||||
time.sleep(3)
|
||||
return function_invoked_ret
|
||||
|
||||
@@ -9,7 +9,6 @@ from util.cmd_config import CmdConfig
|
||||
import socket
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
import platform
|
||||
import requests
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
|
||||
Reference in New Issue
Block a user