diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index a78367184..98857914b 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -122,6 +122,12 @@ cc.init_attributes("qqbot_appid", "") cc.init_attributes("qqbot_secret", "") cc.init_attributes("llm_env_prompt", "> hint: 末尾根据内容和心情添加 1-2 个emoji") cc.init_attributes("default_personality_str", "") +cc.init_attributes("openai_image_generate", { + "model": "dall-e-3", + "size": "1024x1024", + "style": "vivid", + "quality": "standard", +}) # cc.init_attributes(["qq_forward_mode"], False) # QQ机器人 diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 9f469501b..345a3fa45 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -1,5 +1,6 @@ from openai import OpenAI from openai.types.chat.chat_completion import ChatCompletion +from openai.types.images_response import ImagesResponse import json import time import os @@ -8,13 +9,17 @@ from cores.database.conn import dbConn from model.provider.provider import Provider import threading from util import general_utils as gu +from util.cmd_config import CmdConfig import traceback import tiktoken + abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' class ProviderOpenAIOfficial(Provider): def __init__(self, cfg): + self.cc = CmdConfig() + self.key_list = [] # 如果 cfg['key']中有长度为1的字符串,那么是格式错误,直接报错 for key in cfg['key']: @@ -126,7 +131,6 @@ class ProviderOpenAIOfficial(Provider): } self.session_dict[session_id].append(new_record) - def text_chat(self, prompt, session_id = None, image_url = None, @@ -289,16 +293,18 @@ class ProviderOpenAIOfficial(Provider): def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"): retry = 0 image_url = '' + + image_generate_configs = self.cc.get("openai_image_generate", None) + while retry < 5: try: - response = self.client.images.generate( + response: ImagesResponse = self.client.images.generate( prompt=prompt, - n=img_num, - size=img_size + **image_generate_configs ) image_url = [] for i in range(img_num): - image_url.append(response['data'][i]['url']) + image_url.append(response.data[i].url) break except Exception as e: gu.log(str(e), level=gu.LEVEL_ERROR) @@ -310,6 +316,9 @@ class ProviderOpenAIOfficial(Provider): if not is_switched: # 所有Key都超额或不正常 raise e + elif 'Your request was rejected as a result of our safety system.' in str(e): + gu.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING) + raise e else: retry += 1 if retry >= 5: