diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 574dd47b2..31963c53b 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -47,6 +47,12 @@ class AstrBotBootstrap(): logger.info("未使用代理。") self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' + + # set t2i endpoint + if self.context.config_helper.t2i_endpoint: + self.context.image_renderer.set_network_endpoint( + self.context.config_helper.t2i_endpoint + ) async def run(self): self.command_manager = CommandManager() diff --git a/type/config.py b/type/config.py index 309d7a00a..0745bd9fb 100644 --- a/type/config.py +++ b/type/config.py @@ -170,6 +170,7 @@ DEFAULT_CONFIG_VERSION_2 = { "password": "", }, "log_level": "INFO", + "t2i_endpoint": "", } # 这个是用于迁移旧版本配置文件的映射表 @@ -352,4 +353,5 @@ CONFIG_METADATA_2 = { } }, "log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"}, + "t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"}, } diff --git a/util/cmd_config.py b/util/cmd_config.py index 6f9a6d60f..85285bc5b 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -134,6 +134,7 @@ class AstrBotConfig(): platform: List[PlatformConfig] = field(default_factory=list) wake_prefix: List[str] = field(default_factory=list) log_level: str = "INFO" + t2i_endpoint: str = "" def __init__(self) -> None: self.init_configs() @@ -176,6 +177,7 @@ class AstrBotConfig(): self.dashboard=DashboardConfig(**data.get("dashboard", {})) self.wake_prefix=data.get("wake_prefix", []) self.log_level=data.get("log_level", "INFO") + self.t2i_endpoint=data.get("t2i_endpoint", "") def migrate_config_1_2(self, old: dict) -> dict: '''将配置文件从版本 1 迁移至版本 2''' diff --git a/util/t2i/renderer.py b/util/t2i/renderer.py index ab8d9c820..5db6be6ee 100644 --- a/util/t2i/renderer.py +++ b/util/t2i/renderer.py @@ -7,11 +7,17 @@ from logging import Logger logger: Logger = LogManager.GetLogger(log_name='astrbot') class TextToImageRenderer: - def __init__(self): - self.network_strategy = NetworkRenderStrategy() + def __init__(self, endpoint_url: str = None): + self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() self.context = RenderContext(self.network_strategy) + def set_network_endpoint(self, endpoint_url: str): + '''设置 t2i 的网络端点。 + ''' + logger.info("文本转图像服务接口: " + endpoint_url) + self.network_strategy.set_endpoint(endpoint_url) + async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False): '''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。 @param tmpl_str: HTML Jinja2 模板。 diff --git a/util/t2i/strategies/network_strategy.py b/util/t2i/strategies/network_strategy.py index 124c8bfe4..7b73f1f70 100644 --- a/util/t2i/strategies/network_strategy.py +++ b/util/t2i/strategies/network_strategy.py @@ -10,9 +10,16 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img" class NetworkRenderStrategy(RenderStrategy): def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None: super().__init__() + if not base_url: + base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT self.BASE_RENDER_URL = base_url self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template") + def set_endpoint(self, base_url: str): + if not base_url: + base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT + self.BASE_RENDER_URL = base_url + async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str: '''使用自定义文转图模板''' post_data = {