feat: 支持消息并发处理

perf: 优化代码结构
This commit is contained in:
Soulter
2022-12-16 22:07:51 +08:00
parent e723fedce8
commit a5fbc82e11
2 changed files with 235 additions and 226 deletions
+4 -3
View File
@@ -22,8 +22,8 @@ class ChatGPT:
global inst
inst = self
async def chat(self, prompt):
print("[ChatGPT] 接收到prompt")
def chat(self, prompt):
print("[OpenAI API]收到")
try:
response = openai.Completion.create(
prompt=prompt,
@@ -33,7 +33,8 @@ class ChatGPT:
raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e))
# print(response['usage'])
return response["choices"][0]["text"], response['usage']['total_tokens']
print("[ChatGPT] "+response["choices"][0]["text"])
return response["choices"][0]["text"].strip(), response['usage']['total_tokens']
def getConfigs(self):
return self.openai_configs
+231 -223
View File
@@ -1,18 +1,25 @@
import botpy
from botpy.message import Message
import yaml
import asyncio
import cores.openai.core
import re
from util.errors.errors import PromptExceededError
from botpy.message import DirectMessage
import json
from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
client = ''
# executor = ThreadPoolExecutor(max_workers=10)
# ChatGPT的实例
chatgpt = ""
# db = ""
# 缓存的会话
session_dict = {}
# 最大缓存token(在配置里改 configs/config.yaml
max_tokens = 2000
version = "1.3"
# 版本
version = "1.4"
# gpt配置(在配置改)
gpt_config = {
'engine': '',
'temperature': '',
@@ -21,50 +28,55 @@ gpt_config = {
'presence_penalty': '',
'max_tokens': '',
}
# 统计信息
count = {
}
# 统计信息
stat_file = ''
# 是否是独立会话(在配置改)
uniqueSession = False
def new_sub_thread(func, args=()):
thread = threading.Thread(target=func, args=args, daemon=True)
thread.start()
class botClient(botpy.Client):
# 收到At消息
async def on_at_message_create(self, message: Message):
global stat_file
try:
if str(message.guild_id) not in count:
count[str(message.guild_id)] = {
'count': 1,
'direct_count': 0,
}
else:
count[str(message.guild_id)]['count'] += 1
stat_file = open("./configs/stat", 'w', encoding='utf-8')
stat_file.write(json.dumps(count))
stat_file.flush()
stat_file.close()
except BaseException:
pass
await oper_msg(message=message, at=True)
toggle_count(at=True, message=message)
# executor.submit(oper_msg, message, True)
new_sub_thread(oper_msg, (message, True))
# await oper_msg(message=message, at=True)
# 收到私聊消息
async def on_direct_message_create(self, message: DirectMessage):
global stat_file
try:
if str(message.guild_id) not in count:
count[str(message.guild_id)] = {
'count': 1,
'direct_count': 1,
}
else:
count[str(message.guild_id)]['count'] += 1
count[str(message.guild_id)]['direct_count'] += 1
stat_file = open("./configs/stat", 'w', encoding='utf-8')
stat_file.write(json.dumps(count))
stat_file.flush()
stat_file.close()
except BaseException:
pass
toggle_count(at=False, message=message)
print("收到私聊消息")
# executor.submit(oper_msg, message, True)
# await oper_msg(message=message, at=False)
new_sub_thread(oper_msg, (message, False))
print("私聊消息处理完毕")
# 写入统计信息
def toggle_count(at: bool, message):
global stat_file
try:
if str(message.guild_id) not in count:
count[str(message.guild_id)] = {
'count': 1,
'direct_count': 1,
}
else:
count[str(message.guild_id)]['count'] += 1
if not at:
count[str(message.guild_id)]['direct_count'] += 1
stat_file = open("./configs/stat", 'w', encoding='utf-8')
stat_file.write(json.dumps(count))
stat_file.flush()
stat_file.close()
except BaseException:
pass
await oper_msg(message=message, at=False)
def initBot(chatgpt_inst):
global chatgpt
@@ -91,7 +103,6 @@ def initBot(chatgpt_inst):
pass
global uniqueSession
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
cfg = yaml.safe_load(ymlfile)
if 'uniqueSessionMode' in cfg['qqbot'] and cfg['qqbot']['uniqueSessionMode'] == 'true':
@@ -101,21 +112,32 @@ def initBot(chatgpt_inst):
if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '':
print("读取QQBot appid token 成功")
intents = botpy.Intents(public_guild_messages=True, direct_message=True)
global client
client = botClient(intents=intents)
client.run(appid=cfg['qqbot']['appid'], token=cfg['qqbot']['token'])
else:
raise BaseException("请在config中完善你的appid和token")
async def get_chatGPT_response(prompts_str):
'''
得到OpenAI的回复
'''
def get_chatGPT_response(prompts_str):
res = ''
usage = ''
res, usage = await chatgpt.chat(prompts_str)
res, usage = chatgpt.chat(prompts_str)
# 处理结果文本
chatgpt_res = res.strip()
return res, usage
'''
回复QQ消息
'''
def send_qq_msg(message, res):
asyncio.run_coroutine_threadsafe(message.reply(content=res), client.loop)
'''
获取缓存的会话
'''
def get_prompts_by_cache_list(cache_data_list, divide=False, paging=False, size=5, page=1):
prompts = ""
if paging:
@@ -138,9 +160,11 @@ def get_user_usage_tokens(cache_list):
usage_tokens += int(item['single_tokens'])
return usage_tokens
async def oper_msg(message, at=False):
def oper_msg(message, at=False, loop=None):
print("[QQBOT] 接收到消息:"+ message.content)
qq_msg = ""
qq_msg = ''
session_id = ''
name = ''
if at:
# 过滤用户id
@@ -150,193 +174,177 @@ async def oper_msg(message, at=False):
result = re.search(pattern, message.content)
if result:
qq_msg = result.group(1).strip()
else:
qq_msg = message.content
user_id = message.author.id
if not at:
session_id = message.author.id
else:
if uniqueSession:
session_id = message.author.id
else:
session_id = message.guild_id
if session_id:
name = ''
if uniqueSession:
name = message.member.nick
else:
qq_msg = message.content
session_id = message.author.id
if uniqueSession:
name = message.member.nick
else:
name = "频道"
# 指令控制
if qq_msg == "/reset":
msg = ''
session_dict[session_id] = []
if at:
msg = f"{name}(id: {session_id}) 的历史记录重置成功"
else:
name = "频道"
if qq_msg == "/reset":
session_dict[session_id] = []
if at:
await message.reply(content=f"{name}(id: {session_id}) 的历史记录重置成功")
else:
await message.reply(content=f"你的历史记录重置成功")
return
if qq_msg[:4] == "/his":
#分页,每页5条
size_per_page = 3
page = 1
if qq_msg[5:]:
page = int(qq_msg[5:])
# 检查是否有过历史记录
if session_id not in session_dict:
await message.reply(content=f"{name} 的历史记录为空")
return
l = session_dict[session_id]
max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page
p = get_prompts_by_cache_list(session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
if at:
await message.reply(content=f"{name} 的历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页")
else:
await message.reply(content=f"历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页")
return
if qq_msg == "/token":
if at:
await message.reply(content=f"{name} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}")
else:
await message.reply(content=f"会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}")
return
if qq_msg == "/status":
chatgpt_cfg_str = ""
for k, v in gpt_config.items():
if k == "key":
continue
chatgpt_cfg_str += f"{k}: {v}"
await message.reply(content=f"ChatGPT配置:\n - {chatgpt_cfg_str}\n QQChannelChatGPT 版本: {version}")
return
if qq_msg == "/count":
try:
f = open("./configs/stat", "r", encoding="utf-8")
fjson = json.loads(f.read())
f.close()
guild_count = 0
guild_msg_count = 0
guild_direct_msg_count = 0
for k,v in fjson.items():
guild_count += 1
guild_msg_count += v['count']
guild_direct_msg_count += v['direct_count']
session_count = 0
f = open("./configs/session", "r", encoding="utf-8")
fjson = json.loads(f.read())
f.close()
for k,v in fjson.items():
session_count += 1
except:
pass
await message.reply(content=f"当前会话数: {len(session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}")
return
if qq_msg == "/help":
await message.reply(content=f"请联系频道管理员或者前往github(仓库名: QQChannelChatGPT)提issue~")
return
msg = f"你的历史记录重置成功"
send_qq_msg(message, msg)
return
if qq_msg[:4] == "/his":
#分页,每页5条
msg = ''
size_per_page = 3
page = 1
if qq_msg[5:]:
page = int(qq_msg[5:])
# 检查是否有过历史记录
if session_id not in session_dict:
session_dict[session_id] = []
msg = f"{name} 的历史记录为空"
l = session_dict[session_id]
max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page
p = get_prompts_by_cache_list(session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
if at:
msg=f"{name} 的历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页"
else:
msg=f"历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页"
send_qq_msg(message, msg)
return
if qq_msg == "/token":
msg = ''
if at:
msg=f"{name} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}"
else:
msg=f"会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}"
send_qq_msg(message, msg)
return
if qq_msg == "/status":
chatgpt_cfg_str = ""
for k, v in gpt_config.items():
if k == "key":
continue
chatgpt_cfg_str += f"{k}: {v}"
send_qq_msg(message, f"ChatGPT配置:\n - {chatgpt_cfg_str}\n QQChannelChatGPT 版本: {version}")
return
if qq_msg == "/count":
try:
f = open("./configs/stat", "r", encoding="utf-8")
fjson = json.loads(f.read())
f.close()
guild_count = 0
guild_msg_count = 0
guild_direct_msg_count = 0
fjson = {}
try:
f = open("./configs/session", "r", encoding="utf-8")
fjson = json.loads(f.read())
f.close()
except:
pass
finally:
fjson[session_id] = 'true'
f = open("./configs/session", "w", encoding="utf-8")
f.write(json.dumps(fjson))
f.flush()
f.close()
for k,v in fjson.items():
guild_count += 1
guild_msg_count += v['count']
guild_direct_msg_count += v['direct_count']
session_count = 0
f = open("./configs/session", "r", encoding="utf-8")
fjson = json.loads(f.read())
f.close()
for k,v in fjson.items():
session_count += 1
except:
pass
send_qq_msg(message, f"当前会话数: {len(session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}")
return
if qq_msg == "/help":
send_qq_msg(message, "请联系频道管理员或者前往github(仓库名: QQChannelChatGPT)提issue~")
return
# 统计历史会话
if session_id not in session_dict:
session_dict[session_id] = []
# 获取缓存
cache_prompt = ''
fjson = {}
try:
f = open("./configs/session", "r", encoding="utf-8")
fjson = json.loads(f.read())
f.close()
except:
pass
finally:
fjson[session_id] = 'true'
f = open("./configs/session", "w", encoding="utf-8")
f.write(json.dumps(fjson))
f.flush()
f.close()
# 获取缓存
cache_prompt = ''
cache_data_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_data_list)
cache_prompt += "Human: "+ qq_msg + "\nAI: "
# 请求chatGPT获得结果
try:
chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
except (PromptExceededError) as e:
print("出现token超限, 清空对应缓存")
# 超过4097tokens错误,清空缓存
session_dict[session_id] = []
cache_data_list = []
cache_prompt = "Human: "+ qq_msg + "\nAI: "
chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
except (BaseException) as e:
print("OpenAI API错误:(")
send_qq_msg(message, f"OpenAI API错误:( 原因如下:\n{str(e)} \n*前往github(仓库名: QQChannelChatGPT)反馈~")
return
# 发送qq信息
try:
# 防止被qq频道过滤消息
gap_chatgpt_res = chatgpt_res.replace(".", " . ")
# 发送信息
send_qq_msg(message, '[ChatGPT]'+gap_chatgpt_res)
except BaseException as e:
print("QQ频道API错误: \n"+str(e))
f_res = ""
for t in chatgpt_res:
f_res += t + ' '
try:
pass
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
except BaseException as e:
# 如果还是不行则过滤url
f_res = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', f_res, flags=re.MULTILINE)
f_res = f_res.replace(".", " . ")
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
if current_usage_tokens > max_tokens:
t = current_usage_tokens
cache_list = session_dict[session_id]
index = 0
while t > max_tokens:
if index >= len(cache_list):
break
t -= int(cache_list[index]['single_tokens'])
index += 1
session_dict[session_id] = cache_list[index:]
cache_data_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_data_list)
cache_prompt += "Human: "+ qq_msg + "\nAI: "
# 请求chatGPT获得结果
try:
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
except (PromptExceededError) as e:
print("出现token超限, 清空对应缓存")
# 超过4097tokens错误,清空缓存
session_dict[session_id] = []
cache_data_list = []
cache_prompt = "Human: "+ qq_msg + "\nAI: "
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
except (BaseException) as e:
print("OpenAI API错误:(")
await message.reply(content=f"OpenAI API错误:( 原因如下:\n{str(e)} \n*前往github(仓库名: QQChannelChatGPT)反馈~")
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
# print("current_usage_tokens: ", current_usage_tokens)
# print("max_tokens: ", max_tokens)
if current_usage_tokens > max_tokens:
t = current_usage_tokens
cache_list = session_dict[session_id]
index = 0
while t > max_tokens:
if index >= len(cache_list):
break
t -= int(cache_list[index]['single_tokens'])
index += 1
session_dict[session_id] = cache_list[index:]
cache_data_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_data_list)
# cache_prompt += chatgpt_res + "\n";
# 添加新条目进入缓存的prompt
if len(cache_data_list) > 0:
single_record = {
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
"usage_tokens": current_usage_tokens,
"single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
}
else:
single_record = {
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
"usage_tokens": current_usage_tokens,
"single_tokens": current_usage_tokens
}
# print(single_record)
cache_data_list.append(single_record)
# 写入数据库
# try:
# data = {
# "data": cache_data_list
# }
# data_str = json.dumps(data)
# if len(cache_data_list) > 1:
# db.update_session(session_id, data_str)
# else:
# db.insert_session(session_id, data_str)
# except Exception as e:
# print(e)
# print("数据库写入失败")
session_dict[session_id] = cache_data_list
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
chatgpt_res = chatgpt_res.replace(".", " . ")
# 发送qq信息
try:
await message.reply(content=f"[ChatGPT]{chatgpt_res}")
except BaseException as e:
print("QQ频道API错误: \n"+str(e))
f_res = ""
for t in chatgpt_res:
f_res += t + ' '
await message.reply(content=f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
# 添加新条目进入缓存的prompt
if len(cache_data_list) > 0:
single_record = {
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
"usage_tokens": current_usage_tokens,
"single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
}
else:
single_record = {
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
"usage_tokens": current_usage_tokens,
"single_tokens": current_usage_tokens
}
cache_data_list.append(single_record)
session_dict[session_id] = cache_data_list