fix: 修复了prompts超出界限的bug
feat:添加历史记录功能 perf:优化prompt缓存模式
This commit is contained in:
+10
-5
@@ -1,5 +1,6 @@
|
||||
import openai
|
||||
import yaml
|
||||
from util.errors.errors import PromptExceededError
|
||||
|
||||
|
||||
inst = None
|
||||
@@ -18,11 +19,15 @@ class ChatGPT:
|
||||
inst = self
|
||||
|
||||
async def chat(self, prompt):
|
||||
print("[ChatGPT] 接收到prompt:\n"+prompt)
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
print("[ChatGPT] 接收到prompt:\n")
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
except(openai.error.InvalidRequestError) as e:
|
||||
raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e))
|
||||
|
||||
return response["choices"][0]["text"]
|
||||
|
||||
def newSession(self):
|
||||
|
||||
+69
-36
@@ -4,6 +4,7 @@ import yaml
|
||||
import asyncio
|
||||
import cores.openai.core
|
||||
import re
|
||||
from util.errors.errors import PromptExceededError
|
||||
|
||||
chatgpt = ""
|
||||
|
||||
@@ -11,48 +12,68 @@ session_dict = {}
|
||||
|
||||
class botClient(botpy.Client):
|
||||
async def on_at_message_create(self, message: Message):
|
||||
try:
|
||||
print("[QQBOT] 接收到消息:"+ message.content)
|
||||
qq_msg = ""
|
||||
# 过滤@头
|
||||
pattern = r"<@!\d+>\s+(.+)"
|
||||
result = re.search(pattern, message.content)
|
||||
if result:
|
||||
qq_msg = result.group(1).strip()
|
||||
print("[QQBOT] 接收到消息:"+ message.content)
|
||||
qq_msg = ""
|
||||
# 过滤用户id
|
||||
pattern = r"<@!\d+>\s+(.+)"
|
||||
# 多行匹配
|
||||
pattern = re.compile(pattern, flags=re.MULTILINE)
|
||||
result = re.search(pattern, message.content)
|
||||
if result:
|
||||
qq_msg = result.group(1).strip()
|
||||
|
||||
# 检测@头,返回对应缓存的prompt
|
||||
session_id_pattern = r"<@!\d+>"
|
||||
session_id_result = re.search(session_id_pattern, message.content)
|
||||
if session_id_result:
|
||||
# 匹配出sessionid
|
||||
# 检测用户id,返回对应缓存的prompt
|
||||
session_id_pattern = r"<@!\d+>"
|
||||
session_id_result = re.search(session_id_pattern, message.content)
|
||||
if session_id_result:
|
||||
# 匹配出sessionid
|
||||
session_id = session_id_result.group(0)
|
||||
|
||||
if qq_msg == "/reset":
|
||||
session_id = session_id_result.group(0)
|
||||
session_dict[session_id] = []
|
||||
await message.reply(content=f"{message.member.nick} 的历史记录重置成功")
|
||||
return
|
||||
if qq_msg == "/his":
|
||||
p = getPromptsByCacheList(session_dict[session_id], divide=True)
|
||||
await message.reply(content=f"{message.member.nick} 的历史记录如下:\n{p}")
|
||||
return
|
||||
|
||||
if qq_msg == "/reset":
|
||||
session_id = session_id_result.group(0)
|
||||
session_dict[session_id] = ""
|
||||
await message.reply(content=f"{message.member.nick} [ChatGPT] 重置成功")
|
||||
return
|
||||
if session_id not in session_dict:
|
||||
session_dict[session_id] = []
|
||||
|
||||
if session_id not in session_dict:
|
||||
session_dict[session_id] = ""
|
||||
# 添加新条目进入缓存的prompt
|
||||
session_dict[session_id] += "Human: "+ qq_msg + "\nAI: "
|
||||
# 请求chatGPT获得结果
|
||||
chatgpt_res = await chatgpt.chat(session_dict[session_id])
|
||||
# 处理结果文本
|
||||
chatgpt_res = chatgpt_res.strip()
|
||||
# 获取缓存
|
||||
cache_prompt = ''
|
||||
cache_prompt_list = session_dict[session_id]
|
||||
cache_prompt = getPromptsByCacheList(cache_prompt_list)
|
||||
cache_prompt += "Human: "+ qq_msg + "\nAI: "
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res = await getChatGPTResponse(cache_prompt)
|
||||
except (PromptExceededError) as e:
|
||||
print(e)
|
||||
|
||||
# 超过了4096个tokens,清空cache
|
||||
session_dict[session_id] = []
|
||||
cache_prompt_list = []
|
||||
cache_prompt = "Human: "+ qq_msg + "\nAI: "
|
||||
chatgpt_res = await getChatGPTResponse(cache_prompt)
|
||||
|
||||
|
||||
session_dict[session_id] += chatgpt_res + "\n"
|
||||
# 处理结果文本
|
||||
chatgpt_res = chatgpt_res.strip()
|
||||
|
||||
# #检测是否存在url,如果存在,则去除url 防止被qq过滤
|
||||
# chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res)
|
||||
chatgpt_res = chatgpt_res.replace(".", " . ")
|
||||
cache_prompt += chatgpt_res + "\n";
|
||||
# 添加新条目进入缓存的prompt
|
||||
cache_prompt_list.append(f'Human: {qq_msg}\nAI: {chatgpt_res}\n')
|
||||
session_dict[session_id] = cache_prompt_list
|
||||
|
||||
print(f'{session_id} 目前prompt: {session_dict[session_id]}' )
|
||||
# 发送qq信息
|
||||
await message.reply(content=f"[ChatGPT]{chatgpt_res}")
|
||||
except botpy.errors.Forbidden:
|
||||
print("无法发送消息,可能是因为没有给botpy发消息的权限")
|
||||
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
|
||||
# chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res)
|
||||
chatgpt_res = chatgpt_res.replace(".", " . ")
|
||||
|
||||
# 发送qq信息
|
||||
await message.reply(content=f"[ChatGPT]{chatgpt_res}")
|
||||
|
||||
def initBot(chatgpt_inst):
|
||||
global chatgpt
|
||||
@@ -65,4 +86,16 @@ def initBot(chatgpt_inst):
|
||||
client = botClient(intents=intents)
|
||||
client.run(appid=cfg['qqbot']['appid'], token=cfg['qqbot']['token'])
|
||||
else:
|
||||
raise BaseException("请在config中完善你的appid和token")
|
||||
raise BaseException("请在config中完善你的appid和token")
|
||||
|
||||
async def getChatGPTResponse(prompts_str):
|
||||
return await chatgpt.chat(prompts_str)
|
||||
|
||||
def getPromptsByCacheList(cache_prompt_list, divide=False):
|
||||
prompts = ""
|
||||
for item in cache_prompt_list:
|
||||
prompts += str(item)
|
||||
if divide:
|
||||
prompts += "----------\n"
|
||||
return prompts
|
||||
|
||||
Reference in New Issue
Block a user