Merge pull request #23 from Soulter/ChatGPTAI-perf
perf: 重构ChatGPT API,更稳定。
This commit is contained in:
+22
-40
@@ -29,55 +29,37 @@ class ChatGPT:
|
||||
self.chatGPT_configs = chatGPT_configs
|
||||
self.openai_configs = cfg
|
||||
|
||||
def chat(self, prompt, image_mode = False):
|
||||
def chat(self, req, image_mode = False):
|
||||
# ChatGPT API 2023/3/2
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
try:
|
||||
if not image_mode:
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
else:
|
||||
response = openai.Image.create(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size="512x512",
|
||||
)
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided.' in str(e):
|
||||
print("[System] 当前Key已超额,正在切换")
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
print("[System] 当前Key已超额或者不正常,正在切换")
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
|
||||
response, is_switched = self.handle_switch_key(prompt)
|
||||
response, is_switched = self.handle_switch_key(req)
|
||||
if not is_switched:
|
||||
# 所有Key都超额
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
else:
|
||||
if not image_mode:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
else:
|
||||
response = openai.Image.create(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size="512x512",
|
||||
)
|
||||
if not image_mode:
|
||||
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens']
|
||||
self.save_key_record()
|
||||
print("[ChatGPT] "+str(response["choices"][0]["message"]["content"]))
|
||||
return str(response["choices"][0]["message"]["content"]).strip(), response['usage']['total_tokens']
|
||||
else:
|
||||
return response['data'][0]['url']
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens']
|
||||
self.save_key_record()
|
||||
print("[ChatGPT] "+str(response["choices"][0]["message"]["content"]))
|
||||
return str(response["choices"][0]["message"]["content"]).strip(), response['usage']['total_tokens']
|
||||
|
||||
def handle_switch_key(self, prompt):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
def handle_switch_key(self, req):
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
while True:
|
||||
is_all_exceed = True
|
||||
for key in self.key_stat:
|
||||
@@ -85,10 +67,10 @@ class ChatGPT:
|
||||
is_all_exceed = False
|
||||
openai.api_key = key
|
||||
print(f"[System] 切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}")
|
||||
if prompt != '':
|
||||
if len(req) > 0:
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
return response, True
|
||||
|
||||
+72
-28
@@ -308,16 +308,23 @@ def run_bot(appid, token):
|
||||
'''
|
||||
得到OpenAI官方API的回复
|
||||
'''
|
||||
def get_chatGPT_response(prompts_str, image_mode=False):
|
||||
def get_chatGPT_response(context, request, image_mode=False):
|
||||
res = ''
|
||||
usage = ''
|
||||
|
||||
req_list = []
|
||||
for i in context:
|
||||
req_list.append(i['user'])
|
||||
req_list.append(i['AI'])
|
||||
req_list.append(request['user'])
|
||||
|
||||
if not image_mode:
|
||||
res, usage = chatgpt.chat(prompts_str)
|
||||
res, usage = chatgpt.chat(req_list)
|
||||
# 处理结果文本
|
||||
chatgpt_res = res.strip()
|
||||
return res, usage
|
||||
else:
|
||||
res = chatgpt.chat(prompts_str, image_mode = True)
|
||||
res = chatgpt.chat(req_list, image_mode = True)
|
||||
return res
|
||||
|
||||
'''
|
||||
@@ -369,10 +376,13 @@ def get_prompts_by_cache_list(cache_data_list, divide=False, paging=False, size=
|
||||
page_end = len(cache_data_list)
|
||||
cache_data_list = cache_data_list[page_begin:page_end]
|
||||
for item in cache_data_list:
|
||||
prompts += str(item['prompt'])
|
||||
prompts += str(item['user']['role']) + ":\n" + str(item['user']['content']) + "\n"
|
||||
prompts += str(item['AI']['role']) + ":\n" + str(item['AI']['content']) + "\n"
|
||||
|
||||
if divide:
|
||||
prompts += "----------\n"
|
||||
return prompts
|
||||
|
||||
|
||||
def get_user_usage_tokens(cache_list):
|
||||
usage_tokens = 0
|
||||
@@ -497,26 +507,44 @@ def oper_msg(message, at=False, loop=None):
|
||||
if provider == OPENAI_OFFICIAL:
|
||||
|
||||
# 获取缓存
|
||||
cache_prompt = ''
|
||||
# cache_prompt = ''
|
||||
cache_data_list = session_dict[session_id]
|
||||
cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
cache_prompt += "\nHuman: "+ qq_msg + "\nAI: "
|
||||
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
# cache_prompt += "\nHuman: "+ qq_msg + "\nAI: "
|
||||
|
||||
# 创建一个新的Record
|
||||
|
||||
record_obj = {
|
||||
"user": {
|
||||
"role": "user",
|
||||
"content": qq_msg,
|
||||
},
|
||||
"AI": {},
|
||||
'usage_tokens': 0,
|
||||
'level': 'normal',
|
||||
}
|
||||
|
||||
if command_type == 1:
|
||||
record_obj['user'] = 'system'
|
||||
|
||||
print("[Debug] "+ str(cache_data_list))
|
||||
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(prompts_str=cache_prompt)
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(context=cache_data_list, request=record_obj)
|
||||
except (BaseException) as e:
|
||||
print("[System-Err] OpenAI API错误。原因如下:\n"+str(e))
|
||||
if 'maximum context length' in str(e):
|
||||
print("token超限, 清空对应缓存")
|
||||
session_dict[session_id] = []
|
||||
cache_data_list = []
|
||||
cache_prompt = "Human: "+ qq_msg + "\nAI: "
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(prompts_str=cache_prompt)
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(context=cache_data_list, request=record_obj)
|
||||
elif 'exceeded' in str(e):
|
||||
send_qq_msg(message, f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)")
|
||||
else:
|
||||
send_qq_msg(message, f"OpenAI API错误。原因如下:\n{str(e)} \n前往官方频道反馈~")
|
||||
f_res = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
|
||||
f_res = f_res.replace(".", "·")
|
||||
send_qq_msg(message, f"OpenAI API错误。原因如下:\n{f_res} \n前往官方频道反馈~")
|
||||
return
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
@@ -533,28 +561,44 @@ def oper_msg(message, at=False, loop=None):
|
||||
index += 1
|
||||
# 删除完后更新相关字段
|
||||
session_dict[session_id] = cache_data_list
|
||||
cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
|
||||
# 添加新条目进入缓存的prompt
|
||||
# 人格置顶
|
||||
if command_type == 1:
|
||||
level = 'max'
|
||||
else:
|
||||
level = 'normal'
|
||||
if len(cache_data_list) > 0:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens']),
|
||||
"level": level
|
||||
}
|
||||
|
||||
# 添加新条目进入缓存的prompt
|
||||
record_obj['AI'] = {
|
||||
'role': 'assistant',
|
||||
'content': chatgpt_res,
|
||||
}
|
||||
record_obj['usage_tokens'] = current_usage_tokens
|
||||
if len(cache_data_list) > 0:
|
||||
record_obj['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
else:
|
||||
single_record = {
|
||||
"prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
"usage_tokens": current_usage_tokens,
|
||||
"single_tokens": current_usage_tokens,
|
||||
"level": level
|
||||
}
|
||||
cache_data_list.append(single_record)
|
||||
record_obj['single_tokens'] = current_usage_tokens
|
||||
record_obj['level'] = level
|
||||
|
||||
cache_data_list.append(record_obj)
|
||||
# if len(cache_data_list) > 0:
|
||||
# single_record = {
|
||||
# 'role': 'assistant',
|
||||
# "content": chatgpt_res,
|
||||
# "usage_tokens": current_usage_tokens,
|
||||
# "single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens']),
|
||||
# "level": level
|
||||
# }
|
||||
# else:
|
||||
# single_record = {
|
||||
# 'role': 'assistant',
|
||||
# "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
# "usage_tokens": current_usage_tokens,
|
||||
# "single_tokens": current_usage_tokens,
|
||||
# "level": level
|
||||
# }
|
||||
# cache_data_list.append(single_record)
|
||||
session_dict[session_id] = cache_data_list
|
||||
|
||||
elif provider == REV_CHATGPT:
|
||||
@@ -587,7 +631,7 @@ def oper_msg(message, at=False, loop=None):
|
||||
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
|
||||
except BaseException as e:
|
||||
# 如果还是不行则过滤url
|
||||
f_res = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', f_res, flags=re.MULTILINE)
|
||||
f_res = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
|
||||
f_res = f_res.replace(".", "·")
|
||||
send_qq_msg(message, ''+f_res)
|
||||
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
|
||||
|
||||
Reference in New Issue
Block a user