37e9373561
添加画图功能,需调用OpenAI api启用
192 lines
7.7 KiB
Python
192 lines
7.7 KiB
Python
import openai
|
|
import yaml
|
|
from util.errors.errors import PromptExceededError
|
|
import json
|
|
import time
|
|
import os
|
|
import sys
|
|
|
|
inst = None
|
|
# 适配pyinstaller
|
|
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
|
key_record_path = abs_path+'chatgpt_key_record'
|
|
|
|
class ChatGPT:
|
|
def __init__(self, cfg):
|
|
self.key_list = []
|
|
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
|
openai.api_base = cfg['api_base']
|
|
if cfg['key'] != '' and cfg['key'] != None:
|
|
print("[System] 读取ChatGPT Key成功")
|
|
self.key_list = cfg['key']
|
|
# openai.api_key = cfg['key']
|
|
else:
|
|
input("[System] 请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
|
|
|
|
# init key record
|
|
self.init_key_record()
|
|
|
|
chatGPT_configs = cfg['chatGPTConfigs']
|
|
print(f'[System] 加载ChatGPTConfigs: {chatGPT_configs}')
|
|
self.chatGPT_configs = chatGPT_configs
|
|
self.openai_configs = cfg
|
|
|
|
def chat(self, req, image_mode = False, img_num = 1, img_size = "1024x1024"):
|
|
# ChatGPT API 2023/3/2
|
|
# messages = [{"role": "user", "content": prompt}]
|
|
if not image_mode:
|
|
try:
|
|
response = openai.ChatCompletion.create(
|
|
messages=req,
|
|
**self.chatGPT_configs
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
|
print("[System] 当前Key已超额或者不正常,正在切换")
|
|
self.key_stat[openai.api_key]['exceed'] = True
|
|
self.save_key_record()
|
|
|
|
response, is_switched = self.handle_switch_key(req)
|
|
if not is_switched:
|
|
# 所有Key都超额或不正常
|
|
raise e
|
|
else:
|
|
response = openai.ChatCompletion.create(
|
|
messages=req,
|
|
**self.chatGPT_configs
|
|
)
|
|
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens']
|
|
self.save_key_record()
|
|
print("[ChatGPT] "+str(response["choices"][0]["message"]["content"]))
|
|
return str(response["choices"][0]["message"]["content"]).strip(), response['usage']['total_tokens']
|
|
else:
|
|
try:
|
|
# print("test1")
|
|
response = openai.Image.create(
|
|
prompt=req[0]['content'],
|
|
n=img_num,
|
|
size=img_size
|
|
)
|
|
# print("test2")
|
|
image_url = []
|
|
for i in range(img_num):
|
|
image_url.append(response['data'][i]['url'])
|
|
print(image_url)
|
|
except Exception as e:
|
|
print(e)
|
|
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
|
|
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
|
print("[System] 当前Key已超额或者不正常,正在切换")
|
|
self.key_stat[openai.api_key]['exceed'] = True
|
|
self.save_key_record()
|
|
|
|
response, is_switched = self.handle_switch_key(req)
|
|
if not is_switched:
|
|
# 所有Key都超额或不正常
|
|
raise e
|
|
else:
|
|
response = openai.Image.create(
|
|
prompt=req[0]['content'],
|
|
n=img_num,
|
|
size=img_size
|
|
)
|
|
image_url = []
|
|
for i in range(img_num):
|
|
image_url.append(response['data'][i]['url'])
|
|
return image_url
|
|
def handle_switch_key(self, req):
|
|
# messages = [{"role": "user", "content": prompt}]
|
|
while True:
|
|
is_all_exceed = True
|
|
for key in self.key_stat:
|
|
if not self.key_stat[key]['exceed']:
|
|
is_all_exceed = False
|
|
openai.api_key = key
|
|
print(f"[System] 切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}")
|
|
if len(req) > 0:
|
|
try:
|
|
response = openai.ChatCompletion.create(
|
|
messages=req,
|
|
**self.chatGPT_configs
|
|
)
|
|
return response, True
|
|
except Exception as e:
|
|
print(e)
|
|
if 'You exceeded' in str(e):
|
|
print("[System] 当前Key已超额,正在切换")
|
|
self.key_stat[openai.api_key]['exceed'] = True
|
|
self.save_key_record()
|
|
time.sleep(1)
|
|
continue
|
|
else:
|
|
return True
|
|
if is_all_exceed:
|
|
print("[System] 所有Key已超额")
|
|
return None, False
|
|
|
|
def getConfigs(self):
|
|
return self.openai_configs
|
|
|
|
def save_key_record(self):
|
|
with open(key_record_path, 'w', encoding='utf-8') as f:
|
|
json.dump(self.key_stat, f)
|
|
|
|
def get_key_stat(self):
|
|
return self.key_stat
|
|
def get_key_list(self):
|
|
return self.key_list
|
|
|
|
# 添加key
|
|
def append_key(self, key, sponsor):
|
|
self.key_list.append(key)
|
|
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
|
self.save_key_record()
|
|
self.init_key_record()
|
|
# 检查key是否可用
|
|
def check_key(self, key):
|
|
pre_key = openai.api_key
|
|
openai.api_key = key
|
|
messages = [{"role": "user", "content": "1"}]
|
|
try:
|
|
response = openai.ChatCompletion.create(
|
|
messages=messages,
|
|
**self.chatGPT_configs
|
|
)
|
|
openai.api_key = pre_key
|
|
return True
|
|
except Exception as e:
|
|
pass
|
|
openai.api_key = pre_key
|
|
return False
|
|
|
|
#将key_list的key转储到key_record中,并记录相关数据
|
|
def init_key_record(self):
|
|
if not os.path.exists(key_record_path):
|
|
with open(key_record_path, 'w', encoding='utf-8') as f:
|
|
json.dump({}, f)
|
|
with open(key_record_path, 'r', encoding='utf-8') as keyfile:
|
|
try:
|
|
self.key_stat = json.load(keyfile)
|
|
except Exception as e:
|
|
print(e)
|
|
self.key_stat = {}
|
|
finally:
|
|
for key in self.key_list:
|
|
if key not in self.key_stat:
|
|
self.key_stat[key] = {'exceed': False, 'used': 0}
|
|
# if openai.api_key is None:
|
|
# openai.api_key = key
|
|
else:
|
|
# if self.key_stat[key]['exceed']:
|
|
# print(f"Key: {key} 已超额")
|
|
# continue
|
|
# else:
|
|
# if openai.api_key is None:
|
|
# openai.api_key = key
|
|
# print(f"使用Key: {key}, 已使用token: {self.key_stat[key]['used']}")
|
|
pass
|
|
if openai.api_key == None:
|
|
self.handle_switch_key("")
|
|
self.save_key_record()
|