perf: 使用异步重写部分代码

This commit is contained in:
Soulter
2023-10-12 11:16:49 +08:00
parent 123c21fcb3
commit 89fc7b0553
5 changed files with 40 additions and 48 deletions
+37 -38
View File
@@ -31,7 +31,7 @@ from PIL import Image as PILImage
import io
import traceback
from . global_object import GlobalObject
from typing import Union
from typing import Union, Callable
# 缓存的会话
@@ -116,15 +116,20 @@ gocq_app = CQHTTP(
)
gocq_loop = None
bing_cache_loop = None
# 全局对象
_global_object: GlobalObject = None
def new_sub_thread(func, args=()):
thread = threading.Thread(target=func, args=args, daemon=True)
thread = threading.Thread(target=_runner, args=(func, args), daemon=True)
thread.start()
def _runner(func: Callable, args: tuple):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(func(*args))
loop.close()
# [Deprecated] 写入统计信息
def toggle_count(at: bool, message):
@@ -399,9 +404,6 @@ def run_qqchan_bot(cfg, loop, qqchannel_bot: QQChan):
abs_path = os.path.abspath("QQChannelChatGPT/configs/config.yaml")
print("配置文件地址:" + abs_path)
os.system(f"notepad \"{abs_path}\"")
# gu.log("如果你使用了go-cqhttp, 则可以忽略上面的报错。" + str(e), gu.LEVEL_CRITICAL, tag="QQ频道")
# input(f"\n[System-Error] 启动QQ频道机器人时出现错误,原因如下:{e}\n可能是没有填写QQBOT appid和token?请在config中完善你的appid和token\n配置教程:https://soulter.top/posts/qpdg.html\n")
'''
运行GOCQ机器人
@@ -424,7 +426,6 @@ def run_gocq_bot(loop, gocq_bot, gocq_app):
except BaseException as e:
input("启动QQ机器人出现错误"+str(e))
'''
检查发言频率
'''
@@ -450,7 +451,7 @@ def check_frequency(id) -> bool:
'''
通用消息回复
'''
def send_message(platform, message, res, session_id = None):
async def send_message(platform, message, res, session_id = None):
global qqchannel_bot, qqchannel_bot, gocq_loop, session_dict
if session_id is not None:
if session_id not in session_dict:
@@ -462,9 +463,9 @@ def send_message(platform, message, res, session_id = None):
if platform == PLATFORM_QQCHAN:
qqchannel_bot.send_qq_msg(message, res)
if platform == PLATFORM_GOCQ:
asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, res), gocq_loop).result()
await gocq_bot.send_qq_msg(message, res)
def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
group: bool=False,
platform: str = None):
"""
@@ -474,7 +475,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
msg_ref: 引用消息(频道)
platform: 平台(gocq, qqchan)
"""
global chosen_provider, keywords, qqchannel_bot, gocq_bot, gocq_loop, bing_cache_loop, qqchan_loop
global chosen_provider, keywords, qqchannel_bot, gocq_bot
global _global_object
qq_msg = ''
session_id = ''
@@ -543,13 +544,13 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
with_tag = True
if qq_msg == "":
send_message(platform, message, f"Hi~", session_id=session_id)
await send_message(platform, message, f"Hi~", session_id=session_id)
return
if with_tag:
# 检查发言频率
if not check_frequency(user_id):
send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', session_id=session_id)
await send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', session_id=session_id)
return
# logf.write("[GOCQBOT] "+ qq_msg+'\n')
@@ -568,9 +569,9 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
image_url = keywords[k]['image_url']
if image_url != "":
res = [Plain(plain_text), Image.fromURL(image_url)]
send_message(platform, message, res, session_id=session_id)
await send_message(platform, message, res, session_id=session_id)
else:
send_message(platform, message, plain_text, session_id=session_id)
await send_message(platform, message, plain_text, session_id=session_id)
return
# 检查是否是更换语言模型的请求
@@ -592,7 +593,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
else:
chosen_provider = target
cc.put("chosen_provider", chosen_provider)
send_message(platform, message, f"已切换至【{chosen_provider}", session_id=session_id)
await send_message(platform, message, f"已切换至【{chosen_provider}", session_id=session_id)
return
chatgpt_res = ""
@@ -603,7 +604,6 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
hit, command_result = llm_command_instance[chosen_provider].check_command(
qq_msg,
session_id,
bing_cache_loop,
role,
platform,
message,
@@ -617,15 +617,15 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
for i in uw.unfit_words_q:
matches = re.match(i, qq_msg.strip(), re.I | re.M)
if matches:
send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", session_id=session_id)
await send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", session_id=session_id)
return
if baidu_judge != None:
check, msg = baidu_judge.judge(qq_msg)
if not check:
send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", session_id=session_id)
await send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", session_id=session_id)
return
if chosen_provider == None:
send_message(platform, message, f"管理员未启动任何语言模型或者语言模型初始化时失败。", session_id=session_id)
await send_message(platform, message, f"管理员未启动任何语言模型或者语言模型初始化时失败。", session_id=session_id)
return
try:
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
@@ -634,14 +634,14 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
else:
chatgpt_res = str(llm_instance[chosen_provider].text_chat(qq_msg, session_id))
elif chosen_provider == REV_EDGEGPT:
res, res_code = asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].text_chat(qq_msg, platform), bing_cache_loop).result()
res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform)
if res_code == 0: # bing不想继续话题,重置会话后重试。
send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", session_id=session_id)
asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].forget(), bing_cache_loop).result()
res, res_code = asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].text_chat(qq_msg, platform), bing_cache_loop).result()
await send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", session_id=session_id)
await llm_instance[chosen_provider].forget()
res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform)
if res_code == 0: # bing还是不想继续话题,大概率说明提问有问题。
asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].forget(), bing_cache_loop).result()
send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", session_id=session_id)
await llm_instance[chosen_provider].forget()
await send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", session_id=session_id)
res = ""
chatgpt_res = str(res)
@@ -650,7 +650,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
except BaseException as e:
gu.log(f"调用异常:{traceback.format_exc()}", gu.LEVEL_ERROR, max_len=100000)
gu.log("调用语言模型例程时出现异常。原因: "+str(e), gu.LEVEL_ERROR)
send_message(platform, message, "调用语言模型例程时出现异常。原因: "+str(e), session_id=session_id)
await send_message(platform, message, "调用语言模型例程时出现异常。原因: "+str(e), session_id=session_id)
return
# 切换回原来的语言模型
@@ -661,7 +661,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
if hit:
# 检查指令. command_result是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
if command_result == None:
# send_message(platform, message, "指令调用未返回任何信息。", session_id=session_id)
# await send_message(platform, message, "指令调用未返回任何信息。", session_id=session_id)
return
command = command_result[2]
@@ -671,17 +671,17 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
keywords = json.load(f)
else:
try:
send_message(platform, message, command_result[1], session_id=session_id)
await send_message(platform, message, command_result[1], session_id=session_id)
except BaseException as e:
send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
if command == "update latest r":
send_message(platform, message, command_result[1] + "\n\n即将自动重启。", session_id=session_id)
await send_message(platform, message, command_result[1] + "\n\n即将自动重启。", session_id=session_id)
py = sys.executable
os.execl(py, py, *sys.argv)
if not command_result[0]:
send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", session_id=session_id)
await send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", session_id=session_id)
return
# 画图指令
@@ -692,14 +692,14 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
pic_res = requests.get(i, stream = True)
if pic_res.status_code == 200:
image = PILImage.open(io.BytesIO(pic_res.content))
send_message(platform, message, [Image.fromFileSystem(gu.save_temp_img(image))], session_id=session_id)
await send_message(platform, message, [Image.fromFileSystem(gu.save_temp_img(image))], session_id=session_id)
# 其他指令
else:
try:
send_message(platform, message, command_result[1], session_id=session_id)
await send_message(platform, message, command_result[1], session_id=session_id)
except BaseException as e:
send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
return
@@ -715,12 +715,12 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
if baidu_judge != None:
check, msg = baidu_judge.judge(chatgpt_res)
if not check:
send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", session_id=session_id)
await send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", session_id=session_id)
return
# 发送信息
try:
send_message(platform, message, chatgpt_res, session_id=session_id)
await send_message(platform, message, chatgpt_res, session_id=session_id)
except BaseException as e:
gu.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR)
@@ -733,7 +733,6 @@ class botClient(botpy.Client):
# 转换层
nakuru_guild_message = qqchannel_bot.gocq_compatible_receive(message)
gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999)
new_sub_thread(oper_msg, (nakuru_guild_message, True, PLATFORM_QQCHAN))
# 收到私聊消息
-1
View File
@@ -40,7 +40,6 @@ class Command:
def check_command(self,
message,
session_id: str,
loop,
role,
platform,
message_obj):
-2
View File
@@ -17,7 +17,6 @@ class CommandOpenAIOfficial(Command):
def check_command(self,
message: str,
session_id: str,
loop,
role: str,
platform: str,
message_obj):
@@ -25,7 +24,6 @@ class CommandOpenAIOfficial(Command):
hit, res = super().check_command(
message,
session_id,
loop,
role,
platform,
message_obj
-2
View File
@@ -15,7 +15,6 @@ class CommandRevChatGPT(Command):
def check_command(self,
message: str,
session_id: str,
loop,
role: str,
platform: str,
message_obj):
@@ -23,7 +22,6 @@ class CommandRevChatGPT(Command):
hit, res = super().check_command(
message,
session_id,
loop,
role,
platform,
message_obj
+3 -5
View File
@@ -13,8 +13,7 @@ class CommandRevEdgeGPT(Command):
def check_command(self,
message: str,
session_id: str,
loop,
session_id: str,
role: str,
platform: str,
message_obj):
@@ -23,7 +22,6 @@ class CommandRevEdgeGPT(Command):
hit, res = super().check_command(
message,
session_id,
loop,
role,
platform,
message_obj
@@ -32,7 +30,7 @@ class CommandRevEdgeGPT(Command):
if hit:
return True, res
if self.command_start_with(message, "reset"):
return True, self.reset(loop)
return True, self.reset()
elif self.command_start_with(message, "help"):
return True, self.help()
elif self.command_start_with(message, "update"):
@@ -40,7 +38,7 @@ class CommandRevEdgeGPT(Command):
return False, None
def reset(self, loop):
def reset(self, loop = None):
if self.provider is None:
return False, "未启动Bing语言模型.", "reset"
res = asyncio.run_coroutine_threadsafe(self.provider.forget(), loop).result()