fix: 修复了prompts超出界限的bug

feat:添加历史记录功能
perf:优化prompt缓存模式
This commit is contained in:
Soulter
2022-12-10 02:42:14 +08:00
parent d9d7eca04a
commit 3e26ea282f
2 changed files with 79 additions and 41 deletions
+6 -1
View File
@@ -1,5 +1,6 @@
import openai
import yaml
from util.errors.errors import PromptExceededError
inst = None
@@ -18,11 +19,15 @@ class ChatGPT:
inst = self
async def chat(self, prompt):
print("[ChatGPT] 接收到prompt:\n"+prompt)
print("[ChatGPT] 接收到prompt:\n")
try:
response = openai.Completion.create(
prompt=prompt,
**self.chatGPT_configs
)
except(openai.error.InvalidRequestError) as e:
raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e))
return response["choices"][0]["text"]
def newSession(self):
+47 -14
View File
@@ -4,6 +4,7 @@ import yaml
import asyncio
import cores.openai.core
import re
from util.errors.errors import PromptExceededError
chatgpt = ""
@@ -11,16 +12,17 @@ session_dict = {}
class botClient(botpy.Client):
async def on_at_message_create(self, message: Message):
try:
print("[QQBOT] 接收到消息:"+ message.content)
qq_msg = ""
# 过滤@头
# 过滤用户id
pattern = r"<@!\d+>\s+(.+)"
# 多行匹配
pattern = re.compile(pattern, flags=re.MULTILINE)
result = re.search(pattern, message.content)
if result:
qq_msg = result.group(1).strip()
# 检测@头,返回对应缓存的prompt
# 检测用户id,返回对应缓存的prompt
session_id_pattern = r"<@!\d+>"
session_id_result = re.search(session_id_pattern, message.content)
if session_id_result:
@@ -29,30 +31,49 @@ class botClient(botpy.Client):
if qq_msg == "/reset":
session_id = session_id_result.group(0)
session_dict[session_id] = ""
await message.reply(content=f"{message.member.nick} [ChatGPT] 重置成功")
session_dict[session_id] = []
await message.reply(content=f"{message.member.nick} 的历史记录重置成功")
return
if qq_msg == "/his":
p = getPromptsByCacheList(session_dict[session_id], divide=True)
await message.reply(content=f"{message.member.nick} 的历史记录如下:\n{p}")
return
if session_id not in session_dict:
session_dict[session_id] = ""
# 添加新条目进入缓存的prompt
session_dict[session_id] += "Human: "+ qq_msg + "\nAI: "
session_dict[session_id] = []
# 获取缓存
cache_prompt = ''
cache_prompt_list = session_dict[session_id]
cache_prompt = getPromptsByCacheList(cache_prompt_list)
cache_prompt += "Human: "+ qq_msg + "\nAI: "
# 请求chatGPT获得结果
chatgpt_res = await chatgpt.chat(session_dict[session_id])
try:
chatgpt_res = await getChatGPTResponse(cache_prompt)
except (PromptExceededError) as e:
print(e)
# 超过了4096个tokens,清空cache
session_dict[session_id] = []
cache_prompt_list = []
cache_prompt = "Human: "+ qq_msg + "\nAI: "
chatgpt_res = await getChatGPTResponse(cache_prompt)
# 处理结果文本
chatgpt_res = chatgpt_res.strip()
session_dict[session_id] += chatgpt_res + "\n"
cache_prompt += chatgpt_res + "\n";
# 添加新条目进入缓存的prompt
cache_prompt_list.append(f'Human: {qq_msg}\nAI: {chatgpt_res}\n')
session_dict[session_id] = cache_prompt_list
# #检测是否存在url,如果存在,则去除url 防止被qq过滤
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
# chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res)
chatgpt_res = chatgpt_res.replace(".", " . ")
print(f'{session_id} 目前prompt: {session_dict[session_id]}' )
# 发送qq信息
await message.reply(content=f"[ChatGPT]{chatgpt_res}")
except botpy.errors.Forbidden:
print("无法发送消息,可能是因为没有给botpy发消息的权限")
def initBot(chatgpt_inst):
global chatgpt
@@ -66,3 +87,15 @@ def initBot(chatgpt_inst):
client.run(appid=cfg['qqbot']['appid'], token=cfg['qqbot']['token'])
else:
raise BaseException("请在config中完善你的appid和token")
async def getChatGPTResponse(prompts_str):
return await chatgpt.chat(prompts_str)
def getPromptsByCacheList(cache_prompt_list, divide=False):
prompts = ""
for item in cache_prompt_list:
prompts += str(item)
if divide:
prompts += "----------\n"
return prompts