feat: 画图指令支持 DallE3

This commit is contained in:
Soulter
2023-12-04 13:50:49 +08:00
parent 9bc8ac10fa
commit 9c6bdeea9d
2 changed files with 20 additions and 5 deletions
+6
View File
@@ -122,6 +122,12 @@ cc.init_attributes("qqbot_appid", "")
cc.init_attributes("qqbot_secret", "")
cc.init_attributes("llm_env_prompt", "> hint: 末尾根据内容和心情添加 1-2 个emoji")
cc.init_attributes("default_personality_str", "")
cc.init_attributes("openai_image_generate", {
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
})
# cc.init_attributes(["qq_forward_mode"], False)
# QQ机器人
+14 -5
View File
@@ -1,5 +1,6 @@
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.images_response import ImagesResponse
import json
import time
import os
@@ -8,13 +9,17 @@ from cores.database.conn import dbConn
from model.provider.provider import Provider
import threading
from util import general_utils as gu
from util.cmd_config import CmdConfig
import traceback
import tiktoken
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg):
self.cc = CmdConfig()
self.key_list = []
# 如果 cfg['key']中有长度为1的字符串,那么是格式错误,直接报错
for key in cfg['key']:
@@ -126,7 +131,6 @@ class ProviderOpenAIOfficial(Provider):
}
self.session_dict[session_id].append(new_record)
def text_chat(self, prompt,
session_id = None,
image_url = None,
@@ -289,16 +293,18 @@ class ProviderOpenAIOfficial(Provider):
def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
retry = 0
image_url = ''
image_generate_configs = self.cc.get("openai_image_generate", None)
while retry < 5:
try:
response = self.client.images.generate(
response: ImagesResponse = self.client.images.generate(
prompt=prompt,
n=img_num,
size=img_size
**image_generate_configs
)
image_url = []
for i in range(img_num):
image_url.append(response['data'][i]['url'])
image_url.append(response.data[i].url)
break
except Exception as e:
gu.log(str(e), level=gu.LEVEL_ERROR)
@@ -310,6 +316,9 @@ class ProviderOpenAIOfficial(Provider):
if not is_switched:
# 所有Key都超额或不正常
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
gu.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING)
raise e
else:
retry += 1
if retry >= 5: