feat: 支持自定义文转图服务地址
This commit is contained in:
@@ -47,6 +47,12 @@ class AstrBotBootstrap():
|
||||
logger.info("未使用代理。")
|
||||
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
# set t2i endpoint
|
||||
if self.context.config_helper.t2i_endpoint:
|
||||
self.context.image_renderer.set_network_endpoint(
|
||||
self.context.config_helper.t2i_endpoint
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
self.command_manager = CommandManager()
|
||||
|
||||
@@ -170,6 +170,7 @@ DEFAULT_CONFIG_VERSION_2 = {
|
||||
"password": "",
|
||||
},
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
}
|
||||
|
||||
# 这个是用于迁移旧版本配置文件的映射表
|
||||
@@ -352,4 +353,5 @@ CONFIG_METADATA_2 = {
|
||||
}
|
||||
},
|
||||
"log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"},
|
||||
"t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"},
|
||||
}
|
||||
|
||||
@@ -134,6 +134,7 @@ class AstrBotConfig():
|
||||
platform: List[PlatformConfig] = field(default_factory=list)
|
||||
wake_prefix: List[str] = field(default_factory=list)
|
||||
log_level: str = "INFO"
|
||||
t2i_endpoint: str = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.init_configs()
|
||||
@@ -176,6 +177,7 @@ class AstrBotConfig():
|
||||
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
|
||||
self.wake_prefix=data.get("wake_prefix", [])
|
||||
self.log_level=data.get("log_level", "INFO")
|
||||
self.t2i_endpoint=data.get("t2i_endpoint", "")
|
||||
|
||||
def migrate_config_1_2(self, old: dict) -> dict:
|
||||
'''将配置文件从版本 1 迁移至版本 2'''
|
||||
|
||||
@@ -7,11 +7,17 @@ from logging import Logger
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
class TextToImageRenderer:
|
||||
def __init__(self):
|
||||
self.network_strategy = NetworkRenderStrategy()
|
||||
def __init__(self, endpoint_url: str = None):
|
||||
self.network_strategy = NetworkRenderStrategy(endpoint_url)
|
||||
self.local_strategy = LocalRenderStrategy()
|
||||
self.context = RenderContext(self.network_strategy)
|
||||
|
||||
def set_network_endpoint(self, endpoint_url: str):
|
||||
'''设置 t2i 的网络端点。
|
||||
'''
|
||||
logger.info("文本转图像服务接口: " + endpoint_url)
|
||||
self.network_strategy.set_endpoint(endpoint_url)
|
||||
|
||||
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False):
|
||||
'''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
|
||||
@param tmpl_str: HTML Jinja2 模板。
|
||||
|
||||
@@ -10,9 +10,16 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||
class NetworkRenderStrategy(RenderStrategy):
|
||||
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
|
||||
super().__init__()
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
|
||||
|
||||
def set_endpoint(self, base_url: str):
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
|
||||
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str:
|
||||
'''使用自定义文转图模板'''
|
||||
post_data = {
|
||||
|
||||
Reference in New Issue
Block a user