fix: 修复私信机器人的一些问题
This commit is contained in:
+32
-25
@@ -6,22 +6,26 @@ import cores.openai.core
|
||||
import re
|
||||
from util.errors.errors import PromptExceededError
|
||||
from botpy.message import DirectMessage
|
||||
import json
|
||||
|
||||
chatgpt = ""
|
||||
db = ""
|
||||
session_dict = {}
|
||||
max_tokens = 2000
|
||||
|
||||
class botClient(botpy.Client):
|
||||
async def on_at_message_create(self, message: Message):
|
||||
await oper_msg(message=message)
|
||||
await oper_msg(message=message, at=True)
|
||||
|
||||
async def on_direct_message_create(self, message: DirectMessage):
|
||||
print(message.content)
|
||||
await oper_msg(message=message)
|
||||
await oper_msg(message=message, at=False)
|
||||
|
||||
def initBot(chatgpt_inst):
|
||||
def initBot(chatgpt_inst, db_inst):
|
||||
global chatgpt
|
||||
chatgpt = chatgpt_inst
|
||||
global db
|
||||
db = db_inst
|
||||
global max_tokens
|
||||
max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit'])
|
||||
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
|
||||
@@ -43,17 +47,17 @@ async def get_chatGPT_response(prompts_str):
|
||||
chatgpt_res = res.strip()
|
||||
return res, usage
|
||||
|
||||
def get_prompts_by_cache_list(cache_prompt_list, divide=False, paging=False, size=5, page=1):
|
||||
def get_prompts_by_cache_list(cache_data_list, divide=False, paging=False, size=5, page=1):
|
||||
prompts = ""
|
||||
if paging:
|
||||
page_begin = (page-1)*size
|
||||
page_end = page*size
|
||||
if page_begin < 0:
|
||||
page_begin = 0
|
||||
if page_end > len(cache_prompt_list):
|
||||
page_end = len(cache_prompt_list)
|
||||
cache_prompt_list = cache_prompt_list[page_begin:page_end]
|
||||
for item in cache_prompt_list:
|
||||
if page_end > len(cache_data_list):
|
||||
page_end = len(cache_data_list)
|
||||
cache_data_list = cache_data_list[page_begin:page_end]
|
||||
for item in cache_data_list:
|
||||
prompts += str(item['prompt'])
|
||||
if divide:
|
||||
prompts += "----------\n"
|
||||
@@ -65,16 +69,18 @@ def get_user_usage_tokens(cache_list):
|
||||
usage_tokens += int(item['single_tokens'])
|
||||
return usage_tokens
|
||||
|
||||
async def oper_msg(message):
|
||||
async def oper_msg(message, at=False):
|
||||
print("[QQBOT] 接收到消息:"+ message.content)
|
||||
qq_msg = ""
|
||||
# 过滤用户id
|
||||
pattern = r"<@!\d+>\s+(.+)"
|
||||
# 多行匹配
|
||||
pattern = re.compile(pattern, flags=re.MULTILINE)
|
||||
result = re.search(pattern, message.content)
|
||||
if result:
|
||||
qq_msg = result.group(1).strip()
|
||||
|
||||
if at:
|
||||
# 过滤用户id
|
||||
pattern = r"<@!\d+>\s+(.+)"
|
||||
# 多行匹配
|
||||
pattern = re.compile(pattern, flags=re.MULTILINE)
|
||||
result = re.search(pattern, message.content)
|
||||
if result:
|
||||
qq_msg = result.group(1).strip()
|
||||
|
||||
# 检测用户id,返回对应缓存的prompt
|
||||
# session_id_pattern = r"<@!\d+>"
|
||||
@@ -106,8 +112,8 @@ async def oper_msg(message):
|
||||
|
||||
# 获取缓存
|
||||
cache_prompt = ''
|
||||
cache_prompt_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
|
||||
cache_data_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
cache_prompt += "Human: "+ qq_msg + "\nAI: "
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
@@ -117,7 +123,7 @@ async def oper_msg(message):
|
||||
|
||||
# 超过4097tokens错误,清空缓存
|
||||
session_dict[session_id] = []
|
||||
cache_prompt_list = []
|
||||
cache_data_list = []
|
||||
cache_prompt = "Human: "+ qq_msg + "\nAI: "
|
||||
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
|
||||
|
||||
@@ -132,16 +138,16 @@ async def oper_msg(message):
|
||||
t -= int(cache_list[index]['single_tokens'])
|
||||
index += 1
|
||||
session_dict[session_id] = cache_list[index:]
|
||||
cache_prompt_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
|
||||
cache_data_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
|
||||
# cache_prompt += chatgpt_res + "\n";
|
||||
# 添加新条目进入缓存的prompt
|
||||
if len(cache_prompt_list) > 0:
|
||||
if len(cache_data_list) > 0:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens - int(cache_prompt_list[-1]['usage_tokens'])
|
||||
"single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
}
|
||||
else:
|
||||
single_record = {
|
||||
@@ -150,8 +156,9 @@ async def oper_msg(message):
|
||||
"single_tokens": current_usage_tokens
|
||||
}
|
||||
# print(single_record)
|
||||
cache_prompt_list.append(single_record)
|
||||
session_dict[session_id] = cache_prompt_list
|
||||
cache_data_list.append(single_record)
|
||||
|
||||
session_dict[session_id] = cache_data_list
|
||||
|
||||
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
|
||||
chatgpt_res = chatgpt_res.replace(".", " . ")
|
||||
|
||||
Reference in New Issue
Block a user