diff --git a/astrbot/core/utils/t2i/template/default_template.html.bak b/astrbot/core/utils/t2i/template/default_template.html.bak deleted file mode 100644 index 257cff3ff..000000000 --- a/astrbot/core/utils/t2i/template/default_template.html.bak +++ /dev/null @@ -1,247 +0,0 @@ - - - - - - - - - - - - -
- # AstrBot - {{ version }} -
-
- - - - - - \ No newline at end of file diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index ccc5492fd..b441a908e 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -2,94 +2,111 @@ import os import shutil -from astrbot.core.utils.astrbot_path import get_astrbot_path +from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path class TemplateManager: """ 负责管理 t2i HTML 模板的 CRUD 和重置操作。 + 采用“用户覆盖内置”策略:用户模板存储在 data 目录中,并优先于内置模板加载。 + 所有创建、更新、删除操作仅影响用户目录,以确保更新框架时用户数据安全。 """ + CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"] + def __init__(self): - # 修正路径拼接,加入缺失的 'astrbot' 目录 - self.template_dir = os.path.join( + self.builtin_template_dir = os.path.join( get_astrbot_path(), "astrbot", "core", "utils", "t2i", "template" ) - self.backup_template_path = os.path.join( - self.template_dir, "default_template.html.bak" - ) - # 确保模板目录存在 - os.makedirs(self.template_dir, exist_ok=True) + self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates") - # 检查模板目录中是否有 .html 文件 - html_files = [f for f in os.listdir(self.template_dir) if f.endswith(".html")] - if not html_files and os.path.exists(self.backup_template_path): - self.reset_default_template() + os.makedirs(self.user_template_dir, exist_ok=True) + self._initialize_user_templates() - def _get_template_path(self, name: str) -> str: - """获取模板的完整路径,防止路径遍历漏洞。""" + def _copy_core_templates(self, overwrite: bool = False): + """从内置目录复制核心模板到用户目录。""" + for filename in self.CORE_TEMPLATES: + src = os.path.join(self.builtin_template_dir, filename) + dst = os.path.join(self.user_template_dir, filename) + if os.path.exists(src) and (overwrite or not os.path.exists(dst)): + shutil.copyfile(src, dst) + + def _initialize_user_templates(self): + """如果用户目录下缺少核心模板,则进行复制。""" + self._copy_core_templates(overwrite=False) + + def _get_user_template_path(self, name: str) -> str: + """获取用户模板的完整路径,防止路径遍历漏洞。""" if ".." in name or "/" in name or "\\" in name: raise ValueError("模板名称包含非法字符。") - return os.path.join(self.template_dir, f"{name}.html") + return os.path.join(self.user_template_dir, f"{name}.html") - def list_templates(self) -> list[dict]: - """列出所有可用的模板。""" - templates = [] - for filename in os.listdir(self.template_dir): - if filename.endswith(".html"): - templates.append( - { - "name": os.path.splitext(filename)[0], - "is_default": filename == "base.html", - } - ) - return templates - - def get_template(self, name: str) -> str: - """获取指定模板的内容。""" - path = self._get_template_path(name) - if not os.path.exists(path): - raise FileNotFoundError("模板不存在。") + def _read_file(self, path: str) -> str: + """读取文件内容。""" with open(path, "r", encoding="utf-8") as f: return f.read() + def list_templates(self) -> list[dict]: + """ + 列出所有可用模板。 + 该列表是内置模板和用户模板的合并视图,用户模板将覆盖同名的内置模板。 + """ + dirs_to_scan = [self.builtin_template_dir, self.user_template_dir] + all_names = { + os.path.splitext(f)[0] + for d in dirs_to_scan + for f in os.listdir(d) + if f.endswith(".html") + } + return [ + {"name": name, "is_default": name == "base"} for name in sorted(all_names) + ] + + def get_template(self, name: str) -> str: + """ + 获取指定模板的内容。 + 优先从用户目录加载,如果不存在则回退到内置目录。 + """ + user_path = self._get_user_template_path(name) + if os.path.exists(user_path): + return self._read_file(user_path) + + builtin_path = os.path.join(self.builtin_template_dir, f"{name}.html") + if os.path.exists(builtin_path): + return self._read_file(builtin_path) + + raise FileNotFoundError("模板不存在。") + def create_template(self, name: str, content: str): - """创建一个新的模板文件。""" - path = self._get_template_path(name) + """在用户目录中创建一个新的模板文件。""" + path = self._get_user_template_path(name) if os.path.exists(path): raise FileExistsError("同名模板已存在。") with open(path, "w", encoding="utf-8") as f: f.write(content) def update_template(self, name: str, content: str): - """更新一个已存在的模板文件。""" - path = self._get_template_path(name) - if not os.path.exists(path): - raise FileNotFoundError("模板不存在。") + """ + 更新一个模板。此操作始终写入用户目录。 + 如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本, + 从而实现对内置模板的“覆盖”。 + """ + path = self._get_user_template_path(name) with open(path, "w", encoding="utf-8") as f: f.write(content) def delete_template(self, name: str): - """删除一个模板文件。""" - if name == "base": - raise ValueError("不能删除默认的 base 模板。") - path = self._get_template_path(name) + """ + 仅删除用户目录中的模板文件。 + 如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。 + """ + path = self._get_user_template_path(name) if not os.path.exists(path): - raise FileNotFoundError("模板不存在。") + raise FileNotFoundError("用户模板不存在,无法删除。") os.remove(path) - def backup_default_template_if_not_exist(self): - """如果备份不存在,则创建默认模板的备份。""" - default_path = os.path.join(self.template_dir, "base.html") - if not os.path.exists(self.backup_template_path) and os.path.exists( - default_path - ): - shutil.copyfile(default_path, self.backup_template_path) - def reset_default_template(self): - """重置默认模板。""" - if not os.path.exists(self.backup_template_path): - raise FileNotFoundError("默认模板的备份文件不存在,无法重置。") - - default_path = os.path.join(self.template_dir, "base.html") - shutil.copyfile(self.backup_template_path, default_path) + """ + 将核心模板从内置目录强制重置到用户目录。 + """ + self._copy_core_templates(overwrite=True) diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index 31cdc0bb4..04f87bc99 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -32,10 +32,6 @@ class T2iRoute(Route): ], ), ] - - # 应用启动时,确保备份存在 - self.manager.backup_default_template_if_not_exist() - self.register_routes() async def list_templates(self): @@ -89,6 +85,7 @@ class T2iRoute(Route): ) response.status_code = 400 return response + name = name.strip() self.manager.create_template(name, content) response = jsonify( @@ -118,6 +115,7 @@ class T2iRoute(Route): async def update_template(self, name: str): """更新一个已存在的T2I模板""" try: + name = name.strip() data = await request.json content = data.get("content") if content is None: @@ -126,17 +124,16 @@ class T2iRoute(Route): return response self.manager.update_template(name, content) - return jsonify( - asdict( - Response().ok( - data={"name": name}, message="Template updated successfully." - ) - ) - ) - except FileNotFoundError: - response = jsonify(asdict(Response().error("Template not found."))) - response.status_code = 404 - return response + + # 检查更新的是否为当前激活的模板,如果是,则热重载 + active_template = self.config.get("t2i_active_template", "base") + if name == active_template: + await self.core_lifecycle.reload_pipeline_scheduler("default") + message = f"模板 '{name}' 已更新并重新加载。" + else: + message = f"模板 '{name}' 已更新。" + + return jsonify(asdict(Response().ok(data={"name": name}, message=message))) except ValueError as e: response = jsonify(asdict(Response().error(str(e)))) response.status_code = 400 @@ -149,6 +146,7 @@ class T2iRoute(Route): async def delete_template(self, name: str): """删除一个T2I模板""" try: + name = name.strip() self.manager.delete_template(name) return jsonify( asdict(Response().ok(message="Template deleted successfully."))