feat: 支持自定义文转图服务地址

This commit is contained in:
Soulter
2024-09-22 10:50:47 -04:00
parent 90815b1ac5
commit 353b6ed761
5 changed files with 25 additions and 2 deletions
+6
View File
@@ -47,6 +47,12 @@ class AstrBotBootstrap():
logger.info("未使用代理。")
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
# set t2i endpoint
if self.context.config_helper.t2i_endpoint:
self.context.image_renderer.set_network_endpoint(
self.context.config_helper.t2i_endpoint
)
async def run(self):
self.command_manager = CommandManager()
+2
View File
@@ -170,6 +170,7 @@ DEFAULT_CONFIG_VERSION_2 = {
"password": "",
},
"log_level": "INFO",
"t2i_endpoint": "",
}
# 这个是用于迁移旧版本配置文件的映射表
@@ -352,4 +353,5 @@ CONFIG_METADATA_2 = {
}
},
"log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"},
"t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"},
}
+2
View File
@@ -134,6 +134,7 @@ class AstrBotConfig():
platform: List[PlatformConfig] = field(default_factory=list)
wake_prefix: List[str] = field(default_factory=list)
log_level: str = "INFO"
t2i_endpoint: str = ""
def __init__(self) -> None:
self.init_configs()
@@ -176,6 +177,7 @@ class AstrBotConfig():
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
self.wake_prefix=data.get("wake_prefix", [])
self.log_level=data.get("log_level", "INFO")
self.t2i_endpoint=data.get("t2i_endpoint", "")
def migrate_config_1_2(self, old: dict) -> dict:
'''将配置文件从版本 1 迁移至版本 2'''
+8 -2
View File
@@ -7,11 +7,17 @@ from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot')
class TextToImageRenderer:
def __init__(self):
self.network_strategy = NetworkRenderStrategy()
def __init__(self, endpoint_url: str = None):
self.network_strategy = NetworkRenderStrategy(endpoint_url)
self.local_strategy = LocalRenderStrategy()
self.context = RenderContext(self.network_strategy)
def set_network_endpoint(self, endpoint_url: str):
'''设置 t2i 的网络端点。
'''
logger.info("文本转图像服务接口: " + endpoint_url)
self.network_strategy.set_endpoint(endpoint_url)
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False):
'''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
@param tmpl_str: HTML Jinja2 模板。
+7
View File
@@ -10,9 +10,16 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
class NetworkRenderStrategy(RenderStrategy):
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
super().__init__()
if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
def set_endpoint(self, base_url: str):
if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str:
'''使用自定义文转图模板'''
post_data = {