From 3e26ea282fd9e82d2d71e31c25126f08a95ff2ca Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Sat, 10 Dec 2022 02:42:14 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BA=86prompts?= =?UTF-8?q?=E8=B6=85=E5=87=BA=E7=95=8C=E9=99=90=E7=9A=84bug=20feat?= =?UTF-8?q?=EF=BC=9A=E6=B7=BB=E5=8A=A0=E5=8E=86=E5=8F=B2=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=20perf=EF=BC=9A=E4=BC=98=E5=8C=96prompt?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cores/openai/core.py | 15 ++++--- cores/qqbot/core.py | 105 ++++++++++++++++++++++++++++--------------- 2 files changed, 79 insertions(+), 41 deletions(-) diff --git a/cores/openai/core.py b/cores/openai/core.py index e239c1aee..30c5d6611 100644 --- a/cores/openai/core.py +++ b/cores/openai/core.py @@ -1,5 +1,6 @@ import openai import yaml +from util.errors.errors import PromptExceededError inst = None @@ -18,11 +19,15 @@ class ChatGPT: inst = self async def chat(self, prompt): - print("[ChatGPT] 接收到prompt:\n"+prompt) - response = openai.Completion.create( - prompt=prompt, - **self.chatGPT_configs - ) + print("[ChatGPT] 接收到prompt:\n") + try: + response = openai.Completion.create( + prompt=prompt, + **self.chatGPT_configs + ) + except(openai.error.InvalidRequestError) as e: + raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e)) + return response["choices"][0]["text"] def newSession(self): diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 1cdb2a00a..f59f94b3f 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -4,6 +4,7 @@ import yaml import asyncio import cores.openai.core import re +from util.errors.errors import PromptExceededError chatgpt = "" @@ -11,48 +12,68 @@ session_dict = {} class botClient(botpy.Client): async def on_at_message_create(self, message: Message): - try: - print("[QQBOT] 接收到消息:"+ message.content) - qq_msg = "" - # 过滤@头 - pattern = r"<@!\d+>\s+(.+)" - result = re.search(pattern, message.content) - if result: - qq_msg = result.group(1).strip() + print("[QQBOT] 接收到消息:"+ message.content) + qq_msg = "" + # 过滤用户id + pattern = r"<@!\d+>\s+(.+)" + # 多行匹配 + pattern = re.compile(pattern, flags=re.MULTILINE) + result = re.search(pattern, message.content) + if result: + qq_msg = result.group(1).strip() - # 检测@头,返回对应缓存的prompt - session_id_pattern = r"<@!\d+>" - session_id_result = re.search(session_id_pattern, message.content) - if session_id_result: - # 匹配出sessionid + # 检测用户id,返回对应缓存的prompt + session_id_pattern = r"<@!\d+>" + session_id_result = re.search(session_id_pattern, message.content) + if session_id_result: + # 匹配出sessionid + session_id = session_id_result.group(0) + + if qq_msg == "/reset": session_id = session_id_result.group(0) + session_dict[session_id] = [] + await message.reply(content=f"{message.member.nick} 的历史记录重置成功") + return + if qq_msg == "/his": + p = getPromptsByCacheList(session_dict[session_id], divide=True) + await message.reply(content=f"{message.member.nick} 的历史记录如下:\n{p}") + return - if qq_msg == "/reset": - session_id = session_id_result.group(0) - session_dict[session_id] = "" - await message.reply(content=f"{message.member.nick} [ChatGPT] 重置成功") - return + if session_id not in session_dict: + session_dict[session_id] = [] - if session_id not in session_dict: - session_dict[session_id] = "" - # 添加新条目进入缓存的prompt - session_dict[session_id] += "Human: "+ qq_msg + "\nAI: " - # 请求chatGPT获得结果 - chatgpt_res = await chatgpt.chat(session_dict[session_id]) - # 处理结果文本 - chatgpt_res = chatgpt_res.strip() + # 获取缓存 + cache_prompt = '' + cache_prompt_list = session_dict[session_id] + cache_prompt = getPromptsByCacheList(cache_prompt_list) + cache_prompt += "Human: "+ qq_msg + "\nAI: " + # 请求chatGPT获得结果 + try: + chatgpt_res = await getChatGPTResponse(cache_prompt) + except (PromptExceededError) as e: + print(e) + + # 超过了4096个tokens,清空cache + session_dict[session_id] = [] + cache_prompt_list = [] + cache_prompt = "Human: "+ qq_msg + "\nAI: " + chatgpt_res = await getChatGPTResponse(cache_prompt) + - session_dict[session_id] += chatgpt_res + "\n" + # 处理结果文本 + chatgpt_res = chatgpt_res.strip() - # #检测是否存在url,如果存在,则去除url 防止被qq过滤 - # chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res) - chatgpt_res = chatgpt_res.replace(".", " . ") + cache_prompt += chatgpt_res + "\n"; + # 添加新条目进入缓存的prompt + cache_prompt_list.append(f'Human: {qq_msg}\nAI: {chatgpt_res}\n') + session_dict[session_id] = cache_prompt_list - print(f'{session_id} 目前prompt: {session_dict[session_id]}' ) - # 发送qq信息 - await message.reply(content=f"[ChatGPT]{chatgpt_res}") - except botpy.errors.Forbidden: - print("无法发送消息,可能是因为没有给botpy发消息的权限") + # #检测是否存在url,如果存在,则去除url 防止被qq频道过滤 + # chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res) + chatgpt_res = chatgpt_res.replace(".", " . ") + + # 发送qq信息 + await message.reply(content=f"[ChatGPT]{chatgpt_res}") def initBot(chatgpt_inst): global chatgpt @@ -65,4 +86,16 @@ def initBot(chatgpt_inst): client = botClient(intents=intents) client.run(appid=cfg['qqbot']['appid'], token=cfg['qqbot']['token']) else: - raise BaseException("请在config中完善你的appid和token") \ No newline at end of file + raise BaseException("请在config中完善你的appid和token") + +async def getChatGPTResponse(prompts_str): + return await chatgpt.chat(prompts_str) + +def getPromptsByCacheList(cache_prompt_list, divide=False): + prompts = "" + for item in cache_prompt_list: + prompts += str(item) + if divide: + prompts += "----------\n" + return prompts + \ No newline at end of file