diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index c2c8b7f60..775593591 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -6,22 +6,26 @@ import cores.openai.core import re from util.errors.errors import PromptExceededError from botpy.message import DirectMessage +import json chatgpt = "" +db = "" session_dict = {} max_tokens = 2000 class botClient(botpy.Client): async def on_at_message_create(self, message: Message): - await oper_msg(message=message) + await oper_msg(message=message, at=True) async def on_direct_message_create(self, message: DirectMessage): print(message.content) - await oper_msg(message=message) + await oper_msg(message=message, at=False) -def initBot(chatgpt_inst): +def initBot(chatgpt_inst, db_inst): global chatgpt chatgpt = chatgpt_inst + global db + db = db_inst global max_tokens max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit']) with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile: @@ -43,17 +47,17 @@ async def get_chatGPT_response(prompts_str): chatgpt_res = res.strip() return res, usage -def get_prompts_by_cache_list(cache_prompt_list, divide=False, paging=False, size=5, page=1): +def get_prompts_by_cache_list(cache_data_list, divide=False, paging=False, size=5, page=1): prompts = "" if paging: page_begin = (page-1)*size page_end = page*size if page_begin < 0: page_begin = 0 - if page_end > len(cache_prompt_list): - page_end = len(cache_prompt_list) - cache_prompt_list = cache_prompt_list[page_begin:page_end] - for item in cache_prompt_list: + if page_end > len(cache_data_list): + page_end = len(cache_data_list) + cache_data_list = cache_data_list[page_begin:page_end] + for item in cache_data_list: prompts += str(item['prompt']) if divide: prompts += "----------\n" @@ -65,16 +69,18 @@ def get_user_usage_tokens(cache_list): usage_tokens += int(item['single_tokens']) return usage_tokens -async def oper_msg(message): +async def oper_msg(message, at=False): print("[QQBOT] 接收到消息:"+ message.content) qq_msg = "" - # 过滤用户id - pattern = r"<@!\d+>\s+(.+)" - # 多行匹配 - pattern = re.compile(pattern, flags=re.MULTILINE) - result = re.search(pattern, message.content) - if result: - qq_msg = result.group(1).strip() + + if at: + # 过滤用户id + pattern = r"<@!\d+>\s+(.+)" + # 多行匹配 + pattern = re.compile(pattern, flags=re.MULTILINE) + result = re.search(pattern, message.content) + if result: + qq_msg = result.group(1).strip() # 检测用户id,返回对应缓存的prompt # session_id_pattern = r"<@!\d+>" @@ -106,8 +112,8 @@ async def oper_msg(message): # 获取缓存 cache_prompt = '' - cache_prompt_list = session_dict[session_id] - cache_prompt = get_prompts_by_cache_list(cache_prompt_list) + cache_data_list = session_dict[session_id] + cache_prompt = get_prompts_by_cache_list(cache_data_list) cache_prompt += "Human: "+ qq_msg + "\nAI: " # 请求chatGPT获得结果 try: @@ -117,7 +123,7 @@ async def oper_msg(message): # 超过4097tokens错误,清空缓存 session_dict[session_id] = [] - cache_prompt_list = [] + cache_data_list = [] cache_prompt = "Human: "+ qq_msg + "\nAI: " chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt) @@ -132,16 +138,16 @@ async def oper_msg(message): t -= int(cache_list[index]['single_tokens']) index += 1 session_dict[session_id] = cache_list[index:] - cache_prompt_list = session_dict[session_id] - cache_prompt = get_prompts_by_cache_list(cache_prompt_list) + cache_data_list = session_dict[session_id] + cache_prompt = get_prompts_by_cache_list(cache_data_list) # cache_prompt += chatgpt_res + "\n"; # 添加新条目进入缓存的prompt - if len(cache_prompt_list) > 0: + if len(cache_data_list) > 0: single_record = { "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n', "usage_tokens": current_usage_tokens, - "single_tokens": current_usage_tokens - int(cache_prompt_list[-1]['usage_tokens']) + "single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens']) } else: single_record = { @@ -150,8 +156,9 @@ async def oper_msg(message): "single_tokens": current_usage_tokens } # print(single_record) - cache_prompt_list.append(single_record) - session_dict[session_id] = cache_prompt_list + cache_data_list.append(single_record) + + session_dict[session_id] = cache_data_list # #检测是否存在url,如果存在,则去除url 防止被qq频道过滤 chatgpt_res = chatgpt_res.replace(".", " . ")