diff --git a/cores/openai/core.py b/cores/openai/core.py index 270fdc9a8..f1aeac501 100644 --- a/cores/openai/core.py +++ b/cores/openai/core.py @@ -49,17 +49,20 @@ class ChatGPT: global inst inst = self - def chat(self, prompt): + def chat(self, prompt, image_mode = False): try: - response = openai.Completion.create( - prompt=prompt, - **self.chatGPT_configs - ) + if not image_mode: + response = openai.Completion.create( + prompt=prompt, + **self.chatGPT_configs + ) + else: + pass # except(openai.error.InvalidRequestError) as e: # raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e)) except Exception as e: print(e) - if 'You exceeded' in str(e): + if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e): print("当前Key已超额,正在切换") self.key_stat[openai.api_key]['exceed'] = True self.save_key_record() @@ -68,12 +71,22 @@ class ChatGPT: if not is_switched: # 所有Key都超额 raise e - # print(response['usage']) - self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens'] - self.save_key_record() - print("[ChatGPT] "+response["choices"][0]["text"]) - return response["choices"][0]["text"].strip(), response['usage']['total_tokens'] - + else: + if not image_mode: + response = openai.Completion.create( + prompt=prompt, + **self.chatGPT_configs + ) + else: + pass + if not image_mode: + self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens'] + self.save_key_record() + print("[ChatGPT] "+response["choices"][0]["text"]) + return response["choices"][0]["text"].strip(), response['usage']['total_tokens'] + else: + pass + def handle_switch_key(self, prompt): while True: is_all_exceed = True diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 5cf2182fb..16ef0dc42 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -47,6 +47,7 @@ class botClient(botpy.Client): # 收到At消息 async def on_at_message_create(self, message: Message): toggle_count(at=True, message=message) + # executor.submit(oper_msg, message, True) new_sub_thread(oper_msg, (message, True)) # await oper_msg(message=message, at=True) @@ -169,19 +170,26 @@ def initBot(chatgpt_inst): ''' 得到OpenAI的回复 ''' -def get_chatGPT_response(prompts_str): +def get_chatGPT_response(prompts_str, image_mode=False): res = '' usage = '' - res, usage = chatgpt.chat(prompts_str) - # 处理结果文本 - chatgpt_res = res.strip() - return res, usage + if not image_mode: + res, usage = chatgpt.chat(prompts_str) + # 处理结果文本 + chatgpt_res = res.strip() + return res, usage + else: + res = chatgpt.chat(prompts_str, image_mode = True) + return res ''' 回复QQ消息 ''' -def send_qq_msg(message, res): - asyncio.run_coroutine_threadsafe(message.reply(content=res), client.loop) +def send_qq_msg(message, res, image_mode=False): + if not image_mode: + asyncio.run_coroutine_threadsafe(message.reply(content=res), client.loop) + else: + asyncio.run_coroutine_threadsafe(message.reply(image=res, content="【此功能未完全实现】\n"), client.loop) ''' 获取缓存的会话 @@ -325,7 +333,6 @@ def oper_msg(message, at=False, loop=None): if qq_msg == "/help": send_qq_msg(message, "请联系频道管理员或者前往github(仓库名: QQChannelChatGPT)提issue~") return - # 统计历史会话 if session_id not in session_dict: session_dict[session_id] = [] @@ -412,4 +419,5 @@ def oper_msg(message, at=False, loop=None): "single_tokens": current_usage_tokens } cache_data_list.append(single_record) - session_dict[session_id] = cache_data_list \ No newline at end of file + session_dict[session_id] = cache_data_list +