fix: 优化发送流程

feat: 1.支持显示历史记录
2. 支持配置缓存最大token
This commit is contained in:
Soulter
2022-12-10 11:49:31 +08:00
parent f3c9d4577c
commit 0117591a20
4 changed files with 73 additions and 30 deletions
+3 -1
View File
@@ -7,12 +7,14 @@ openai:
top_p: 1
frequency_penalty: 0.4
presence_penalty: 0.3
total_tokens_limit: 1000 #这是默认的单个用户最大缓存token数。超过了这个数字会自动截断最早的prompt记录.
qqbot:
appid:
token:
database:
# 未实现
database:
url:
port:
user:
+10 -5
View File
@@ -6,7 +6,7 @@ from util.errors.errors import PromptExceededError
inst = None
class ChatGPT:
def __init__(self, chatGPT_configs):
def __init__(self):
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
cfg = yaml.safe_load(ymlfile)
if cfg['openai']['key'] != '':
@@ -14,7 +14,11 @@ class ChatGPT:
openai.api_key = cfg['openai']['key']
else:
print("请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
chatGPT_configs = cfg['openai']['chatGPTConfigs']
print(f'加载ChatGPTConfigs: {chatGPT_configs}')
self.chatGPT_configs = chatGPT_configs
self.openai_configs = cfg['openai']
global inst
inst = self
@@ -28,10 +32,11 @@ class ChatGPT:
except(openai.error.InvalidRequestError) as e:
raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e))
return response["choices"][0]["text"]
def newSession(self):
return openai.Session()
# print(response['usage'])
return response["choices"][0]["text"], response['usage']['total_tokens']
def getConfigs(self):
return self.openai_configs
def getInst() -> ChatGPT:
global inst
+55 -17
View File
@@ -7,8 +7,8 @@ import re
from util.errors.errors import PromptExceededError
chatgpt = ""
session_dict = {}
max_tokens = 2000
class botClient(botpy.Client):
async def on_at_message_create(self, message: Message):
@@ -35,9 +35,12 @@ class botClient(botpy.Client):
await message.reply(content=f"{message.member.nick} 的历史记录重置成功")
return
if qq_msg == "/his":
p = getPromptsByCacheList(session_dict[session_id], divide=True)
p = get_prompts_by_cache_list(session_dict[session_id], divide=True)
await message.reply(content=f"{message.member.nick} 的历史记录如下:\n{p}")
return
if qq_msg == "/token":
await message.reply(content=f"{message.member.nick} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}")
return
if session_id not in session_dict:
session_dict[session_id] = []
@@ -45,31 +48,53 @@ class botClient(botpy.Client):
# 获取缓存
cache_prompt = ''
cache_prompt_list = session_dict[session_id]
cache_prompt = getPromptsByCacheList(cache_prompt_list)
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
cache_prompt += "Human: "+ qq_msg + "\nAI: "
# 请求chatGPT获得结果
try:
chatgpt_res = await getChatGPTResponse(cache_prompt)
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
except (PromptExceededError) as e:
print(e)
# 超过4096个tokens,清空cache
# 超过4097tokens错误,清空缓存
session_dict[session_id] = []
cache_prompt_list = []
cache_prompt = "Human: "+ qq_msg + "\nAI: "
chatgpt_res = await getChatGPTResponse(cache_prompt)
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
# 处理结果文本
chatgpt_res = chatgpt_res.strip()
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
# print("current_usage_tokens: ", current_usage_tokens)
# print("max_tokens: ", max_tokens)
if current_usage_tokens > max_tokens:
t = current_usage_tokens
cache_list = session_dict[session_id]
index = 0
while t > max_tokens:
t -= int(cache_list[index]['single_tokens'])
index += 1
session_dict[session_id] = cache_list[index:]
cache_prompt_list = session_dict[session_id]
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
cache_prompt += chatgpt_res + "\n";
# cache_prompt += chatgpt_res + "\n";
# 添加新条目进入缓存的prompt
cache_prompt_list.append(f'Human: {qq_msg}\nAI: {chatgpt_res}\n')
if len(cache_prompt_list) > 0:
single_record = {
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
"usage_tokens": current_usage_tokens,
"single_tokens": current_usage_tokens - int(cache_prompt_list[-1]['usage_tokens'])
}
else:
single_record = {
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
"usage_tokens": current_usage_tokens,
"single_tokens": current_usage_tokens
}
# print(single_record)
cache_prompt_list.append(single_record)
session_dict[session_id] = cache_prompt_list
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
# chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res)
chatgpt_res = chatgpt_res.replace(".", " . ")
# 发送qq信息
@@ -78,6 +103,8 @@ class botClient(botpy.Client):
def initBot(chatgpt_inst):
global chatgpt
chatgpt = chatgpt_inst
global max_tokens
max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit'])
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
cfg = yaml.safe_load(ymlfile)
if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '':
@@ -88,14 +115,25 @@ def initBot(chatgpt_inst):
else:
raise BaseException("请在config中完善你的appid和token")
async def getChatGPTResponse(prompts_str):
return await chatgpt.chat(prompts_str)
def getPromptsByCacheList(cache_prompt_list, divide=False):
async def get_chatGPT_response(prompts_str):
res = ''
usage = ''
res, usage = await chatgpt.chat(prompts_str)
# 处理结果文本
chatgpt_res = res.strip()
return res, usage
def get_prompts_by_cache_list(cache_prompt_list, divide=False):
prompts = ""
for item in cache_prompt_list:
prompts += str(item)
prompts += str(item['prompt'])
if divide:
prompts += "----------\n"
return prompts
def get_user_usage_tokens(cache_list):
usage_tokens = 0
for item in cache_list:
usage_tokens += int(item['single_tokens'])
return usage_tokens
+5 -7
View File
@@ -5,14 +5,12 @@ import yaml
def main():
# 读取参数
with open('configs/config.yaml', 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
chatGPT_configs = cfg['openai']['chatGPTConfigs']
print(chatGPT_configs)
# with open('configs/config.yaml', 'r', encoding='utf-8') as f:
# cfg = yaml.safe_load(f)
# chatGPT_configs = cfg['openai']['chatGPTConfigs']
# print(chatGPT_configs)
#实例化ChatGPT
chatgpt = ChatGPT(chatGPT_configs=chatGPT_configs)
chatgpt = ChatGPT()
#执行qqBot
qqBot.initBot(chatgpt)