diff --git a/configs/config.yaml b/configs/config.yaml index b7d3c641d..c9d4e1067 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -7,12 +7,14 @@ openai: top_p: 1 frequency_penalty: 0.4 presence_penalty: 0.3 + total_tokens_limit: 1000 #这是默认的单个用户最大缓存token数。超过了这个数字会自动截断最早的prompt记录. qqbot: appid: token: -database: +# 未实现 +database: url: port: user: diff --git a/cores/openai/core.py b/cores/openai/core.py index aee7ca2bf..5bcf9115c 100644 --- a/cores/openai/core.py +++ b/cores/openai/core.py @@ -6,7 +6,7 @@ from util.errors.errors import PromptExceededError inst = None class ChatGPT: - def __init__(self, chatGPT_configs): + def __init__(self): with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile: cfg = yaml.safe_load(ymlfile) if cfg['openai']['key'] != '': @@ -14,7 +14,11 @@ class ChatGPT: openai.api_key = cfg['openai']['key'] else: print("请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys") + + chatGPT_configs = cfg['openai']['chatGPTConfigs'] + print(f'加载ChatGPTConfigs: {chatGPT_configs}') self.chatGPT_configs = chatGPT_configs + self.openai_configs = cfg['openai'] global inst inst = self @@ -28,10 +32,11 @@ class ChatGPT: except(openai.error.InvalidRequestError) as e: raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e)) - return response["choices"][0]["text"] - - def newSession(self): - return openai.Session() + # print(response['usage']) + return response["choices"][0]["text"], response['usage']['total_tokens'] + + def getConfigs(self): + return self.openai_configs def getInst() -> ChatGPT: global inst diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index f59f94b3f..3f9adb483 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -7,8 +7,8 @@ import re from util.errors.errors import PromptExceededError chatgpt = "" - session_dict = {} +max_tokens = 2000 class botClient(botpy.Client): async def on_at_message_create(self, message: Message): @@ -35,9 +35,12 @@ class botClient(botpy.Client): await message.reply(content=f"{message.member.nick} 的历史记录重置成功") return if qq_msg == "/his": - p = getPromptsByCacheList(session_dict[session_id], divide=True) + p = get_prompts_by_cache_list(session_dict[session_id], divide=True) await message.reply(content=f"{message.member.nick} 的历史记录如下:\n{p}") return + if qq_msg == "/token": + await message.reply(content=f"{message.member.nick} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}") + return if session_id not in session_dict: session_dict[session_id] = [] @@ -45,31 +48,53 @@ class botClient(botpy.Client): # 获取缓存 cache_prompt = '' cache_prompt_list = session_dict[session_id] - cache_prompt = getPromptsByCacheList(cache_prompt_list) + cache_prompt = get_prompts_by_cache_list(cache_prompt_list) cache_prompt += "Human: "+ qq_msg + "\nAI: " # 请求chatGPT获得结果 try: - chatgpt_res = await getChatGPTResponse(cache_prompt) + chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt) except (PromptExceededError) as e: print(e) - # 超过了4096个tokens,清空cache + # 超过4097tokens错误,清空缓存 session_dict[session_id] = [] cache_prompt_list = [] cache_prompt = "Human: "+ qq_msg + "\nAI: " - chatgpt_res = await getChatGPTResponse(cache_prompt) - + chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt) - # 处理结果文本 - chatgpt_res = chatgpt_res.strip() + # 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens + # print("current_usage_tokens: ", current_usage_tokens) + # print("max_tokens: ", max_tokens) + if current_usage_tokens > max_tokens: + t = current_usage_tokens + cache_list = session_dict[session_id] + index = 0 + while t > max_tokens: + t -= int(cache_list[index]['single_tokens']) + index += 1 + session_dict[session_id] = cache_list[index:] + cache_prompt_list = session_dict[session_id] + cache_prompt = get_prompts_by_cache_list(cache_prompt_list) - cache_prompt += chatgpt_res + "\n"; + # cache_prompt += chatgpt_res + "\n"; # 添加新条目进入缓存的prompt - cache_prompt_list.append(f'Human: {qq_msg}\nAI: {chatgpt_res}\n') + if len(cache_prompt_list) > 0: + single_record = { + "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n', + "usage_tokens": current_usage_tokens, + "single_tokens": current_usage_tokens - int(cache_prompt_list[-1]['usage_tokens']) + } + else: + single_record = { + "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n', + "usage_tokens": current_usage_tokens, + "single_tokens": current_usage_tokens + } + # print(single_record) + cache_prompt_list.append(single_record) session_dict[session_id] = cache_prompt_list # #检测是否存在url,如果存在,则去除url 防止被qq频道过滤 - # chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res) chatgpt_res = chatgpt_res.replace(".", " . ") # 发送qq信息 @@ -78,6 +103,8 @@ class botClient(botpy.Client): def initBot(chatgpt_inst): global chatgpt chatgpt = chatgpt_inst + global max_tokens + max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit']) with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile: cfg = yaml.safe_load(ymlfile) if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '': @@ -88,14 +115,25 @@ def initBot(chatgpt_inst): else: raise BaseException("请在config中完善你的appid和token") -async def getChatGPTResponse(prompts_str): - return await chatgpt.chat(prompts_str) -def getPromptsByCacheList(cache_prompt_list, divide=False): +async def get_chatGPT_response(prompts_str): + res = '' + usage = '' + res, usage = await chatgpt.chat(prompts_str) + # 处理结果文本 + chatgpt_res = res.strip() + return res, usage + +def get_prompts_by_cache_list(cache_prompt_list, divide=False): prompts = "" for item in cache_prompt_list: - prompts += str(item) + prompts += str(item['prompt']) if divide: prompts += "----------\n" return prompts - \ No newline at end of file + +def get_user_usage_tokens(cache_list): + usage_tokens = 0 + for item in cache_list: + usage_tokens += int(item['single_tokens']) + return usage_tokens \ No newline at end of file diff --git a/main.py b/main.py index 408d0bd70..3f4c0ce8d 100644 --- a/main.py +++ b/main.py @@ -5,14 +5,12 @@ import yaml def main(): # 读取参数 - with open('configs/config.yaml', 'r', encoding='utf-8') as f: - cfg = yaml.safe_load(f) - chatGPT_configs = cfg['openai']['chatGPTConfigs'] - print(chatGPT_configs) - - + # with open('configs/config.yaml', 'r', encoding='utf-8') as f: + # cfg = yaml.safe_load(f) + # chatGPT_configs = cfg['openai']['chatGPTConfigs'] + # print(chatGPT_configs) #实例化ChatGPT - chatgpt = ChatGPT(chatGPT_configs=chatGPT_configs) + chatgpt = ChatGPT() #执行qqBot qqBot.initBot(chatgpt)