feat: 支持消息并发处理
perf: 优化代码结构
This commit is contained in:
@@ -22,8 +22,8 @@ class ChatGPT:
|
||||
global inst
|
||||
inst = self
|
||||
|
||||
async def chat(self, prompt):
|
||||
print("[ChatGPT] 接收到prompt")
|
||||
def chat(self, prompt):
|
||||
print("[OpenAI API]收到")
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
@@ -33,7 +33,8 @@ class ChatGPT:
|
||||
raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e))
|
||||
|
||||
# print(response['usage'])
|
||||
return response["choices"][0]["text"], response['usage']['total_tokens']
|
||||
print("[ChatGPT] "+response["choices"][0]["text"])
|
||||
return response["choices"][0]["text"].strip(), response['usage']['total_tokens']
|
||||
|
||||
def getConfigs(self):
|
||||
return self.openai_configs
|
||||
|
||||
+231
-223
@@ -1,18 +1,25 @@
|
||||
import botpy
|
||||
from botpy.message import Message
|
||||
import yaml
|
||||
import asyncio
|
||||
import cores.openai.core
|
||||
import re
|
||||
from util.errors.errors import PromptExceededError
|
||||
from botpy.message import DirectMessage
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
client = ''
|
||||
# executor = ThreadPoolExecutor(max_workers=10)
|
||||
# ChatGPT的实例
|
||||
chatgpt = ""
|
||||
# db = ""
|
||||
# 缓存的会话
|
||||
session_dict = {}
|
||||
# 最大缓存token(在配置里改 configs/config.yaml)
|
||||
max_tokens = 2000
|
||||
version = "1.3"
|
||||
# 版本
|
||||
version = "1.4"
|
||||
# gpt配置(在配置改)
|
||||
gpt_config = {
|
||||
'engine': '',
|
||||
'temperature': '',
|
||||
@@ -21,50 +28,55 @@ gpt_config = {
|
||||
'presence_penalty': '',
|
||||
'max_tokens': '',
|
||||
}
|
||||
# 统计信息
|
||||
count = {
|
||||
}
|
||||
# 统计信息
|
||||
stat_file = ''
|
||||
# 是否是独立会话(在配置改)
|
||||
uniqueSession = False
|
||||
|
||||
def new_sub_thread(func, args=()):
|
||||
thread = threading.Thread(target=func, args=args, daemon=True)
|
||||
thread.start()
|
||||
|
||||
class botClient(botpy.Client):
|
||||
# 收到At消息
|
||||
async def on_at_message_create(self, message: Message):
|
||||
global stat_file
|
||||
try:
|
||||
if str(message.guild_id) not in count:
|
||||
count[str(message.guild_id)] = {
|
||||
'count': 1,
|
||||
'direct_count': 0,
|
||||
}
|
||||
else:
|
||||
count[str(message.guild_id)]['count'] += 1
|
||||
stat_file = open("./configs/stat", 'w', encoding='utf-8')
|
||||
stat_file.write(json.dumps(count))
|
||||
stat_file.flush()
|
||||
stat_file.close()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
await oper_msg(message=message, at=True)
|
||||
toggle_count(at=True, message=message)
|
||||
# executor.submit(oper_msg, message, True)
|
||||
new_sub_thread(oper_msg, (message, True))
|
||||
# await oper_msg(message=message, at=True)
|
||||
|
||||
# 收到私聊消息
|
||||
async def on_direct_message_create(self, message: DirectMessage):
|
||||
global stat_file
|
||||
try:
|
||||
if str(message.guild_id) not in count:
|
||||
count[str(message.guild_id)] = {
|
||||
'count': 1,
|
||||
'direct_count': 1,
|
||||
}
|
||||
else:
|
||||
count[str(message.guild_id)]['count'] += 1
|
||||
count[str(message.guild_id)]['direct_count'] += 1
|
||||
stat_file = open("./configs/stat", 'w', encoding='utf-8')
|
||||
stat_file.write(json.dumps(count))
|
||||
stat_file.flush()
|
||||
stat_file.close()
|
||||
except BaseException:
|
||||
pass
|
||||
toggle_count(at=False, message=message)
|
||||
print("收到私聊消息")
|
||||
# executor.submit(oper_msg, message, True)
|
||||
# await oper_msg(message=message, at=False)
|
||||
new_sub_thread(oper_msg, (message, False))
|
||||
print("私聊消息处理完毕")
|
||||
|
||||
# 写入统计信息
|
||||
def toggle_count(at: bool, message):
|
||||
global stat_file
|
||||
try:
|
||||
if str(message.guild_id) not in count:
|
||||
count[str(message.guild_id)] = {
|
||||
'count': 1,
|
||||
'direct_count': 1,
|
||||
}
|
||||
else:
|
||||
count[str(message.guild_id)]['count'] += 1
|
||||
if not at:
|
||||
count[str(message.guild_id)]['direct_count'] += 1
|
||||
stat_file = open("./configs/stat", 'w', encoding='utf-8')
|
||||
stat_file.write(json.dumps(count))
|
||||
stat_file.flush()
|
||||
stat_file.close()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
await oper_msg(message=message, at=False)
|
||||
|
||||
def initBot(chatgpt_inst):
|
||||
global chatgpt
|
||||
@@ -91,7 +103,6 @@ def initBot(chatgpt_inst):
|
||||
pass
|
||||
|
||||
global uniqueSession
|
||||
|
||||
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
|
||||
cfg = yaml.safe_load(ymlfile)
|
||||
if 'uniqueSessionMode' in cfg['qqbot'] and cfg['qqbot']['uniqueSessionMode'] == 'true':
|
||||
@@ -101,21 +112,32 @@ def initBot(chatgpt_inst):
|
||||
if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '':
|
||||
print("读取QQBot appid token 成功")
|
||||
intents = botpy.Intents(public_guild_messages=True, direct_message=True)
|
||||
global client
|
||||
client = botClient(intents=intents)
|
||||
client.run(appid=cfg['qqbot']['appid'], token=cfg['qqbot']['token'])
|
||||
else:
|
||||
raise BaseException("请在config中完善你的appid和token")
|
||||
|
||||
|
||||
|
||||
async def get_chatGPT_response(prompts_str):
|
||||
'''
|
||||
得到OpenAI的回复
|
||||
'''
|
||||
def get_chatGPT_response(prompts_str):
|
||||
res = ''
|
||||
usage = ''
|
||||
res, usage = await chatgpt.chat(prompts_str)
|
||||
res, usage = chatgpt.chat(prompts_str)
|
||||
# 处理结果文本
|
||||
chatgpt_res = res.strip()
|
||||
return res, usage
|
||||
|
||||
'''
|
||||
回复QQ消息
|
||||
'''
|
||||
def send_qq_msg(message, res):
|
||||
asyncio.run_coroutine_threadsafe(message.reply(content=res), client.loop)
|
||||
|
||||
'''
|
||||
获取缓存的会话
|
||||
'''
|
||||
def get_prompts_by_cache_list(cache_data_list, divide=False, paging=False, size=5, page=1):
|
||||
prompts = ""
|
||||
if paging:
|
||||
@@ -138,9 +160,11 @@ def get_user_usage_tokens(cache_list):
|
||||
usage_tokens += int(item['single_tokens'])
|
||||
return usage_tokens
|
||||
|
||||
async def oper_msg(message, at=False):
|
||||
def oper_msg(message, at=False, loop=None):
|
||||
print("[QQBOT] 接收到消息:"+ message.content)
|
||||
qq_msg = ""
|
||||
qq_msg = ''
|
||||
session_id = ''
|
||||
name = ''
|
||||
|
||||
if at:
|
||||
# 过滤用户id
|
||||
@@ -150,193 +174,177 @@ async def oper_msg(message, at=False):
|
||||
result = re.search(pattern, message.content)
|
||||
if result:
|
||||
qq_msg = result.group(1).strip()
|
||||
else:
|
||||
qq_msg = message.content
|
||||
|
||||
user_id = message.author.id
|
||||
if not at:
|
||||
session_id = message.author.id
|
||||
else:
|
||||
if uniqueSession:
|
||||
session_id = message.author.id
|
||||
else:
|
||||
session_id = message.guild_id
|
||||
if session_id:
|
||||
name = ''
|
||||
if uniqueSession:
|
||||
name = message.member.nick
|
||||
else:
|
||||
qq_msg = message.content
|
||||
session_id = message.author.id
|
||||
|
||||
if uniqueSession:
|
||||
name = message.member.nick
|
||||
else:
|
||||
name = "频道"
|
||||
|
||||
# 指令控制
|
||||
if qq_msg == "/reset":
|
||||
msg = ''
|
||||
session_dict[session_id] = []
|
||||
if at:
|
||||
msg = f"{name}(id: {session_id}) 的历史记录重置成功"
|
||||
else:
|
||||
name = "频道"
|
||||
if qq_msg == "/reset":
|
||||
|
||||
session_dict[session_id] = []
|
||||
if at:
|
||||
await message.reply(content=f"{name}(id: {session_id}) 的历史记录重置成功")
|
||||
else:
|
||||
await message.reply(content=f"你的历史记录重置成功")
|
||||
return
|
||||
if qq_msg[:4] == "/his":
|
||||
#分页,每页5条
|
||||
size_per_page = 3
|
||||
page = 1
|
||||
if qq_msg[5:]:
|
||||
page = int(qq_msg[5:])
|
||||
# 检查是否有过历史记录
|
||||
if session_id not in session_dict:
|
||||
await message.reply(content=f"{name} 的历史记录为空")
|
||||
return
|
||||
l = session_dict[session_id]
|
||||
max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page
|
||||
p = get_prompts_by_cache_list(session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
|
||||
if at:
|
||||
await message.reply(content=f"{name} 的历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页")
|
||||
else:
|
||||
await message.reply(content=f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页")
|
||||
|
||||
return
|
||||
if qq_msg == "/token":
|
||||
if at:
|
||||
await message.reply(content=f"{name} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}")
|
||||
else:
|
||||
await message.reply(content=f"会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}")
|
||||
|
||||
return
|
||||
if qq_msg == "/status":
|
||||
chatgpt_cfg_str = ""
|
||||
for k, v in gpt_config.items():
|
||||
if k == "key":
|
||||
continue
|
||||
chatgpt_cfg_str += f"{k}: {v}"
|
||||
|
||||
await message.reply(content=f"ChatGPT配置:\n - {chatgpt_cfg_str}\n QQChannelChatGPT 版本: {version}")
|
||||
return
|
||||
|
||||
if qq_msg == "/count":
|
||||
try:
|
||||
f = open("./configs/stat", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
guild_count = 0
|
||||
guild_msg_count = 0
|
||||
guild_direct_msg_count = 0
|
||||
|
||||
for k,v in fjson.items():
|
||||
guild_count += 1
|
||||
guild_msg_count += v['count']
|
||||
guild_direct_msg_count += v['direct_count']
|
||||
|
||||
session_count = 0
|
||||
|
||||
f = open("./configs/session", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
for k,v in fjson.items():
|
||||
session_count += 1
|
||||
except:
|
||||
pass
|
||||
|
||||
await message.reply(content=f"当前会话数: {len(session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}")
|
||||
return
|
||||
|
||||
if qq_msg == "/help":
|
||||
await message.reply(content=f"请联系频道管理员或者前往github(仓库名: QQChannelChatGPT)提issue~")
|
||||
return
|
||||
|
||||
msg = f"你的历史记录重置成功"
|
||||
send_qq_msg(message, msg)
|
||||
return
|
||||
if qq_msg[:4] == "/his":
|
||||
#分页,每页5条
|
||||
msg = ''
|
||||
size_per_page = 3
|
||||
page = 1
|
||||
if qq_msg[5:]:
|
||||
page = int(qq_msg[5:])
|
||||
# 检查是否有过历史记录
|
||||
if session_id not in session_dict:
|
||||
session_dict[session_id] = []
|
||||
msg = f"{name} 的历史记录为空"
|
||||
l = session_dict[session_id]
|
||||
max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page
|
||||
p = get_prompts_by_cache_list(session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
|
||||
if at:
|
||||
msg=f"{name} 的历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页"
|
||||
else:
|
||||
msg=f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页"
|
||||
send_qq_msg(message, msg)
|
||||
return
|
||||
if qq_msg == "/token":
|
||||
msg = ''
|
||||
if at:
|
||||
msg=f"{name} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}"
|
||||
else:
|
||||
msg=f"会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}"
|
||||
send_qq_msg(message, msg)
|
||||
return
|
||||
if qq_msg == "/status":
|
||||
chatgpt_cfg_str = ""
|
||||
for k, v in gpt_config.items():
|
||||
if k == "key":
|
||||
continue
|
||||
chatgpt_cfg_str += f"{k}: {v}"
|
||||
send_qq_msg(message, f"ChatGPT配置:\n - {chatgpt_cfg_str}\n QQChannelChatGPT 版本: {version}")
|
||||
return
|
||||
if qq_msg == "/count":
|
||||
try:
|
||||
f = open("./configs/stat", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
guild_count = 0
|
||||
guild_msg_count = 0
|
||||
guild_direct_msg_count = 0
|
||||
|
||||
fjson = {}
|
||||
try:
|
||||
f = open("./configs/session", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
fjson[session_id] = 'true'
|
||||
f = open("./configs/session", "w", encoding="utf-8")
|
||||
f.write(json.dumps(fjson))
|
||||
f.flush()
|
||||
f.close()
|
||||
for k,v in fjson.items():
|
||||
guild_count += 1
|
||||
guild_msg_count += v['count']
|
||||
guild_direct_msg_count += v['direct_count']
|
||||
|
||||
session_count = 0
|
||||
|
||||
f = open("./configs/session", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
for k,v in fjson.items():
|
||||
session_count += 1
|
||||
except:
|
||||
pass
|
||||
send_qq_msg(message, f"当前会话数: {len(session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}")
|
||||
return
|
||||
if qq_msg == "/help":
|
||||
send_qq_msg(message, "请联系频道管理员或者前往github(仓库名: QQChannelChatGPT)提issue~")
|
||||
return
|
||||
|
||||
# 统计历史会话
|
||||
if session_id not in session_dict:
|
||||
session_dict[session_id] = []
|
||||
|
||||
# 获取缓存
|
||||
cache_prompt = ''
|
||||
fjson = {}
|
||||
try:
|
||||
f = open("./configs/session", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
fjson[session_id] = 'true'
|
||||
f = open("./configs/session", "w", encoding="utf-8")
|
||||
f.write(json.dumps(fjson))
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
# 获取缓存
|
||||
cache_prompt = ''
|
||||
cache_data_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
cache_prompt += "Human: "+ qq_msg + "\nAI: "
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
|
||||
except (PromptExceededError) as e:
|
||||
print("出现token超限, 清空对应缓存")
|
||||
# 超过4097tokens错误,清空缓存
|
||||
session_dict[session_id] = []
|
||||
cache_data_list = []
|
||||
cache_prompt = "Human: "+ qq_msg + "\nAI: "
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
|
||||
except (BaseException) as e:
|
||||
print("OpenAI API错误:(")
|
||||
send_qq_msg(message, f"OpenAI API错误:( 原因如下:\n{str(e)} \n*前往github(仓库名: QQChannelChatGPT)反馈~")
|
||||
return
|
||||
|
||||
# 发送qq信息
|
||||
try:
|
||||
# 防止被qq频道过滤消息
|
||||
gap_chatgpt_res = chatgpt_res.replace(".", " . ")
|
||||
# 发送信息
|
||||
send_qq_msg(message, '[ChatGPT]'+gap_chatgpt_res)
|
||||
except BaseException as e:
|
||||
print("QQ频道API错误: \n"+str(e))
|
||||
f_res = ""
|
||||
for t in chatgpt_res:
|
||||
f_res += t + ' '
|
||||
try:
|
||||
pass
|
||||
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
|
||||
except BaseException as e:
|
||||
# 如果还是不行则过滤url
|
||||
f_res = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', f_res, flags=re.MULTILINE)
|
||||
f_res = f_res.replace(".", " . ")
|
||||
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
if current_usage_tokens > max_tokens:
|
||||
t = current_usage_tokens
|
||||
cache_list = session_dict[session_id]
|
||||
index = 0
|
||||
while t > max_tokens:
|
||||
if index >= len(cache_list):
|
||||
break
|
||||
t -= int(cache_list[index]['single_tokens'])
|
||||
index += 1
|
||||
session_dict[session_id] = cache_list[index:]
|
||||
cache_data_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
cache_prompt += "Human: "+ qq_msg + "\nAI: "
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
|
||||
except (PromptExceededError) as e:
|
||||
print("出现token超限, 清空对应缓存")
|
||||
# 超过4097tokens错误,清空缓存
|
||||
session_dict[session_id] = []
|
||||
cache_data_list = []
|
||||
cache_prompt = "Human: "+ qq_msg + "\nAI: "
|
||||
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
|
||||
except (BaseException) as e:
|
||||
print("OpenAI API错误:(")
|
||||
await message.reply(content=f"OpenAI API错误:( 原因如下:\n{str(e)} \n*前往github(仓库名: QQChannelChatGPT)反馈~")
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
# print("current_usage_tokens: ", current_usage_tokens)
|
||||
# print("max_tokens: ", max_tokens)
|
||||
if current_usage_tokens > max_tokens:
|
||||
t = current_usage_tokens
|
||||
cache_list = session_dict[session_id]
|
||||
index = 0
|
||||
while t > max_tokens:
|
||||
if index >= len(cache_list):
|
||||
break
|
||||
t -= int(cache_list[index]['single_tokens'])
|
||||
index += 1
|
||||
session_dict[session_id] = cache_list[index:]
|
||||
cache_data_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
|
||||
# cache_prompt += chatgpt_res + "\n";
|
||||
# 添加新条目进入缓存的prompt
|
||||
if len(cache_data_list) > 0:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
}
|
||||
else:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens
|
||||
}
|
||||
# print(single_record)
|
||||
cache_data_list.append(single_record)
|
||||
|
||||
# 写入数据库
|
||||
# try:
|
||||
# data = {
|
||||
# "data": cache_data_list
|
||||
# }
|
||||
# data_str = json.dumps(data)
|
||||
# if len(cache_data_list) > 1:
|
||||
# db.update_session(session_id, data_str)
|
||||
# else:
|
||||
# db.insert_session(session_id, data_str)
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# print("数据库写入失败")
|
||||
|
||||
session_dict[session_id] = cache_data_list
|
||||
|
||||
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
|
||||
chatgpt_res = chatgpt_res.replace(".", " . ")
|
||||
|
||||
# 发送qq信息
|
||||
try:
|
||||
await message.reply(content=f"[ChatGPT]{chatgpt_res}")
|
||||
except BaseException as e:
|
||||
print("QQ频道API错误: \n"+str(e))
|
||||
f_res = ""
|
||||
for t in chatgpt_res:
|
||||
f_res += t + ' '
|
||||
await message.reply(content=f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
|
||||
# 添加新条目进入缓存的prompt
|
||||
if len(cache_data_list) > 0:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
}
|
||||
else:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens
|
||||
}
|
||||
cache_data_list.append(single_record)
|
||||
session_dict[session_id] = cache_data_list
|
||||
Reference in New Issue
Block a user