|
|
|
@@ -31,7 +31,7 @@ from PIL import Image as PILImage
|
|
|
|
|
import io
|
|
|
|
|
import traceback
|
|
|
|
|
from . global_object import GlobalObject
|
|
|
|
|
from typing import Union
|
|
|
|
|
from typing import Union, Callable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 缓存的会话
|
|
|
|
@@ -116,15 +116,20 @@ gocq_app = CQHTTP(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
gocq_loop = None
|
|
|
|
|
bing_cache_loop = None
|
|
|
|
|
|
|
|
|
|
# 全局对象
|
|
|
|
|
_global_object: GlobalObject = None
|
|
|
|
|
|
|
|
|
|
def new_sub_thread(func, args=()):
|
|
|
|
|
thread = threading.Thread(target=func, args=args, daemon=True)
|
|
|
|
|
thread = threading.Thread(target=_runner, args=(func, args), daemon=True)
|
|
|
|
|
thread.start()
|
|
|
|
|
|
|
|
|
|
def _runner(func: Callable, args: tuple):
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
loop.run_until_complete(func(*args))
|
|
|
|
|
loop.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# [Deprecated] 写入统计信息
|
|
|
|
|
def toggle_count(at: bool, message):
|
|
|
|
@@ -399,9 +404,6 @@ def run_qqchan_bot(cfg, loop, qqchannel_bot: QQChan):
|
|
|
|
|
abs_path = os.path.abspath("QQChannelChatGPT/configs/config.yaml")
|
|
|
|
|
print("配置文件地址:" + abs_path)
|
|
|
|
|
os.system(f"notepad \"{abs_path}\"")
|
|
|
|
|
# gu.log("如果你使用了go-cqhttp, 则可以忽略上面的报错。" + str(e), gu.LEVEL_CRITICAL, tag="QQ频道")
|
|
|
|
|
# input(f"\n[System-Error] 启动QQ频道机器人时出现错误,原因如下:{e}\n可能是没有填写QQBOT appid和token?请在config中完善你的appid和token\n配置教程:https://soulter.top/posts/qpdg.html\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
运行GOCQ机器人
|
|
|
|
@@ -424,7 +426,6 @@ def run_gocq_bot(loop, gocq_bot, gocq_app):
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
input("启动QQ机器人出现错误"+str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
检查发言频率
|
|
|
|
|
'''
|
|
|
|
@@ -450,7 +451,7 @@ def check_frequency(id) -> bool:
|
|
|
|
|
'''
|
|
|
|
|
通用消息回复
|
|
|
|
|
'''
|
|
|
|
|
def send_message(platform, message, res, session_id = None):
|
|
|
|
|
async def send_message(platform, message, res, session_id = None):
|
|
|
|
|
global qqchannel_bot, qqchannel_bot, gocq_loop, session_dict
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
if session_id not in session_dict:
|
|
|
|
@@ -462,9 +463,9 @@ def send_message(platform, message, res, session_id = None):
|
|
|
|
|
if platform == PLATFORM_QQCHAN:
|
|
|
|
|
qqchannel_bot.send_qq_msg(message, res)
|
|
|
|
|
if platform == PLATFORM_GOCQ:
|
|
|
|
|
asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, res), gocq_loop).result()
|
|
|
|
|
await gocq_bot.send_qq_msg(message, res)
|
|
|
|
|
|
|
|
|
|
def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
|
|
|
|
|
async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
|
|
|
|
|
group: bool=False,
|
|
|
|
|
platform: str = None):
|
|
|
|
|
"""
|
|
|
|
@@ -474,7 +475,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
msg_ref: 引用消息(频道)
|
|
|
|
|
platform: 平台(gocq, qqchan)
|
|
|
|
|
"""
|
|
|
|
|
global chosen_provider, keywords, qqchannel_bot, gocq_bot, gocq_loop, bing_cache_loop, qqchan_loop
|
|
|
|
|
global chosen_provider, keywords, qqchannel_bot, gocq_bot
|
|
|
|
|
global _global_object
|
|
|
|
|
qq_msg = ''
|
|
|
|
|
session_id = ''
|
|
|
|
@@ -543,13 +544,13 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
with_tag = True
|
|
|
|
|
|
|
|
|
|
if qq_msg == "":
|
|
|
|
|
send_message(platform, message, f"Hi~", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"Hi~", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if with_tag:
|
|
|
|
|
# 检查发言频率
|
|
|
|
|
if not check_frequency(user_id):
|
|
|
|
|
send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# logf.write("[GOCQBOT] "+ qq_msg+'\n')
|
|
|
|
@@ -568,9 +569,9 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
image_url = keywords[k]['image_url']
|
|
|
|
|
if image_url != "":
|
|
|
|
|
res = [Plain(plain_text), Image.fromURL(image_url)]
|
|
|
|
|
send_message(platform, message, res, session_id=session_id)
|
|
|
|
|
await send_message(platform, message, res, session_id=session_id)
|
|
|
|
|
else:
|
|
|
|
|
send_message(platform, message, plain_text, session_id=session_id)
|
|
|
|
|
await send_message(platform, message, plain_text, session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 检查是否是更换语言模型的请求
|
|
|
|
@@ -592,7 +593,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
else:
|
|
|
|
|
chosen_provider = target
|
|
|
|
|
cc.put("chosen_provider", chosen_provider)
|
|
|
|
|
send_message(platform, message, f"已切换至【{chosen_provider}】", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"已切换至【{chosen_provider}】", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
chatgpt_res = ""
|
|
|
|
@@ -603,7 +604,6 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
hit, command_result = llm_command_instance[chosen_provider].check_command(
|
|
|
|
|
qq_msg,
|
|
|
|
|
session_id,
|
|
|
|
|
bing_cache_loop,
|
|
|
|
|
role,
|
|
|
|
|
platform,
|
|
|
|
|
message,
|
|
|
|
@@ -617,15 +617,15 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
for i in uw.unfit_words_q:
|
|
|
|
|
matches = re.match(i, qq_msg.strip(), re.I | re.M)
|
|
|
|
|
if matches:
|
|
|
|
|
send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
if baidu_judge != None:
|
|
|
|
|
check, msg = baidu_judge.judge(qq_msg)
|
|
|
|
|
if not check:
|
|
|
|
|
send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
if chosen_provider == None:
|
|
|
|
|
send_message(platform, message, f"管理员未启动任何语言模型或者语言模型初始化时失败。", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"管理员未启动任何语言模型或者语言模型初始化时失败。", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
try:
|
|
|
|
|
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
|
|
|
|
@@ -634,14 +634,14 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
else:
|
|
|
|
|
chatgpt_res = str(llm_instance[chosen_provider].text_chat(qq_msg, session_id))
|
|
|
|
|
elif chosen_provider == REV_EDGEGPT:
|
|
|
|
|
res, res_code = asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].text_chat(qq_msg, platform), bing_cache_loop).result()
|
|
|
|
|
res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform)
|
|
|
|
|
if res_code == 0: # bing不想继续话题,重置会话后重试。
|
|
|
|
|
send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", session_id=session_id)
|
|
|
|
|
asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].forget(), bing_cache_loop).result()
|
|
|
|
|
res, res_code = asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].text_chat(qq_msg, platform), bing_cache_loop).result()
|
|
|
|
|
await send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", session_id=session_id)
|
|
|
|
|
await llm_instance[chosen_provider].forget()
|
|
|
|
|
res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform)
|
|
|
|
|
if res_code == 0: # bing还是不想继续话题,大概率说明提问有问题。
|
|
|
|
|
asyncio.run_coroutine_threadsafe(llm_instance[chosen_provider].forget(), bing_cache_loop).result()
|
|
|
|
|
send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", session_id=session_id)
|
|
|
|
|
await llm_instance[chosen_provider].forget()
|
|
|
|
|
await send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", session_id=session_id)
|
|
|
|
|
res = ""
|
|
|
|
|
chatgpt_res = str(res)
|
|
|
|
|
|
|
|
|
@@ -650,7 +650,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
gu.log(f"调用异常:{traceback.format_exc()}", gu.LEVEL_ERROR, max_len=100000)
|
|
|
|
|
gu.log("调用语言模型例程时出现异常。原因: "+str(e), gu.LEVEL_ERROR)
|
|
|
|
|
send_message(platform, message, "调用语言模型例程时出现异常。原因: "+str(e), session_id=session_id)
|
|
|
|
|
await send_message(platform, message, "调用语言模型例程时出现异常。原因: "+str(e), session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 切换回原来的语言模型
|
|
|
|
@@ -661,7 +661,7 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
if hit:
|
|
|
|
|
# 检查指令. command_result是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
|
|
|
|
|
if command_result == None:
|
|
|
|
|
# send_message(platform, message, "指令调用未返回任何信息。", session_id=session_id)
|
|
|
|
|
# await send_message(platform, message, "指令调用未返回任何信息。", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
command = command_result[2]
|
|
|
|
@@ -671,17 +671,17 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
keywords = json.load(f)
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
send_message(platform, message, command_result[1], session_id=session_id)
|
|
|
|
|
await send_message(platform, message, command_result[1], session_id=session_id)
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
|
|
|
|
|
|
|
|
|
|
if command == "update latest r":
|
|
|
|
|
send_message(platform, message, command_result[1] + "\n\n即将自动重启。", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, command_result[1] + "\n\n即将自动重启。", session_id=session_id)
|
|
|
|
|
py = sys.executable
|
|
|
|
|
os.execl(py, py, *sys.argv)
|
|
|
|
|
|
|
|
|
|
if not command_result[0]:
|
|
|
|
|
send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 画图指令
|
|
|
|
@@ -692,14 +692,14 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
pic_res = requests.get(i, stream = True)
|
|
|
|
|
if pic_res.status_code == 200:
|
|
|
|
|
image = PILImage.open(io.BytesIO(pic_res.content))
|
|
|
|
|
send_message(platform, message, [Image.fromFileSystem(gu.save_temp_img(image))], session_id=session_id)
|
|
|
|
|
await send_message(platform, message, [Image.fromFileSystem(gu.save_temp_img(image))], session_id=session_id)
|
|
|
|
|
|
|
|
|
|
# 其他指令
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
send_message(platform, message, command_result[1], session_id=session_id)
|
|
|
|
|
await send_message(platform, message, command_result[1], session_id=session_id)
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
@@ -715,12 +715,12 @@ def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGui
|
|
|
|
|
if baidu_judge != None:
|
|
|
|
|
check, msg = baidu_judge.judge(chatgpt_res)
|
|
|
|
|
if not check:
|
|
|
|
|
send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", session_id=session_id)
|
|
|
|
|
await send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", session_id=session_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 发送信息
|
|
|
|
|
try:
|
|
|
|
|
send_message(platform, message, chatgpt_res, session_id=session_id)
|
|
|
|
|
await send_message(platform, message, chatgpt_res, session_id=session_id)
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
gu.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR)
|
|
|
|
|
|
|
|
|
@@ -733,7 +733,6 @@ class botClient(botpy.Client):
|
|
|
|
|
# 转换层
|
|
|
|
|
nakuru_guild_message = qqchannel_bot.gocq_compatible_receive(message)
|
|
|
|
|
gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999)
|
|
|
|
|
|
|
|
|
|
new_sub_thread(oper_msg, (nakuru_guild_message, True, PLATFORM_QQCHAN))
|
|
|
|
|
|
|
|
|
|
# 收到私聊消息
|
|
|
|
|