diff --git a/astrbot/core/utils/t2i/template/default_template.html.bak b/astrbot/core/utils/t2i/template/default_template.html.bak
deleted file mode 100644
index 257cff3ff..000000000
--- a/astrbot/core/utils/t2i/template/default_template.html.bak
+++ /dev/null
@@ -1,247 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
- # AstrBot
- {{ version }}
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py
index ccc5492fd..b441a908e 100644
--- a/astrbot/core/utils/t2i/template_manager.py
+++ b/astrbot/core/utils/t2i/template_manager.py
@@ -2,94 +2,111 @@
import os
import shutil
-from astrbot.core.utils.astrbot_path import get_astrbot_path
+from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path
class TemplateManager:
"""
负责管理 t2i HTML 模板的 CRUD 和重置操作。
+ 采用“用户覆盖内置”策略:用户模板存储在 data 目录中,并优先于内置模板加载。
+ 所有创建、更新、删除操作仅影响用户目录,以确保更新框架时用户数据安全。
"""
+ CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"]
+
def __init__(self):
- # 修正路径拼接,加入缺失的 'astrbot' 目录
- self.template_dir = os.path.join(
+ self.builtin_template_dir = os.path.join(
get_astrbot_path(), "astrbot", "core", "utils", "t2i", "template"
)
- self.backup_template_path = os.path.join(
- self.template_dir, "default_template.html.bak"
- )
- # 确保模板目录存在
- os.makedirs(self.template_dir, exist_ok=True)
+ self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates")
- # 检查模板目录中是否有 .html 文件
- html_files = [f for f in os.listdir(self.template_dir) if f.endswith(".html")]
- if not html_files and os.path.exists(self.backup_template_path):
- self.reset_default_template()
+ os.makedirs(self.user_template_dir, exist_ok=True)
+ self._initialize_user_templates()
- def _get_template_path(self, name: str) -> str:
- """获取模板的完整路径,防止路径遍历漏洞。"""
+ def _copy_core_templates(self, overwrite: bool = False):
+ """从内置目录复制核心模板到用户目录。"""
+ for filename in self.CORE_TEMPLATES:
+ src = os.path.join(self.builtin_template_dir, filename)
+ dst = os.path.join(self.user_template_dir, filename)
+ if os.path.exists(src) and (overwrite or not os.path.exists(dst)):
+ shutil.copyfile(src, dst)
+
+ def _initialize_user_templates(self):
+ """如果用户目录下缺少核心模板,则进行复制。"""
+ self._copy_core_templates(overwrite=False)
+
+ def _get_user_template_path(self, name: str) -> str:
+ """获取用户模板的完整路径,防止路径遍历漏洞。"""
if ".." in name or "/" in name or "\\" in name:
raise ValueError("模板名称包含非法字符。")
- return os.path.join(self.template_dir, f"{name}.html")
+ return os.path.join(self.user_template_dir, f"{name}.html")
- def list_templates(self) -> list[dict]:
- """列出所有可用的模板。"""
- templates = []
- for filename in os.listdir(self.template_dir):
- if filename.endswith(".html"):
- templates.append(
- {
- "name": os.path.splitext(filename)[0],
- "is_default": filename == "base.html",
- }
- )
- return templates
-
- def get_template(self, name: str) -> str:
- """获取指定模板的内容。"""
- path = self._get_template_path(name)
- if not os.path.exists(path):
- raise FileNotFoundError("模板不存在。")
+ def _read_file(self, path: str) -> str:
+ """读取文件内容。"""
with open(path, "r", encoding="utf-8") as f:
return f.read()
+ def list_templates(self) -> list[dict]:
+ """
+ 列出所有可用模板。
+ 该列表是内置模板和用户模板的合并视图,用户模板将覆盖同名的内置模板。
+ """
+ dirs_to_scan = [self.builtin_template_dir, self.user_template_dir]
+ all_names = {
+ os.path.splitext(f)[0]
+ for d in dirs_to_scan
+ for f in os.listdir(d)
+ if f.endswith(".html")
+ }
+ return [
+ {"name": name, "is_default": name == "base"} for name in sorted(all_names)
+ ]
+
+ def get_template(self, name: str) -> str:
+ """
+ 获取指定模板的内容。
+ 优先从用户目录加载,如果不存在则回退到内置目录。
+ """
+ user_path = self._get_user_template_path(name)
+ if os.path.exists(user_path):
+ return self._read_file(user_path)
+
+ builtin_path = os.path.join(self.builtin_template_dir, f"{name}.html")
+ if os.path.exists(builtin_path):
+ return self._read_file(builtin_path)
+
+ raise FileNotFoundError("模板不存在。")
+
def create_template(self, name: str, content: str):
- """创建一个新的模板文件。"""
- path = self._get_template_path(name)
+ """在用户目录中创建一个新的模板文件。"""
+ path = self._get_user_template_path(name)
if os.path.exists(path):
raise FileExistsError("同名模板已存在。")
with open(path, "w", encoding="utf-8") as f:
f.write(content)
def update_template(self, name: str, content: str):
- """更新一个已存在的模板文件。"""
- path = self._get_template_path(name)
- if not os.path.exists(path):
- raise FileNotFoundError("模板不存在。")
+ """
+ 更新一个模板。此操作始终写入用户目录。
+ 如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本,
+ 从而实现对内置模板的“覆盖”。
+ """
+ path = self._get_user_template_path(name)
with open(path, "w", encoding="utf-8") as f:
f.write(content)
def delete_template(self, name: str):
- """删除一个模板文件。"""
- if name == "base":
- raise ValueError("不能删除默认的 base 模板。")
- path = self._get_template_path(name)
+ """
+ 仅删除用户目录中的模板文件。
+ 如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。
+ """
+ path = self._get_user_template_path(name)
if not os.path.exists(path):
- raise FileNotFoundError("模板不存在。")
+ raise FileNotFoundError("用户模板不存在,无法删除。")
os.remove(path)
- def backup_default_template_if_not_exist(self):
- """如果备份不存在,则创建默认模板的备份。"""
- default_path = os.path.join(self.template_dir, "base.html")
- if not os.path.exists(self.backup_template_path) and os.path.exists(
- default_path
- ):
- shutil.copyfile(default_path, self.backup_template_path)
-
def reset_default_template(self):
- """重置默认模板。"""
- if not os.path.exists(self.backup_template_path):
- raise FileNotFoundError("默认模板的备份文件不存在,无法重置。")
-
- default_path = os.path.join(self.template_dir, "base.html")
- shutil.copyfile(self.backup_template_path, default_path)
+ """
+ 将核心模板从内置目录强制重置到用户目录。
+ """
+ self._copy_core_templates(overwrite=True)
diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py
index 31cdc0bb4..04f87bc99 100644
--- a/astrbot/dashboard/routes/t2i.py
+++ b/astrbot/dashboard/routes/t2i.py
@@ -32,10 +32,6 @@ class T2iRoute(Route):
],
),
]
-
- # 应用启动时,确保备份存在
- self.manager.backup_default_template_if_not_exist()
-
self.register_routes()
async def list_templates(self):
@@ -89,6 +85,7 @@ class T2iRoute(Route):
)
response.status_code = 400
return response
+ name = name.strip()
self.manager.create_template(name, content)
response = jsonify(
@@ -118,6 +115,7 @@ class T2iRoute(Route):
async def update_template(self, name: str):
"""更新一个已存在的T2I模板"""
try:
+ name = name.strip()
data = await request.json
content = data.get("content")
if content is None:
@@ -126,17 +124,16 @@ class T2iRoute(Route):
return response
self.manager.update_template(name, content)
- return jsonify(
- asdict(
- Response().ok(
- data={"name": name}, message="Template updated successfully."
- )
- )
- )
- except FileNotFoundError:
- response = jsonify(asdict(Response().error("Template not found.")))
- response.status_code = 404
- return response
+
+ # 检查更新的是否为当前激活的模板,如果是,则热重载
+ active_template = self.config.get("t2i_active_template", "base")
+ if name == active_template:
+ await self.core_lifecycle.reload_pipeline_scheduler("default")
+ message = f"模板 '{name}' 已更新并重新加载。"
+ else:
+ message = f"模板 '{name}' 已更新。"
+
+ return jsonify(asdict(Response().ok(data={"name": name}, message=message)))
except ValueError as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 400
@@ -149,6 +146,7 @@ class T2iRoute(Route):
async def delete_template(self, name: str):
"""删除一个T2I模板"""
try:
+ name = name.strip()
self.manager.delete_template(name)
return jsonify(
asdict(Response().ok(message="Template deleted successfully."))