refactor: 重构部分代码-officialapi
This commit is contained in:
+11
-193
@@ -12,7 +12,6 @@ import json
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
from cores.database.conn import dbConn
|
||||
import requests
|
||||
import util.unfit_words as uw
|
||||
import os
|
||||
@@ -21,7 +20,6 @@ from cores.qqbot.personality import personalities
|
||||
from addons.baidu_aip_judge import BaiduJudge
|
||||
|
||||
|
||||
history_dump_interval = 10
|
||||
# QQBotClient实例
|
||||
client = ''
|
||||
# ChatGPT实例
|
||||
@@ -127,30 +125,6 @@ def toggle_count(at: bool, message):
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
# 转储历史记录的定时器~ Soulter
|
||||
def dump_history():
|
||||
time.sleep(10)
|
||||
global session_dict, history_dump_interval
|
||||
db = dbConn()
|
||||
while True:
|
||||
try:
|
||||
# print("转储历史记录...")
|
||||
for key in session_dict:
|
||||
# print("TEST: "+str(db.get_session(key)))
|
||||
data = session_dict[key]
|
||||
data_json = {
|
||||
'data': data
|
||||
}
|
||||
if db.check_session(key):
|
||||
db.update_session(key, json.dumps(data_json))
|
||||
else:
|
||||
db.insert_session(key, json.dumps(data_json))
|
||||
# print("转储历史记录完毕")
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
# 每隔10分钟转储一次
|
||||
time.sleep(10*history_dump_interval)
|
||||
|
||||
# 上传统计信息并检查更新
|
||||
def upload():
|
||||
global object_id
|
||||
@@ -192,7 +166,7 @@ def upload():
|
||||
'''
|
||||
def initBot(cfg, prov):
|
||||
global chatgpt, provider, rev_chatgpt, baidu_judge, rev_ernie, rev_edgegpt
|
||||
global reply_prefix, now_personality, gpt_config, config, uniqueSession, history_dump_interval, frequency_count, frequency_time,announcement, direct_message_mode, version
|
||||
global reply_prefix, now_personality, gpt_config, config, uniqueSession, frequency_count, frequency_time,announcement, direct_message_mode, version
|
||||
|
||||
provider = prov
|
||||
config = cfg
|
||||
@@ -220,42 +194,8 @@ def initBot(cfg, prov):
|
||||
else:
|
||||
input("[System-err] 请退出本程序, 然后在配置文件中填写rev_ChatGPT相关配置")
|
||||
elif prov == OPENAI_OFFICIAL:
|
||||
from cores.openai.core import ChatGPT
|
||||
chatgpt = ChatGPT(cfg['openai'])
|
||||
global max_tokens
|
||||
max_tokens = int(chatgpt.getConfigs()['total_tokens_limit'])
|
||||
|
||||
# 读取历史记录 Soulter
|
||||
try:
|
||||
db1 = dbConn()
|
||||
for session in db1.get_all_session():
|
||||
session_dict[session[0]] = json.loads(session[1])['data']
|
||||
print("[System] 历史记录读取成功喵")
|
||||
except BaseException as e:
|
||||
print("[System] 历史记录读取失败: " + str(e))
|
||||
|
||||
# 读统计信息
|
||||
global stat_file
|
||||
if not os.path.exists(abs_path+"configs/stat"):
|
||||
with open(abs_path+"configs/stat", 'w', encoding='utf-8') as f:
|
||||
json.dump({}, f)
|
||||
stat_file = open(abs_path+"configs/stat", 'r', encoding='utf-8')
|
||||
global count
|
||||
res = stat_file.read()
|
||||
if res == '':
|
||||
count = {}
|
||||
else:
|
||||
try:
|
||||
count = json.loads(res)
|
||||
except BaseException:
|
||||
pass
|
||||
# 创建转储定时器线程
|
||||
threading.Thread(target=dump_history, daemon=True).start()
|
||||
|
||||
# 得到GPT配置信息
|
||||
if 'openai' in cfg and 'chatGPTConfigs' in cfg['openai']:
|
||||
gpt_config = cfg['openai']['chatGPTConfigs']
|
||||
|
||||
from model.provider.provider_openai_official import ProviderOpenAIOfficial
|
||||
chatgpt = ProviderOpenAIOfficial(cfg['openai'])
|
||||
if OPENAI_OFFICIAL in reply_prefix_config:
|
||||
reply_prefix = reply_prefix_config[OPENAI_OFFICIAL]
|
||||
elif prov == REV_ERNIE:
|
||||
@@ -325,8 +265,7 @@ def initBot(cfg, prov):
|
||||
uniqueSession = False
|
||||
print("[System] 独立会话: " + str(uniqueSession))
|
||||
if 'dump_history_interval' in cfg:
|
||||
history_dump_interval = int(cfg['dump_history_interval'])
|
||||
print("[System] 历史记录转储时间周期: " + str(history_dump_interval) + "分钟")
|
||||
print("[System] 历史记录转储时间周期: " + cfg['dump_history_interval'] + "分钟")
|
||||
except BaseException:
|
||||
print("[System-Error] 读取uniqueSessionMode/version/dump_history_interval配置文件失败, 使用默认值。")
|
||||
|
||||
@@ -351,29 +290,6 @@ def run_bot(appid, token):
|
||||
client = botClient(intents=intents)
|
||||
client.run(appid=appid, token=token)
|
||||
|
||||
'''
|
||||
得到OpenAI官方API的回复
|
||||
'''
|
||||
def get_chatGPT_response(context, request, image_mode=False, img_num=1, img_size="1024*1024"):
|
||||
res = ''
|
||||
usage = ''
|
||||
|
||||
req_list = []
|
||||
for i in context:
|
||||
req_list.append(i['user'])
|
||||
req_list.append(i['AI'])
|
||||
req_list.append(request['user'])
|
||||
|
||||
if not image_mode:
|
||||
# print("[Debug] "+ str(req_list))
|
||||
res, usage = chatgpt.chat(req_list)
|
||||
# 处理结果文本
|
||||
chatgpt_res = res.strip()
|
||||
return res, usage
|
||||
else:
|
||||
res = chatgpt.chat(req_list, image_mode=True, img_num=img_num, img_size=img_size)
|
||||
return res
|
||||
|
||||
'''
|
||||
负载均衡,得到逆向ChatGPT回复
|
||||
'''
|
||||
@@ -572,64 +488,20 @@ def oper_msg(message, at=False, msg_ref = None):
|
||||
send_qq_msg(message, f"你的提问得到的回复未通过【百度AI内容审核】服务,不予回复。\n\n{msg}", msg_ref=msg_ref)
|
||||
return
|
||||
|
||||
# 会话机制
|
||||
if session_id not in session_dict:
|
||||
session_dict[session_id] = []
|
||||
|
||||
fjson = {}
|
||||
try:
|
||||
f = open(abs_path+"configs/session", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
fjson[session_id] = 'true'
|
||||
f = open(abs_path+"configs/session", "w", encoding="utf-8")
|
||||
f.write(json.dumps(fjson))
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
chatgpt_res = "[Error] 占位符"
|
||||
|
||||
if provider == OPENAI_OFFICIAL:
|
||||
|
||||
# 获取缓存
|
||||
# cache_prompt = ''
|
||||
cache_data_list = session_dict[session_id]
|
||||
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
# cache_prompt += "\nHuman: "+ qq_msg + "\nAI: "
|
||||
|
||||
# 创建一个新的Record
|
||||
|
||||
record_obj = {
|
||||
"user": {
|
||||
"role": "user",
|
||||
"content": qq_msg,
|
||||
},
|
||||
"AI": {},
|
||||
'usage_tokens': 0,
|
||||
}
|
||||
record_obj_img = {
|
||||
"user": {
|
||||
"role": "user",
|
||||
"content": qq_msg[1:], # 去掉第一个字符
|
||||
},
|
||||
"AI": {},
|
||||
'usage_tokens': 0,
|
||||
}
|
||||
# ChatGPT API 回复倾向(人格)
|
||||
if command_type == 1:
|
||||
record_obj["user"]["role"] = "system"
|
||||
record_obj_img["user"]["role"] = "system"
|
||||
# print("[Debug] "+ str(cache_data_list))
|
||||
# print("qq_msg", qq_msg)
|
||||
# print("qq_msg.strip", qq_msg.strip())
|
||||
# if command_type == 1:
|
||||
# record_obj["user"]["role"] = "system"
|
||||
# record_obj_img["user"]["role"] = "system"
|
||||
|
||||
if qq_msg[0] == '画':
|
||||
print("[Debug] 画图模式")
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res = get_chatGPT_response(context=[], request=record_obj_img, image_mode=True, img_num=1, img_size="1024x1024")
|
||||
chatgpt_res = chatgpt.image_chat(qq_msg)
|
||||
# print(chatgpt_res)
|
||||
for i in range(len(chatgpt_res)):
|
||||
send_qq_msg(message, chatgpt_res[i], image_mode=True)
|
||||
@@ -649,16 +521,10 @@ def oper_msg(message, at=False, msg_ref = None):
|
||||
else:
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(context=cache_data_list, request=record_obj)
|
||||
chatgpt_res = reply_prefix + chatgpt_res
|
||||
chatgpt_res = reply_prefix + chatgpt.text_chat(qq_msg, session_id)
|
||||
except (BaseException) as e:
|
||||
print("[System-Err] OpenAI API错误。原因如下:\n"+str(e))
|
||||
if 'maximum context length' in str(e):
|
||||
print("token超限, 清空对应缓存")
|
||||
session_dict[session_id] = []
|
||||
cache_data_list = []
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(context=cache_data_list, request=record_obj)
|
||||
elif 'exceeded' in str(e):
|
||||
if 'exceeded' in str(e):
|
||||
send_qq_msg(message, f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)")
|
||||
return
|
||||
else:
|
||||
@@ -666,54 +532,6 @@ def oper_msg(message, at=False, msg_ref = None):
|
||||
f_res = f_res.replace(".", "·")
|
||||
send_qq_msg(message, f"OpenAI API错误。原因如下:\n{f_res} \n前往官方频道反馈~")
|
||||
return
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
if current_usage_tokens > max_tokens:
|
||||
t = current_usage_tokens
|
||||
index = 0
|
||||
while t > max_tokens:
|
||||
if index >= len(cache_data_list):
|
||||
break
|
||||
# 保留倾向(人格)信息
|
||||
if 'user' in cache_data_list[index] and cache_data_list[index]['user']['role'] != 'system':
|
||||
t -= int(cache_data_list[index]['single_tokens'])
|
||||
del cache_data_list[index]
|
||||
else:
|
||||
index += 1
|
||||
# 删除完后更新相关字段
|
||||
session_dict[session_id] = cache_data_list
|
||||
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
|
||||
# 添加新条目进入缓存的prompt
|
||||
record_obj['AI'] = {
|
||||
'role': 'assistant',
|
||||
'content': chatgpt_res,
|
||||
}
|
||||
record_obj['usage_tokens'] = current_usage_tokens
|
||||
if len(cache_data_list) > 0:
|
||||
record_obj['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
else:
|
||||
record_obj['single_tokens'] = current_usage_tokens
|
||||
|
||||
cache_data_list.append(record_obj)
|
||||
# if len(cache_data_list) > 0:
|
||||
# single_record = {
|
||||
# 'role': 'assistant',
|
||||
# "content": chatgpt_res,
|
||||
# "usage_tokens": current_usage_tokens,
|
||||
# "single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens']),
|
||||
# "level": level
|
||||
# }
|
||||
# else:
|
||||
# single_record = {
|
||||
# 'role': 'assistant',
|
||||
# "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n',
|
||||
# "usage_tokens": current_usage_tokens,
|
||||
# "single_tokens": current_usage_tokens,
|
||||
# "level": level
|
||||
# }
|
||||
# cache_data_list.append(single_record)
|
||||
session_dict[session_id] = cache_data_list
|
||||
|
||||
elif provider == REV_CHATGPT:
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import abc
|
||||
|
||||
class Command:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_command(self, message):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
return False
|
||||
|
||||
def set(self):
|
||||
return False
|
||||
|
||||
def unset(self):
|
||||
return False
|
||||
|
||||
def key(self):
|
||||
return False
|
||||
|
||||
def help(self):
|
||||
return False
|
||||
|
||||
def status(self):
|
||||
return False
|
||||
|
||||
def token(self):
|
||||
return False
|
||||
|
||||
def his(self):
|
||||
return False
|
||||
|
||||
def draw(self):
|
||||
return False
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import abc
|
||||
|
||||
class Provider:
|
||||
def __init__(self, cfg):
|
||||
pass
|
||||
|
||||
def text_chat(self, prompt):
|
||||
pass
|
||||
|
||||
def image_chat(self, prompt):
|
||||
pass
|
||||
|
||||
def memory(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def forget(self) -> bool:
|
||||
pass
|
||||
@@ -0,0 +1,331 @@
|
||||
import openai
|
||||
import yaml
|
||||
from util.errors.errors import PromptExceededError
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
from cores.database.conn import dbConn
|
||||
from model.provider.provider import Provider
|
||||
import threading
|
||||
|
||||
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
||||
key_record_path = abs_path+'chatgpt_key_record'
|
||||
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, cfg):
|
||||
self.key_list = []
|
||||
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
||||
openai.api_base = cfg['api_base']
|
||||
if cfg['key'] != '' and cfg['key'] != None:
|
||||
print("[System] 读取ChatGPT Key成功")
|
||||
self.key_list = cfg['key']
|
||||
else:
|
||||
input("[System] 请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
|
||||
|
||||
# init key record
|
||||
self.init_key_record()
|
||||
|
||||
self.chatGPT_configs = cfg['chatGPTConfigs']
|
||||
print(f'[System] 加载ChatGPTConfigs: {self.chatGPT_configs}')
|
||||
self.openai_configs = cfg
|
||||
# 会话缓存
|
||||
self.session_dict = {}
|
||||
# 最大缓存token
|
||||
self.max_tokens = cfg['total_tokens_limit']
|
||||
# 历史记录持久化间隔时间
|
||||
self.history_dump_interval = 20
|
||||
|
||||
# 读取历史记录
|
||||
try:
|
||||
db1 = dbConn()
|
||||
for session in db1.get_all_session():
|
||||
self.session_dict[session[0]] = json.loads(session[1])['data']
|
||||
print("[System] 历史记录读取成功喵")
|
||||
except BaseException as e:
|
||||
print("[System] 历史记录读取失败: " + str(e))
|
||||
|
||||
# 读取统计信息
|
||||
if not os.path.exists(abs_path+"configs/stat"):
|
||||
with open(abs_path+"configs/stat", 'w', encoding='utf-8') as f:
|
||||
json.dump({}, f)
|
||||
self.stat_file = open(abs_path+"configs/stat", 'r', encoding='utf-8')
|
||||
global count
|
||||
res = self.stat_file.read()
|
||||
if res == '':
|
||||
count = {}
|
||||
else:
|
||||
try:
|
||||
count = json.loads(res)
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
# 创建转储定时器线程
|
||||
threading.Thread(target=self.dump_history, daemon=True).start()
|
||||
|
||||
# 转储历史记录的定时器~ Soulter
|
||||
def dump_history(self):
|
||||
time.sleep(10)
|
||||
db = dbConn()
|
||||
while True:
|
||||
try:
|
||||
# print("转储历史记录...")
|
||||
for key in self.session_dict:
|
||||
# print("TEST: "+str(db.get_session(key)))
|
||||
data = self.session_dict[key]
|
||||
data_json = {
|
||||
'data': data
|
||||
}
|
||||
if db.check_session(key):
|
||||
db.update_session(key, json.dumps(data_json))
|
||||
else:
|
||||
db.insert_session(key, json.dumps(data_json))
|
||||
# print("转储历史记录完毕")
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
# 每隔10分钟转储一次
|
||||
time.sleep(10*self.history_dump_interval)
|
||||
|
||||
def text_chat(self, prompt, session_id):
|
||||
# 会话机制
|
||||
if session_id not in self.session_dict:
|
||||
self.session_dict[session_id] = []
|
||||
|
||||
fjson = {}
|
||||
try:
|
||||
f = open(abs_path+"configs/session", "r", encoding="utf-8")
|
||||
fjson = json.loads(f.read())
|
||||
f.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
fjson[session_id] = 'true'
|
||||
f = open(abs_path+"configs/session", "w", encoding="utf-8")
|
||||
f.write(json.dumps(fjson))
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
cache_data_list, new_record, req = self.wrap(prompt, session_id)
|
||||
retry = 0
|
||||
response = None
|
||||
while retry < 5:
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
print("[System] 当前Key已超额或者不正常,正在切换")
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
|
||||
response, is_switched = self.handle_switch_key(req)
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
if 'maximum context length' in str(e):
|
||||
print("token超限, 清空对应缓存")
|
||||
self.session_dict[session_id] = []
|
||||
cache_data_list, new_record, req = self.wrap(prompt, session_id)
|
||||
retry+=1
|
||||
if retry >= 5:
|
||||
raise BaseException("连接超时")
|
||||
|
||||
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens']
|
||||
self.save_key_record()
|
||||
print("[ChatGPT] "+str(response["choices"][0]["message"]["content"]))
|
||||
chatgpt_res = str(response["choices"][0]["message"]["content"]).strip()
|
||||
current_usage_tokens = response['usage']['total_tokens']
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
if current_usage_tokens > self.max_tokens:
|
||||
t = current_usage_tokens
|
||||
index = 0
|
||||
while t > self.max_tokens:
|
||||
if index >= len(cache_data_list):
|
||||
break
|
||||
# 保留倾向(人格)信息
|
||||
if 'user' in cache_data_list[index] and cache_data_list[index]['user']['role'] != 'system':
|
||||
t -= int(cache_data_list[index]['single_tokens'])
|
||||
del cache_data_list[index]
|
||||
else:
|
||||
index += 1
|
||||
# 删除完后更新相关字段
|
||||
self.session_dict[session_id] = cache_data_list
|
||||
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
|
||||
|
||||
# 添加新条目进入缓存的prompt
|
||||
new_record['AI'] = {
|
||||
'role': 'assistant',
|
||||
'content': chatgpt_res,
|
||||
}
|
||||
new_record['usage_tokens'] = current_usage_tokens
|
||||
if len(cache_data_list) > 0:
|
||||
new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
else:
|
||||
new_record['single_tokens'] = current_usage_tokens
|
||||
cache_data_list.append(new_record)
|
||||
|
||||
self.session_dict[session_id] = cache_data_list
|
||||
|
||||
return chatgpt_res
|
||||
|
||||
def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
|
||||
retry = 0
|
||||
image_url = ''
|
||||
while retry < 5:
|
||||
try:
|
||||
# print("test1")
|
||||
response = openai.Image.create(
|
||||
prompt=prompt,
|
||||
n=img_num,
|
||||
size=img_size
|
||||
)
|
||||
# print("test2")
|
||||
image_url = []
|
||||
for i in range(img_num):
|
||||
image_url.append(response['data'][i]['url'])
|
||||
print(image_url)
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
|
||||
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
print("[System] 当前Key已超额或者不正常,正在切换")
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
|
||||
response, is_switched = self.handle_switch_key(req)
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
retry += 1
|
||||
if retry >= 5:
|
||||
raise BaseException("连接超时")
|
||||
|
||||
return image_url
|
||||
|
||||
# 包装信息
|
||||
def wrap(self, prompt, session_id):
|
||||
# 获得缓存信息
|
||||
context = self.session_dict[session_id]
|
||||
new_record = {
|
||||
"user": {
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
"AI": {},
|
||||
'usage_tokens': 0,
|
||||
}
|
||||
req_list = []
|
||||
for i in context:
|
||||
req_list.append(i['user'])
|
||||
req_list.append(i['AI'])
|
||||
req_list.append(new_record['user'])
|
||||
return context, new_record, req_list
|
||||
|
||||
|
||||
|
||||
def handle_switch_key(self, req):
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
while True:
|
||||
is_all_exceed = True
|
||||
for key in self.key_stat:
|
||||
if not self.key_stat[key]['exceed']:
|
||||
is_all_exceed = False
|
||||
openai.api_key = key
|
||||
print(f"[System] 切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}")
|
||||
if len(req) > 0:
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
return response, True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if 'You exceeded' in str(e):
|
||||
print("[System] 当前Key已超额,正在切换")
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
time.sleep(1)
|
||||
continue
|
||||
else:
|
||||
return True
|
||||
if is_all_exceed:
|
||||
print("[System] 所有Key已超额")
|
||||
return None, False
|
||||
|
||||
def getConfigs(self):
|
||||
return self.openai_configs
|
||||
|
||||
def save_key_record(self):
|
||||
with open(key_record_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.key_stat, f)
|
||||
|
||||
def get_key_stat(self):
|
||||
return self.key_stat
|
||||
def get_key_list(self):
|
||||
return self.key_list
|
||||
|
||||
# 添加key
|
||||
def append_key(self, key, sponsor):
|
||||
self.key_list.append(key)
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
||||
self.save_key_record()
|
||||
self.init_key_record()
|
||||
|
||||
# 检查key是否可用
|
||||
def check_key(self, key):
|
||||
pre_key = openai.api_key
|
||||
openai.api_key = key
|
||||
messages = [{"role": "user", "content": "1"}]
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
openai.api_key = pre_key
|
||||
return True
|
||||
except Exception as e:
|
||||
pass
|
||||
openai.api_key = pre_key
|
||||
return False
|
||||
|
||||
#将key_list的key转储到key_record中,并记录相关数据
|
||||
def init_key_record(self):
|
||||
if not os.path.exists(key_record_path):
|
||||
with open(key_record_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({}, f)
|
||||
with open(key_record_path, 'r', encoding='utf-8') as keyfile:
|
||||
try:
|
||||
self.key_stat = json.load(keyfile)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.key_stat = {}
|
||||
finally:
|
||||
for key in self.key_list:
|
||||
if key not in self.key_stat:
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0}
|
||||
# if openai.api_key is None:
|
||||
# openai.api_key = key
|
||||
else:
|
||||
# if self.key_stat[key]['exceed']:
|
||||
# print(f"Key: {key} 已超额")
|
||||
# continue
|
||||
# else:
|
||||
# if openai.api_key is None:
|
||||
# openai.api_key = key
|
||||
# print(f"使用Key: {key}, 已使用token: {self.key_stat[key]['used']}")
|
||||
pass
|
||||
if openai.api_key == None:
|
||||
self.handle_switch_key("")
|
||||
self.save_key_record()
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
from revChatGPT.V1 import Chatbot
|
||||
from provider import Provider
|
||||
|
||||
class ProviderRevChatGPT(Provider):
|
||||
def __init__(self, config):
|
||||
if 'password' in config:
|
||||
config['password'] = str(config['password'])
|
||||
self.bot = Chatbot(config=config)
|
||||
|
||||
def forget(self) -> bool:
|
||||
self.bot.reset_chat()
|
||||
return True
|
||||
|
||||
def text_chat(self, prompt):
|
||||
resp = ''
|
||||
err_count = 0
|
||||
retry_count = 5
|
||||
|
||||
while err_count < retry_count:
|
||||
try:
|
||||
for data in self.bot.ask(prompt):
|
||||
resp = data["message"]
|
||||
break
|
||||
except BaseException as e:
|
||||
try:
|
||||
print("[RevChatGPT] 请求出现了一些问题, 正在重试。次数"+str(err_count))
|
||||
err_count += 1
|
||||
if err_count >= retry_count:
|
||||
raise e
|
||||
except BaseException:
|
||||
err_count += 1
|
||||
|
||||
print("[RevChatGPT] "+str(resp))
|
||||
return resp
|
||||
@@ -0,0 +1,46 @@
|
||||
from provider import Provider
|
||||
import asyncio
|
||||
from EdgeGPT import Chatbot, ConversationStyle
|
||||
import json
|
||||
|
||||
class ProviderRevEdgeGPT(Provider):
|
||||
def __init__(self):
|
||||
self.busy = False
|
||||
self.wait_stack = []
|
||||
with open('./cookies.json', 'r') as f:
|
||||
cookies = json.load(f)
|
||||
self.bot = Chatbot(cookies=cookies)
|
||||
|
||||
async def forget(self):
|
||||
try:
|
||||
await self.bot.reset()
|
||||
return False
|
||||
except BaseException:
|
||||
return True
|
||||
|
||||
async def text_chat(self, prompt):
|
||||
if self.busy:
|
||||
return
|
||||
self.busy = True
|
||||
resp = 'err'
|
||||
err_count = 0
|
||||
retry_count = 5
|
||||
|
||||
while err_count < retry_count:
|
||||
try:
|
||||
resp = await self.bot.ask(prompt=prompt, conversation_style=ConversationStyle.creative)
|
||||
resp = resp['item']['messages'][len(resp['item']['messages'])-1]['text']
|
||||
if resp == prompt:
|
||||
resp += '\n\n如果你没有让我复述你的话,那代表我可能不想和你继续这个话题了,请输入/reset重置会话😶'
|
||||
break
|
||||
except BaseException as e:
|
||||
print(e.with_traceback)
|
||||
err_count += 1
|
||||
if err_count >= retry_count:
|
||||
raise e
|
||||
print("[RevEdgeGPT] 请求出现了一些问题, 正在重试。次数"+str(err_count))
|
||||
self.busy = False
|
||||
|
||||
print("[RevEdgeGPT] "+str(resp))
|
||||
return resp
|
||||
|
||||
Reference in New Issue
Block a user