diff --git a/cores/openai/core.py b/cores/openai/core.py index 6a5bcd897..a70c6278a 100644 --- a/cores/openai/core.py +++ b/cores/openai/core.py @@ -31,35 +31,70 @@ class ChatGPT: self.chatGPT_configs = chatGPT_configs self.openai_configs = cfg - def chat(self, req, image_mode = False): + def chat(self, req, image_mode = False, img_num = 1, img_size = "1024x1024"): # ChatGPT API 2023/3/2 # messages = [{"role": "user", "content": prompt}] - try: - response = openai.ChatCompletion.create( - messages=req, - **self.chatGPT_configs - ) - except Exception as e: - print(e) - if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - print("[System] 当前Key已超额或者不正常,正在切换") - self.key_stat[openai.api_key]['exceed'] = True - self.save_key_record() - - response, is_switched = self.handle_switch_key(req) - if not is_switched: - # 所有Key都超额或不正常 - raise e - else: + if not image_mode: + try: response = openai.ChatCompletion.create( messages=req, **self.chatGPT_configs ) - self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens'] - self.save_key_record() - print("[ChatGPT] "+str(response["choices"][0]["message"]["content"])) - return str(response["choices"][0]["message"]["content"]).strip(), response['usage']['total_tokens'] - + except Exception as e: + print(e) + if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): + print("[System] 当前Key已超额或者不正常,正在切换") + self.key_stat[openai.api_key]['exceed'] = True + self.save_key_record() + + response, is_switched = self.handle_switch_key(req) + if not is_switched: + # 所有Key都超额或不正常 + raise e + else: + response = openai.ChatCompletion.create( + messages=req, + **self.chatGPT_configs + ) + self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens'] + self.save_key_record() + print("[ChatGPT] "+str(response["choices"][0]["message"]["content"])) + return str(response["choices"][0]["message"]["content"]).strip(), response['usage']['total_tokens'] + else: + try: + # print("test1") + response = openai.Image.create( + prompt=req[0]['content'], + n=img_num, + size=img_size + ) + # print("test2") + image_url = [] + for i in range(img_num): + image_url.append(response['data'][i]['url']) + print(image_url) + except Exception as e: + print(e) + if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str( + e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): + print("[System] 当前Key已超额或者不正常,正在切换") + self.key_stat[openai.api_key]['exceed'] = True + self.save_key_record() + + response, is_switched = self.handle_switch_key(req) + if not is_switched: + # 所有Key都超额或不正常 + raise e + else: + response = openai.Image.create( + prompt=req[0]['content'], + n=img_num, + size=img_size + ) + image_url = [] + for i in range(img_num): + image_url.append(response['data'][i]['url']) + return image_url def handle_switch_key(self, req): # messages = [{"role": "user", "content": prompt}] while True: @@ -153,4 +188,4 @@ class ChatGPT: pass if openai.api_key == None: self.handle_switch_key("") - self.save_key_record() \ No newline at end of file + self.save_key_record()