fix: 优化发送流程
feat: 1.支持显示历史记录 2. 支持配置缓存最大token
This commit is contained in:
+3
-1
@@ -7,12 +7,14 @@ openai:
|
||||
top_p: 1
|
||||
frequency_penalty: 0.4
|
||||
presence_penalty: 0.3
|
||||
total_tokens_limit: 1000 #这是默认的单个用户最大缓存token数。超过了这个数字会自动截断最早的prompt记录.
|
||||
|
||||
qqbot:
|
||||
appid:
|
||||
token:
|
||||
|
||||
database:
|
||||
# 未实现
|
||||
database:
|
||||
url:
|
||||
port:
|
||||
user:
|
||||
|
||||
+10
-5
@@ -6,7 +6,7 @@ from util.errors.errors import PromptExceededError
|
||||
inst = None
|
||||
|
||||
class ChatGPT:
|
||||
def __init__(self, chatGPT_configs):
|
||||
def __init__(self):
|
||||
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
|
||||
cfg = yaml.safe_load(ymlfile)
|
||||
if cfg['openai']['key'] != '':
|
||||
@@ -14,7 +14,11 @@ class ChatGPT:
|
||||
openai.api_key = cfg['openai']['key']
|
||||
else:
|
||||
print("请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
|
||||
|
||||
chatGPT_configs = cfg['openai']['chatGPTConfigs']
|
||||
print(f'加载ChatGPTConfigs: {chatGPT_configs}')
|
||||
self.chatGPT_configs = chatGPT_configs
|
||||
self.openai_configs = cfg['openai']
|
||||
global inst
|
||||
inst = self
|
||||
|
||||
@@ -28,10 +32,11 @@ class ChatGPT:
|
||||
except(openai.error.InvalidRequestError) as e:
|
||||
raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e))
|
||||
|
||||
return response["choices"][0]["text"]
|
||||
|
||||
def newSession(self):
|
||||
return openai.Session()
|
||||
# print(response['usage'])
|
||||
return response["choices"][0]["text"], response['usage']['total_tokens']
|
||||
|
||||
def getConfigs(self):
|
||||
return self.openai_configs
|
||||
|
||||
def getInst() -> ChatGPT:
|
||||
global inst
|
||||
|
||||
+55
-17
@@ -7,8 +7,8 @@ import re
|
||||
from util.errors.errors import PromptExceededError
|
||||
|
||||
chatgpt = ""
|
||||
|
||||
session_dict = {}
|
||||
max_tokens = 2000
|
||||
|
||||
class botClient(botpy.Client):
|
||||
async def on_at_message_create(self, message: Message):
|
||||
@@ -35,9 +35,12 @@ class botClient(botpy.Client):
|
||||
await message.reply(content=f"{message.member.nick} 的历史记录重置成功")
|
||||
return
|
||||
if qq_msg == "/his":
|
||||
p = getPromptsByCacheList(session_dict[session_id], divide=True)
|
||||
p = get_prompts_by_cache_list(session_dict[session_id], divide=True)
|
||||
await message.reply(content=f"{message.member.nick} 的历史记录如下:\n{p}")
|
||||
return
|
||||
if qq_msg == "/token":
|
||||
await message.reply(content=f"{message.member.nick} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}")
|
||||
return
|
||||
|
||||
if session_id not in session_dict:
|
||||
session_dict[session_id] = []
|
||||
@@ -45,31 +48,53 @@ class botClient(botpy.Client):
|
||||
# 获取缓存
|
||||
cache_prompt = ''
|
||||
cache_prompt_list = session_dict[session_id]
|
||||
cache_prompt = getPromptsByCacheList(cache_prompt_list)
|
||||
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
|
||||
cache_prompt += "Human: "+ qq_msg + "\nAI: "
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res = await getChatGPTResponse(cache_prompt)
|
||||
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
|
||||
except (PromptExceededError) as e:
|
||||
print(e)
|
||||
|
||||
# 超过了4096个tokens,清空cache
|
||||
# 超过4097tokens错误,清空缓存
|
||||
session_dict[session_id] = []
|
||||
cache_prompt_list = []
|
||||
cache_prompt = "Human: "+ qq_msg + "\nAI: "
|
||||
chatgpt_res = await getChatGPTResponse(cache_prompt)
|
||||
|
||||
chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt)
|
||||
|
||||
# 处理结果文本
|
||||
chatgpt_res = chatgpt_res.strip()
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
# print("current_usage_tokens: ", current_usage_tokens)
|
||||
# print("max_tokens: ", max_tokens)
|
||||
if current_usage_tokens > max_tokens:
|
||||
t = current_usage_tokens
|
||||
cache_list = session_dict[session_id]
|
||||
index = 0
|
||||
while t > max_tokens:
|
||||
t -= int(cache_list[index]['single_tokens'])
|
||||
index += 1
|
||||
session_dict[session_id] = cache_list[index:]
|
||||
cache_prompt_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_prompt_list)
|
||||
|
||||
cache_prompt += chatgpt_res + "\n";
|
||||
# cache_prompt += chatgpt_res + "\n";
|
||||
# 添加新条目进入缓存的prompt
|
||||
cache_prompt_list.append(f'Human: {qq_msg}\nAI: {chatgpt_res}\n')
|
||||
if len(cache_prompt_list) > 0:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens - int(cache_prompt_list[-1]['usage_tokens'])
|
||||
}
|
||||
else:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens
|
||||
}
|
||||
# print(single_record)
|
||||
cache_prompt_list.append(single_record)
|
||||
session_dict[session_id] = cache_prompt_list
|
||||
|
||||
# #检测是否存在url,如果存在,则去除url 防止被qq频道过滤
|
||||
# chatgpt_res = re.sub(r"([\s]+)(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)([\s]+)", r"\1\3", chatgpt_res)
|
||||
chatgpt_res = chatgpt_res.replace(".", " . ")
|
||||
|
||||
# 发送qq信息
|
||||
@@ -78,6 +103,8 @@ class botClient(botpy.Client):
|
||||
def initBot(chatgpt_inst):
|
||||
global chatgpt
|
||||
chatgpt = chatgpt_inst
|
||||
global max_tokens
|
||||
max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit'])
|
||||
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
|
||||
cfg = yaml.safe_load(ymlfile)
|
||||
if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '':
|
||||
@@ -88,14 +115,25 @@ def initBot(chatgpt_inst):
|
||||
else:
|
||||
raise BaseException("请在config中完善你的appid和token")
|
||||
|
||||
async def getChatGPTResponse(prompts_str):
|
||||
return await chatgpt.chat(prompts_str)
|
||||
|
||||
def getPromptsByCacheList(cache_prompt_list, divide=False):
|
||||
async def get_chatGPT_response(prompts_str):
|
||||
res = ''
|
||||
usage = ''
|
||||
res, usage = await chatgpt.chat(prompts_str)
|
||||
# 处理结果文本
|
||||
chatgpt_res = res.strip()
|
||||
return res, usage
|
||||
|
||||
def get_prompts_by_cache_list(cache_prompt_list, divide=False):
|
||||
prompts = ""
|
||||
for item in cache_prompt_list:
|
||||
prompts += str(item)
|
||||
prompts += str(item['prompt'])
|
||||
if divide:
|
||||
prompts += "----------\n"
|
||||
return prompts
|
||||
|
||||
|
||||
def get_user_usage_tokens(cache_list):
|
||||
usage_tokens = 0
|
||||
for item in cache_list:
|
||||
usage_tokens += int(item['single_tokens'])
|
||||
return usage_tokens
|
||||
@@ -5,14 +5,12 @@ import yaml
|
||||
|
||||
def main():
|
||||
# 读取参数
|
||||
with open('configs/config.yaml', 'r', encoding='utf-8') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
chatGPT_configs = cfg['openai']['chatGPTConfigs']
|
||||
print(chatGPT_configs)
|
||||
|
||||
|
||||
# with open('configs/config.yaml', 'r', encoding='utf-8') as f:
|
||||
# cfg = yaml.safe_load(f)
|
||||
# chatGPT_configs = cfg['openai']['chatGPTConfigs']
|
||||
# print(chatGPT_configs)
|
||||
#实例化ChatGPT
|
||||
chatgpt = ChatGPT(chatGPT_configs=chatGPT_configs)
|
||||
chatgpt = ChatGPT()
|
||||
#执行qqBot
|
||||
qqBot.initBot(chatgpt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user