fix: 修复私信机器人的一些问题

This commit is contained in:
Soulter
2022-12-10 17:57:49 +08:00
parent 0b8eeb6e68
commit 616108aa45
+32 -25
View File
@@ -6,22 +6,26 @@ import cores.openai.core
import re
from util.errors.errors import PromptExceededError
from botpy.message import DirectMessage
import json
chatgpt = ""
db = ""
session_dict = {}
max_tokens = 2000
class botClient(botpy.Client):
async def on_at_message_create(self, message: Message):
await oper_msg(message=message)
await oper_msg(message=message, at=True)
async def on_direct_message_create(self, message: DirectMessage):
print(message.content)
await oper_msg(message=message)
await oper_msg(message=message, at=False)
def initBot(chatgpt_inst):
def initBot(chatgpt_inst, db_inst):
global chatgpt
chatgpt = chatgpt_inst
global db
db = db_inst
global max_tokens
max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit'])
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
@@ -43,17 +47,17 @@ async def get_chatGPT_response(prompts_str):
chatgpt_res = res.strip()
return res, usage
def get_prompts_by_cache_list(cache_prompt_list, divide=False, paging=False, size=5, page=1):
def get_prompts_by_cache_list(cache_data_list, divide=False, paging=False, size=5, page=1):
prompts = ""
if paging:
page_begin = (page-1)*size
page_end = page*size
if page_begin < 0:
page_begin = 0
if page_end > len(cache_prompt_list):
page_end = len(cache_prompt_list)
cache_prompt_list = cache_prompt_list[page_begin:page_end]
for item in cache_prompt_list:
if page_end > len(cache_data_list):
page_end = len(cache_data_list)
cache_data_list = cache_data_list[page_begin:page_end]
for item in cache_data_list:
prompts += str(item['prompt'])
if divide:
prompts += "----------\n"
@@ -65,16 +69,18 @@ def get_user_usage_tokens(cache_list):
usage_tokens += int(item['single_tokens'])
return usage_tokens
async def oper_msg(message):
async def oper_msg(message, at=False):
print("[QQBOT] 接收到消息:"+ message.content)
qq_msg = ""
# 过滤用户id
pattern = r"<@!\d+>\s+(.+)"
# 多行匹配
pattern = re.compile(pattern, flags=re.MULTILINE)
result = re.search(pattern, message.content)
if result:
qq_msg = result.group(1).strip()
if at:
# 过滤用户id
pattern = r"<@!\d+>\s+(.+)"
# 多行匹配
pattern = re.compile(pattern, flags=re.MULTILINE)
result = re.search(pattern, message.content)
if result:
qq_msg = result.group(1).strip()
# 检测用户id,返回对应缓存的prompt
# session_id_pattern = r"<@!\d+>"
@@ -106,8 +112,8 @@ async def oper_msg(message):
# 获取缓存
cache_prompt = ''
cache_prompt_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
cache_data_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_data_list)
cache_prompt += "Human: "+ qq_msg + "\nAI: "
# 请求chatGPT获得结果
try:
@@ -117,7 +123,7 @@ async def oper_msg(message):
# 超过4097tokens错误,清空缓存
session_dict[session_id] = []
cache_prompt_list = []
cache_data_list = []
cache_prompt = "Human: "+ qq_msg + "\nAI: "
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
@@ -132,16 +138,16 @@ async def oper_msg(message):
t -= int(cache_list[index]['single_tokens'])
index += 1
session_dict[session_id] = cache_list[index:]
cache_prompt_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
cache_data_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_data_list)
# cache_prompt += chatgpt_res + "\n";
# 添加新条目进入缓存的prompt
if len(cache_prompt_list) > 0:
if len(cache_data_list) > 0:
single_record = {
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
"usage_tokens": current_usage_tokens,
"single_tokens": current_usage_tokens - int(cache_prompt_list[-1]['usage_tokens'])
"single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
}
else:
single_record = {
@@ -150,8 +156,9 @@ async def oper_msg(message):
"single_tokens": current_usage_tokens
}
# print(single_record)
cache_prompt_list.append(single_record)
session_dict[session_id] = cache_prompt_list
cache_data_list.append(single_record)
session_dict[session_id] = cache_data_list
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
chatgpt_res = chatgpt_res.replace(".", " . ")