feat: supports data backup (#4105)

* feat: 添加数据迁移功能

* test: 添加迁移相关测试

* feat: 备份插件及相关持久化目录

* fix: 修复版本号比较逻辑,添加相关测试

* fix: 清洗文件名,添加相关测试

* fix: 修复安全文件名测试用例断言

* refactor: 优化代码,为备份模块提取公用常量

* feat: 修改备份版本校验逻辑,允许强制小版本间导入

* fix: 修复备份创建时间读取,修复备份相关i18n

* refactor(backup): 使用 astrbot_path 统一管理备份目录路径

* fix(backup): 清理备份模块中未使用的导入

* refactor(backup): 统一备份路径与参数并移除未用附件目录

- 通过 astrbot_path 动态获取备份/知识库/数据相关路径
- 移除 exporter/importer 未使用的 attachments_dir/data_root 传参
- 更新备份路由与测试用例的构造参数

* fix(dashboard): alias mermaid to dist entry for Vite prebundle

* fix(backup): 放行start-time接口到白名单以处理备份导入后jwt token变化导致无法自动刷新webui的问题

* chore(backup): 统一配置路径以使用动态数据目录

* refactor(backup): 使用 VersionComparator 替代重复的版本比较函数

* style(backup test): format code

---------

Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
RC-CHN
2025-12-26 15:47:50 +08:00
committed by GitHub
parent 701399c00c
commit aa38fe776a
14 changed files with 3557 additions and 3 deletions
+26
View File
@@ -0,0 +1,26 @@
"""AstrBot 备份与恢复模块
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
"""
# 从 constants 模块导入共享常量
from .constants import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
# 导入导出器和导入器
from .exporter import AstrBotExporter
from .importer import AstrBotImporter, ImportPreCheckResult
__all__ = [
"AstrBotExporter",
"AstrBotImporter",
"ImportPreCheckResult",
"MAIN_DB_MODELS",
"KB_METADATA_MODELS",
"get_backup_directories",
"BACKUP_MANIFEST_VERSION",
]
+77
View File
@@ -0,0 +1,77 @@
"""AstrBot 备份模块共享常量
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
"""
from sqlmodel import SQLModel
from astrbot.core.db.po import (
Attachment,
CommandConfig,
CommandConflict,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
)
from astrbot.core.knowledge_base.models import (
KBDocument,
KBMedia,
KnowledgeBase,
)
from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_plugin_data_path,
get_astrbot_plugin_path,
get_astrbot_t2i_templates_path,
get_astrbot_temp_path,
get_astrbot_webchat_path,
)
# ============================================================
# 共享常量 - 确保导出和导入端配置一致
# ============================================================
# 主数据库模型类映射
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
"platform_stats": PlatformStat,
"conversations": ConversationV2,
"personas": Persona,
"preferences": Preference,
"platform_message_history": PlatformMessageHistory,
"platform_sessions": PlatformSession,
"attachments": Attachment,
"command_configs": CommandConfig,
"command_conflicts": CommandConflict,
}
# 知识库元数据模型类映射
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
"knowledge_bases": KnowledgeBase,
"kb_documents": KBDocument,
"kb_media": KBMedia,
}
def get_backup_directories() -> dict[str, str]:
"""获取需要备份的目录列表
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
Returns:
dict: 键为备份文件中的目录名称,值为目录的绝对路径
"""
return {
"plugins": get_astrbot_plugin_path(), # 插件本体
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
"config": get_astrbot_config_path(), # 配置目录
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
"webchat": get_astrbot_webchat_path(), # WebChat 数据
"temp": get_astrbot_temp_path(), # 临时文件
}
# 备份清单版本号
BACKUP_MANIFEST_VERSION = "1.1"
+476
View File
@@ -0,0 +1,476 @@
"""AstrBot 数据导出器
负责将所有数据导出为 ZIP 备份文件。
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
"""
import hashlib
import json
import os
import zipfile
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_backups_path,
get_astrbot_data_path,
)
# 从共享常量模块导入
from .constants import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
if TYPE_CHECKING:
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
class AstrBotExporter:
"""AstrBot 数据导出器
导出内容:
- 主数据库所有表(data/data_v4.db
- 知识库元数据(data/knowledge_base/kb.db
- 每个知识库的向量文档数据
- 配置文件(data/cmd_config.json
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins
- 插件数据目录(data/plugin_data
- 配置目录(data/config
- T2I 模板目录(data/t2i_templates
- WebChat 数据目录(data/webchat
- 临时文件目录(data/temp
"""
def __init__(
self,
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
):
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
self._checksums: dict[str, str] = {}
async def export_all(
self,
output_dir: str | None = None,
progress_callback: Any | None = None,
) -> str:
"""导出所有数据到 ZIP 文件
Args:
output_dir: 输出目录
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
Returns:
str: 生成的 ZIP 文件路径
"""
if output_dir is None:
output_dir = get_astrbot_backups_path()
# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"astrbot_backup_{timestamp}.zip"
zip_path = os.path.join(output_dir, zip_filename)
logger.info(f"开始导出备份到 {zip_path}")
try:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# 1. 导出主数据库
if progress_callback:
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
main_data = await self._export_main_database()
main_db_json = json.dumps(
main_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/main_db.json", main_db_json)
self._add_checksum("databases/main_db.json", main_db_json)
if progress_callback:
await progress_callback("main_db", 100, 100, "主数据库导出完成")
# 2. 导出知识库数据
kb_meta_data: dict[str, Any] = {
"knowledge_bases": [],
"kb_documents": [],
"kb_media": [],
}
if self.kb_manager:
if progress_callback:
await progress_callback(
"kb_metadata", 0, 100, "正在导出知识库元数据..."
)
kb_meta_data = await self._export_kb_metadata()
kb_meta_json = json.dumps(
kb_meta_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/kb_metadata.json", kb_meta_json)
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
if progress_callback:
await progress_callback(
"kb_metadata", 100, 100, "知识库元数据导出完成"
)
# 导出每个知识库的文档数据
kb_insts = self.kb_manager.kb_insts
total_kbs = len(kb_insts)
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
if progress_callback:
await progress_callback(
"kb_documents",
idx,
total_kbs,
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
)
doc_data = await self._export_kb_documents(kb_helper)
doc_json = json.dumps(
doc_data, ensure_ascii=False, indent=2, default=str
)
doc_path = f"databases/kb_{kb_id}/documents.json"
zf.writestr(doc_path, doc_json)
self._add_checksum(doc_path, doc_json)
# 导出 FAISS 索引文件
await self._export_faiss_index(zf, kb_helper, kb_id)
# 导出知识库多媒体文件
await self._export_kb_media_files(zf, kb_helper, kb_id)
if progress_callback:
await progress_callback(
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
)
# 3. 导出配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导出配置文件...")
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config_content = f.read()
zf.writestr("config/cmd_config.json", config_content)
self._add_checksum("config/cmd_config.json", config_content)
if progress_callback:
await progress_callback("config", 100, 100, "配置文件导出完成")
# 4. 导出附件文件
if progress_callback:
await progress_callback("attachments", 0, 100, "正在导出附件...")
await self._export_attachments(zf, main_data.get("attachments", []))
if progress_callback:
await progress_callback("attachments", 100, 100, "附件导出完成")
# 5. 导出插件和其他目录
if progress_callback:
await progress_callback(
"directories", 0, 100, "正在导出插件和数据目录..."
)
dir_stats = await self._export_directories(zf)
if progress_callback:
await progress_callback("directories", 100, 100, "目录导出完成")
# 6. 生成 manifest
if progress_callback:
await progress_callback("manifest", 0, 100, "正在生成清单...")
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
zf.writestr("manifest.json", manifest_json)
if progress_callback:
await progress_callback("manifest", 100, 100, "清单生成完成")
logger.info(f"备份导出完成: {zip_path}")
return zip_path
except Exception as e:
logger.error(f"备份导出失败: {e}")
# 清理失败的文件
if os.path.exists(zip_path):
os.remove(zip_path)
raise
async def _export_main_database(self) -> dict[str, list[dict]]:
"""导出主数据库所有表"""
export_data: dict[str, list[dict]] = {}
async with self.main_db.get_db() as session:
for table_name, model_class in MAIN_DB_MODELS.items():
try:
result = await session.execute(select(model_class))
records = result.scalars().all()
export_data[table_name] = [
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出表 {table_name} 失败: {e}")
export_data[table_name] = []
return export_data
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
"""导出知识库元数据库"""
if not self.kb_manager:
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
export_data: dict[str, list[dict]] = {}
async with self.kb_manager.kb_db.get_db() as session:
for table_name, model_class in KB_METADATA_MODELS.items():
try:
result = await session.execute(select(model_class))
records = result.scalars().all()
export_data[table_name] = [
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
export_data[table_name] = []
return export_data
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
"""导出知识库的文档块数据"""
try:
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
vec_db: FaissVecDB = kb_helper.vec_db
if not vec_db or not vec_db.document_storage:
return {"documents": []}
# 获取所有文档
docs = await vec_db.document_storage.get_documents(
metadata_filters={},
offset=0,
limit=None, # 获取全部
)
return {"documents": docs}
except Exception as e:
logger.warning(f"导出知识库文档失败: {e}")
return {"documents": []}
async def _export_faiss_index(
self,
zf: zipfile.ZipFile,
kb_helper: Any,
kb_id: str,
) -> None:
"""导出 FAISS 索引文件"""
try:
index_path = kb_helper.kb_dir / "index.faiss"
if index_path.exists():
archive_path = f"databases/kb_{kb_id}/index.faiss"
zf.write(str(index_path), archive_path)
logger.debug(f"导出 FAISS 索引: {archive_path}")
except Exception as e:
logger.warning(f"导出 FAISS 索引失败: {e}")
async def _export_kb_media_files(
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
) -> None:
"""导出知识库的多媒体文件"""
try:
media_dir = kb_helper.kb_medias_dir
if not media_dir.exists():
return
for root, _, files in os.walk(media_dir):
for file in files:
file_path = Path(root) / file
# 计算相对路径
rel_path = file_path.relative_to(kb_helper.kb_dir)
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
zf.write(str(file_path), archive_path)
except Exception as e:
logger.warning(f"导出知识库媒体文件失败: {e}")
async def _export_directories(
self, zf: zipfile.ZipFile
) -> dict[str, dict[str, int]]:
"""导出插件和其他数据目录
Returns:
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
"""
stats: dict[str, dict[str, int]] = {}
backup_directories = get_backup_directories()
for dir_name, dir_path in backup_directories.items():
full_path = Path(dir_path)
if not full_path.exists():
logger.debug(f"目录不存在,跳过: {full_path}")
continue
file_count = 0
total_size = 0
try:
for root, dirs, files in os.walk(full_path):
# 跳过 __pycache__ 目录
dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
# 跳过 .pyc 文件
if file.endswith(".pyc"):
continue
file_path = Path(root) / file
try:
# 计算相对路径
rel_path = file_path.relative_to(full_path)
archive_path = f"directories/{dir_name}/{rel_path}"
zf.write(str(file_path), archive_path)
file_count += 1
total_size += file_path.stat().st_size
except Exception as e:
logger.warning(f"导出文件 {file_path} 失败: {e}")
stats[dir_name] = {"files": file_count, "size": total_size}
logger.debug(
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
)
except Exception as e:
logger.warning(f"导出目录 {dir_path} 失败: {e}")
stats[dir_name] = {"files": 0, "size": 0}
return stats
async def _export_attachments(
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""导出附件文件"""
for attachment in attachments:
try:
file_path = attachment.get("path", "")
if file_path and os.path.exists(file_path):
# 使用 attachment_id 作为文件名
attachment_id = attachment.get("attachment_id", "")
ext = os.path.splitext(file_path)[1]
archive_path = f"files/attachments/{attachment_id}{ext}"
zf.write(file_path, archive_path)
except Exception as e:
logger.warning(f"导出附件失败: {e}")
def _model_to_dict(self, record: Any) -> dict:
"""将 SQLModel 实例转换为字典
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
"""
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
if hasattr(record, "model_dump"):
data = record.model_dump(mode="python")
# 处理 datetime 类型
for key, value in data.items():
if isinstance(value, datetime):
data[key] = value.isoformat()
return data
# 回退到手动提取
data = {}
# 使用 inspect 获取表信息
from sqlalchemy import inspect as sa_inspect
mapper = sa_inspect(record.__class__)
for column in mapper.columns:
value = getattr(record, column.name)
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
if isinstance(value, datetime):
value = value.isoformat()
data[column.name] = value
return data
def _add_checksum(self, path: str, content: str | bytes) -> None:
"""计算并添加文件校验和"""
if isinstance(content, str):
content = content.encode("utf-8")
checksum = hashlib.sha256(content).hexdigest()
self._checksums[path] = f"sha256:{checksum}"
def _generate_manifest(
self,
main_data: dict[str, list[dict]],
kb_meta_data: dict[str, list[dict]],
dir_stats: dict[str, dict[str, int]] | None = None,
) -> dict:
"""生成备份清单"""
if dir_stats is None:
dir_stats = {}
# 收集知识库 ID
kb_document_tables = {}
if self.kb_manager:
for kb_id in self.kb_manager.kb_insts.keys():
kb_document_tables[kb_id] = "documents"
# 收集附件文件列表
attachment_files = []
for attachment in main_data.get("attachments", []):
attachment_id = attachment.get("attachment_id", "")
path = attachment.get("path", "")
if attachment_id and path:
ext = os.path.splitext(path)[1]
attachment_files.append(f"{attachment_id}{ext}")
# 收集知识库媒体文件
kb_media_files: dict[str, list[str]] = {}
if self.kb_manager:
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
media_files: list[str] = []
media_dir = kb_helper.kb_medias_dir
if media_dir.exists():
for root, _, files in os.walk(media_dir):
for file in files:
media_files.append(file)
if media_files:
kb_media_files[kb_id] = media_files
manifest = {
"version": BACKUP_MANIFEST_VERSION,
"astrbot_version": VERSION,
"exported_at": datetime.now(timezone.utc).isoformat(),
"schema_version": {
"main_db": "v4",
"kb_db": "v1",
},
"tables": {
"main_db": list(main_data.keys()),
"kb_metadata": list(kb_meta_data.keys()),
"kb_documents": kb_document_tables,
},
"files": {
"attachments": attachment_files,
"kb_media": kb_media_files,
},
"directories": list(dir_stats.keys()),
"checksums": self._checksums,
"statistics": {
"main_db": {
table: len(records) for table, records in main_data.items()
},
"kb_metadata": {
table: len(records) for table, records in kb_meta_data.items()
},
"directories": dir_stats,
},
}
return manifest
+761
View File
@@ -0,0 +1,761 @@
"""AstrBot 数据导入器
负责从 ZIP 备份文件恢复所有数据。
导入时进行版本校验:
- 主版本(前两位)不同时直接拒绝导入
- 小版本(第三位)不同时提示警告,用户可选择强制导入
- 版本匹配时也需要用户确认
"""
import json
import os
import shutil
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
from sqlalchemy import delete
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_knowledge_base_path,
)
from astrbot.core.utils.version_comparator import VersionComparator
# 从共享常量模块导入
from .constants import (
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
if TYPE_CHECKING:
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
def _get_major_version(version_str: str) -> str:
"""提取版本的主版本部分(前两位)
Args:
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
Returns:
主版本字符串,如 "4.9", "4.10"
"""
if not version_str:
return "0.0"
# 移除 v 前缀和预发布标签
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
parts = [p for p in version.split(".") if p] # 过滤空字符串
if len(parts) >= 2:
return f"{parts[0]}.{parts[1]}"
elif len(parts) == 1 and parts[0]:
return f"{parts[0]}.0"
return "0.0"
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
KB_PATH = get_astrbot_knowledge_base_path()
@dataclass
class ImportPreCheckResult:
"""导入预检查结果
用于在实际导入前检查备份文件的版本兼容性,
并返回确认信息让用户决定是否继续导入。
"""
# 检查是否通过(文件有效且版本可导入)
valid: bool = False
# 是否可以导入(版本兼容)
can_import: bool = False
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
version_status: str = ""
# 备份文件中的 AstrBot 版本
backup_version: str = ""
# 当前运行的 AstrBot 版本
current_version: str = VERSION
# 备份创建时间
backup_time: str = ""
# 确认消息(显示给用户)
confirm_message: str = ""
# 警告消息列表
warnings: list[str] = field(default_factory=list)
# 错误消息(如果检查失败)
error: str = ""
# 备份包含的内容摘要
backup_summary: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"valid": self.valid,
"can_import": self.can_import,
"version_status": self.version_status,
"backup_version": self.backup_version,
"current_version": self.current_version,
"backup_time": self.backup_time,
"confirm_message": self.confirm_message,
"warnings": self.warnings,
"error": self.error,
"backup_summary": self.backup_summary,
}
class ImportResult:
"""导入结果"""
def __init__(self):
self.success = True
self.imported_tables: dict[str, int] = {}
self.imported_files: dict[str, int] = {}
self.imported_directories: dict[str, int] = {}
self.warnings: list[str] = []
self.errors: list[str] = []
def add_warning(self, msg: str) -> None:
self.warnings.append(msg)
logger.warning(msg)
def add_error(self, msg: str) -> None:
self.errors.append(msg)
self.success = False
logger.error(msg)
def to_dict(self) -> dict:
return {
"success": self.success,
"imported_tables": self.imported_tables,
"imported_files": self.imported_files,
"imported_directories": self.imported_directories,
"warnings": self.warnings,
"errors": self.errors,
}
class AstrBotImporter:
"""AstrBot 数据导入器
导入备份文件中的所有数据,包括:
- 主数据库所有表
- 知识库元数据和文档
- 配置文件
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins
- 插件数据目录(data/plugin_data
- 配置目录(data/config
- T2I 模板目录(data/t2i_templates
- WebChat 数据目录(data/webchat
- 临时文件目录(data/temp
"""
def __init__(
self,
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
kb_root_dir: str = KB_PATH,
):
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
self.kb_root_dir = kb_root_dir
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
"""预检查备份文件
在实际导入前检查备份文件的有效性和版本兼容性。
返回检查结果供前端显示确认对话框。
Args:
zip_path: ZIP 备份文件路径
Returns:
ImportPreCheckResult: 预检查结果
"""
result = ImportPreCheckResult()
result.current_version = VERSION
if not os.path.exists(zip_path):
result.error = f"备份文件不存在: {zip_path}"
return result
try:
with zipfile.ZipFile(zip_path, "r") as zf:
# 读取 manifest
try:
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data)
except KeyError:
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
return result
except json.JSONDecodeError as e:
result.error = f"manifest.json 格式错误: {e}"
return result
# 提取基本信息
result.backup_version = manifest.get("astrbot_version", "未知")
result.backup_time = manifest.get("exported_at", "未知")
result.valid = True
# 构建备份摘要
result.backup_summary = {
"tables": list(manifest.get("tables", {}).keys()),
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
"has_config": manifest.get("has_config", False),
"directories": manifest.get("directories", []),
}
# 检查版本兼容性
version_check = self._check_version_compatibility(result.backup_version)
result.version_status = version_check["status"]
result.can_import = version_check["can_import"]
# 版本信息由前端根据 version_status 和 i18n 生成显示
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
# warnings 列表保留用于其他非版本相关的警告
return result
except zipfile.BadZipFile:
result.error = "无效的 ZIP 文件"
return result
except Exception as e:
result.error = f"检查备份文件失败: {e}"
return result
def _check_version_compatibility(self, backup_version: str) -> dict:
"""检查版本兼容性
规则:
- 主版本(前两位,如 4.9)必须一致,否则拒绝
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
Returns:
dict: {status, can_import, message}
"""
if not backup_version:
return {
"status": "major_diff",
"can_import": False,
"message": "备份文件缺少版本信息",
}
# 提取主版本(前两位)进行比较
backup_major = _get_major_version(backup_version)
current_major = _get_major_version(VERSION)
# 比较主版本
if VersionComparator.compare_version(backup_major, current_major) != 0:
return {
"status": "major_diff",
"can_import": False,
"message": (
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}"
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
),
}
# 比较完整版本
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
if version_cmp != 0:
return {
"status": "minor_diff",
"can_import": True,
"message": (
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}"
),
}
return {
"status": "match",
"can_import": True,
"message": "版本匹配",
}
async def import_all(
self,
zip_path: str,
mode: str = "replace", # "replace" 清空后导入
progress_callback: Any | None = None,
) -> ImportResult:
"""从 ZIP 文件导入所有数据
Args:
zip_path: ZIP 备份文件路径
mode: 导入模式,目前仅支持 "replace"(清空后导入)
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
Returns:
ImportResult: 导入结果
"""
result = ImportResult()
if not os.path.exists(zip_path):
result.add_error(f"备份文件不存在: {zip_path}")
return result
logger.info(f"开始从 {zip_path} 导入备份")
try:
with zipfile.ZipFile(zip_path, "r") as zf:
# 1. 读取并验证 manifest
if progress_callback:
await progress_callback("validate", 0, 100, "正在验证备份文件...")
try:
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data)
except KeyError:
result.add_error("备份文件缺少 manifest.json")
return result
except json.JSONDecodeError as e:
result.add_error(f"manifest.json 格式错误: {e}")
return result
# 版本校验
try:
self._validate_version(manifest)
except ValueError as e:
result.add_error(str(e))
return result
if progress_callback:
await progress_callback("validate", 100, 100, "验证完成")
# 2. 导入主数据库
if progress_callback:
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
try:
main_data_content = zf.read("databases/main_db.json")
main_data = json.loads(main_data_content)
if mode == "replace":
await self._clear_main_db()
imported = await self._import_main_database(main_data)
result.imported_tables.update(imported)
except Exception as e:
result.add_error(f"导入主数据库失败: {e}")
return result
if progress_callback:
await progress_callback("main_db", 100, 100, "主数据库导入完成")
# 3. 导入知识库
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
if progress_callback:
await progress_callback("kb", 0, 100, "正在导入知识库...")
try:
kb_meta_content = zf.read("databases/kb_metadata.json")
kb_meta_data = json.loads(kb_meta_content)
if mode == "replace":
await self._clear_kb_data()
await self._import_knowledge_bases(zf, kb_meta_data, result)
except Exception as e:
result.add_warning(f"导入知识库失败: {e}")
if progress_callback:
await progress_callback("kb", 100, 100, "知识库导入完成")
# 4. 导入配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导入配置文件...")
if "config/cmd_config.json" in zf.namelist():
try:
config_content = zf.read("config/cmd_config.json")
# 备份现有配置
if os.path.exists(self.config_path):
backup_path = f"{self.config_path}.bak"
shutil.copy2(self.config_path, backup_path)
with open(self.config_path, "wb") as f:
f.write(config_content)
result.imported_files["config"] = 1
except Exception as e:
result.add_warning(f"导入配置文件失败: {e}")
if progress_callback:
await progress_callback("config", 100, 100, "配置文件导入完成")
# 5. 导入附件文件
if progress_callback:
await progress_callback("attachments", 0, 100, "正在导入附件...")
attachment_count = await self._import_attachments(
zf, main_data.get("attachments", [])
)
result.imported_files["attachments"] = attachment_count
if progress_callback:
await progress_callback("attachments", 100, 100, "附件导入完成")
# 6. 导入插件和其他目录
if progress_callback:
await progress_callback(
"directories", 0, 100, "正在导入插件和数据目录..."
)
dir_stats = await self._import_directories(zf, manifest, result)
result.imported_directories = dir_stats
if progress_callback:
await progress_callback("directories", 100, 100, "目录导入完成")
logger.info(f"备份导入完成: {result.to_dict()}")
return result
except zipfile.BadZipFile:
result.add_error("无效的 ZIP 文件")
return result
except Exception as e:
result.add_error(f"导入失败: {e}")
return result
def _validate_version(self, manifest: dict) -> None:
"""验证版本兼容性 - 仅允许相同主版本导入
注意:此方法仅在 import_all 中调用,用于双重校验。
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
"""
backup_version = manifest.get("astrbot_version")
if not backup_version:
raise ValueError("备份文件缺少版本信息")
# 使用新的版本兼容性检查
version_check = self._check_version_compatibility(backup_version)
if version_check["status"] == "major_diff":
raise ValueError(version_check["message"])
# minor_diff 和 match 都允许导入
if version_check["status"] == "minor_diff":
logger.warning(f"版本差异警告: {version_check['message']}")
async def _clear_main_db(self) -> None:
"""清空主数据库所有表"""
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, model_class in MAIN_DB_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空表 {table_name}")
except Exception as e:
logger.warning(f"清空表 {table_name} 失败: {e}")
async def _clear_kb_data(self) -> None:
"""清空知识库数据"""
if not self.kb_manager:
return
# 清空知识库元数据表
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, model_class in KB_METADATA_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空知识库表 {table_name}")
except Exception as e:
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
# 删除知识库文件目录
for kb_id in list(self.kb_manager.kb_insts.keys()):
try:
kb_helper = self.kb_manager.kb_insts[kb_id]
await kb_helper.terminate()
if kb_helper.kb_dir.exists():
shutil.rmtree(kb_helper.kb_dir)
except Exception as e:
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
self.kb_manager.kb_insts.clear()
async def _import_main_database(
self, data: dict[str, list[dict]]
) -> dict[str, int]:
"""导入主数据库数据"""
imported: dict[str, int] = {}
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, rows in data.items():
model_class = MAIN_DB_MODELS.get(table_name)
if not model_class:
logger.warning(f"未知的表: {table_name}")
continue
count = 0
for row in rows:
try:
# 转换 datetime 字符串为 datetime 对象
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入记录到 {table_name} 失败: {e}")
imported[table_name] = count
logger.debug(f"导入表 {table_name}: {count} 条记录")
return imported
async def _import_knowledge_bases(
self,
zf: zipfile.ZipFile,
kb_meta_data: dict[str, list[dict]],
result: ImportResult,
) -> None:
"""导入知识库数据"""
if not self.kb_manager:
return
# 1. 导入知识库元数据
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, rows in kb_meta_data.items():
model_class = KB_METADATA_MODELS.get(table_name)
if not model_class:
continue
count = 0
for row in rows:
try:
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
result.imported_tables[f"kb_{table_name}"] = count
# 2. 导入每个知识库的文档和文件
for kb_data in kb_meta_data.get("knowledge_bases", []):
kb_id = kb_data.get("kb_id")
if not kb_id:
continue
# 创建知识库目录
kb_dir = Path(self.kb_root_dir) / kb_id
kb_dir.mkdir(parents=True, exist_ok=True)
# 导入文档数据
doc_path = f"databases/kb_{kb_id}/documents.json"
if doc_path in zf.namelist():
try:
doc_content = zf.read(doc_path)
doc_data = json.loads(doc_content)
# 导入到文档存储数据库
await self._import_kb_documents(kb_id, doc_data)
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
# 导入 FAISS 索引
faiss_path = f"databases/kb_{kb_id}/index.faiss"
if faiss_path in zf.namelist():
try:
target_path = kb_dir / "index.faiss"
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
# 导入媒体文件
media_prefix = f"files/kb_media/{kb_id}/"
for name in zf.namelist():
if name.startswith(media_prefix):
try:
rel_path = name[len(media_prefix) :]
target_path = kb_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
# 3. 重新加载知识库实例
await self.kb_manager.load_kbs()
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
"""导入知识库文档到向量数据库"""
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
kb_dir = Path(self.kb_root_dir) / kb_id
doc_db_path = kb_dir / "doc.db"
# 初始化文档存储
doc_storage = DocumentStorage(str(doc_db_path))
await doc_storage.initialize()
try:
documents = doc_data.get("documents", [])
for doc in documents:
try:
await doc_storage.insert_document(
doc_id=doc.get("doc_id", ""),
text=doc.get("text", ""),
metadata=json.loads(doc.get("metadata", "{}")),
)
except Exception as e:
logger.warning(f"导入文档块失败: {e}")
finally:
await doc_storage.close()
async def _import_attachments(
self,
zf: zipfile.ZipFile,
attachments: list[dict],
) -> int:
"""导入附件文件"""
count = 0
attachments_dir = Path(self.config_path).parent / "attachments"
attachments_dir.mkdir(parents=True, exist_ok=True)
attachment_prefix = "files/attachments/"
for name in zf.namelist():
if name.startswith(attachment_prefix) and name != attachment_prefix:
try:
# 从附件记录中找到原始路径
attachment_id = os.path.splitext(os.path.basename(name))[0]
original_path = None
for att in attachments:
if att.get("attachment_id") == attachment_id:
original_path = att.get("path")
break
if original_path:
target_path = Path(original_path)
else:
target_path = attachments_dir / os.path.basename(name)
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
count += 1
except Exception as e:
logger.warning(f"导入附件 {name} 失败: {e}")
return count
async def _import_directories(
self,
zf: zipfile.ZipFile,
manifest: dict,
result: ImportResult,
) -> dict[str, int]:
"""导入插件和其他数据目录
Args:
zf: ZIP 文件对象
manifest: 备份清单
result: 导入结果对象
Returns:
dict: 每个目录导入的文件数量
"""
dir_stats: dict[str, int] = {}
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
backup_version = manifest.get("version", "1.0")
if VersionComparator.compare_version(backup_version, "1.1") < 0:
logger.info("备份版本不支持目录备份,跳过目录导入")
return dir_stats
backed_up_dirs = manifest.get("directories", [])
backup_directories = get_backup_directories()
for dir_name in backed_up_dirs:
if dir_name not in backup_directories:
result.add_warning(f"未知的目录类型: {dir_name}")
continue
target_dir = Path(backup_directories[dir_name])
archive_prefix = f"directories/{dir_name}/"
file_count = 0
try:
# 获取该目录下的所有文件
dir_files = [
name
for name in zf.namelist()
if name.startswith(archive_prefix) and name != archive_prefix
]
if not dir_files:
continue
# 备份现有目录(如果存在)
if target_dir.exists():
backup_path = Path(f"{target_dir}.bak")
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(str(target_dir), str(backup_path))
logger.debug(f"已备份现有目录 {target_dir}{backup_path}")
# 创建目标目录
target_dir.mkdir(parents=True, exist_ok=True)
# 解压文件
for name in dir_files:
try:
# 计算相对路径
rel_path = name[len(archive_prefix) :]
if not rel_path: # 跳过目录条目
continue
target_path = target_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
file_count += 1
except Exception as e:
result.add_warning(f"导入文件 {name} 失败: {e}")
dir_stats[dir_name] = file_count
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
except Exception as e:
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
dir_stats[dir_name] = 0
return dir_stats
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
"""转换 datetime 字符串字段为 datetime 对象"""
result = row.copy()
# 获取模型的 datetime 字段
from sqlalchemy import inspect as sa_inspect
try:
mapper = sa_inspect(model_class)
for column in mapper.columns:
if column.name in result and result[column.name] is not None:
# 检查是否是 datetime 类型的列
from sqlalchemy import DateTime
if isinstance(column.type, DateTime):
value = result[column.name]
if isinstance(value, str):
# 解析 ISO 格式的日期时间字符串
result[column.name] = datetime.fromisoformat(value)
except Exception:
pass
return result
+34
View File
@@ -5,6 +5,10 @@
数据目录路径:固定为根目录下的 data 目录
配置文件路径:固定为数据目录下的 config 目录
插件目录路径:固定为数据目录下的 plugins 目录
插件数据目录路径:固定为数据目录下的 plugin_data 目录
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
临时文件目录路径:固定为数据目录下的 temp 目录
"""
import os
@@ -37,3 +41,33 @@ def get_astrbot_config_path() -> str:
def get_astrbot_plugin_path() -> str:
"""获取Astrbot插件目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
def get_astrbot_plugin_data_path() -> str:
"""获取Astrbot插件数据目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
def get_astrbot_t2i_templates_path() -> str:
"""获取Astrbot T2I 模板目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
def get_astrbot_webchat_path() -> str:
"""获取Astrbot WebChat 数据目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
def get_astrbot_temp_path() -> str:
"""获取Astrbot临时文件目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
def get_astrbot_knowledge_base_path() -> str:
"""获取Astrbot知识库根目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
def get_astrbot_backups_path() -> str:
"""获取Astrbot备份目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
+2
View File
@@ -1,4 +1,5 @@
from .auth import AuthRoute
from .backup import BackupRoute
from .chat import ChatRoute
from .command import CommandRoute
from .config import ConfigRoute
@@ -17,6 +18,7 @@ from .update import UpdateRoute
__all__ = [
"AuthRoute",
"BackupRoute",
"ChatRoute",
"CommandRoute",
"ConfigRoute",
+589
View File
@@ -0,0 +1,589 @@
"""备份管理 API 路由"""
import asyncio
import os
import re
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from quart import request, send_file
from astrbot.core import logger
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.backup.importer import AstrBotImporter
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_backups_path,
get_astrbot_data_path,
)
from .route import Response, Route, RouteContext
def secure_filename(filename: str) -> str:
"""清洗文件名,移除路径遍历字符和危险字符
Args:
filename: 原始文件名
Returns:
安全的文件名
"""
# 跨平台处理:先将反斜杠替换为正斜杠,再取文件名
filename = filename.replace("\\", "/")
# 仅保留文件名部分,移除路径
filename = os.path.basename(filename)
# 替换路径遍历字符
filename = filename.replace("..", "_")
# 仅保留字母、数字、下划线、连字符、点
filename = re.sub(r"[^\w\-.]", "_", filename)
# 移除前导点(隐藏文件)和尾部点
filename = filename.strip(".")
# 如果文件名为空或只包含下划线,生成一个默认名称
if not filename or filename.replace("_", "") == "":
filename = "backup"
return filename
def generate_unique_filename(original_filename: str) -> str:
"""生成唯一的文件名,添加时间戳前缀
Args:
original_filename: 原始文件名(已清洗)
Returns:
唯一的文件名
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
name, ext = os.path.splitext(original_filename)
return f"uploaded_{timestamp}_{name}{ext}"
class BackupRoute(Route):
"""备份管理路由
提供备份导出、导入、列表等 API 接口
"""
def __init__(
self,
context: RouteContext,
db: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.db = db
self.core_lifecycle = core_lifecycle
self.backup_dir = get_astrbot_backups_path()
self.data_dir = get_astrbot_data_path()
# 任务状态跟踪
self.backup_tasks: dict[str, dict] = {}
self.backup_progress: dict[str, dict] = {}
# 注册路由
self.routes = {
"/backup/list": ("GET", self.list_backups),
"/backup/export": ("POST", self.export_backup),
"/backup/upload": ("POST", self.upload_backup), # 上传文件
"/backup/check": ("POST", self.check_backup), # 预检查
"/backup/import": ("POST", self.import_backup), # 确认导入
"/backup/progress": ("GET", self.get_progress),
"/backup/download": ("GET", self.download_backup),
"/backup/delete": ("POST", self.delete_backup),
}
self.register_routes()
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
"""初始化任务状态"""
self.backup_tasks[task_id] = {
"type": task_type,
"status": status,
"result": None,
"error": None,
}
self.backup_progress[task_id] = {
"status": status,
"stage": "waiting",
"current": 0,
"total": 100,
"message": "",
}
def _set_task_result(
self,
task_id: str,
status: str,
result: dict | None = None,
error: str | None = None,
) -> None:
"""设置任务结果"""
if task_id in self.backup_tasks:
self.backup_tasks[task_id]["status"] = status
self.backup_tasks[task_id]["result"] = result
self.backup_tasks[task_id]["error"] = error
if task_id in self.backup_progress:
self.backup_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
message: str | None = None,
) -> None:
"""更新任务进度"""
if task_id not in self.backup_progress:
return
p = self.backup_progress[task_id]
if status is not None:
p["status"] = status
if stage is not None:
p["stage"] = stage
if current is not None:
p["current"] = current
if total is not None:
p["total"] = total
if message is not None:
p["message"] = message
def _make_progress_callback(self, task_id: str):
"""创建进度回调函数"""
async def _callback(stage: str, current: int, total: int, message: str = ""):
self._update_progress(
task_id,
status="processing",
stage=stage,
current=current,
total=total,
message=message,
)
return _callback
async def list_backups(self):
"""获取备份列表
Query 参数:
- page: 页码 (默认 1)
- page_size: 每页数量 (默认 20)
"""
try:
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
# 确保备份目录存在
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
# 获取所有备份文件
backup_files = []
for filename in os.listdir(self.backup_dir):
if filename.endswith(".zip") and filename.startswith("astrbot_backup_"):
file_path = os.path.join(self.backup_dir, filename)
stat = os.stat(file_path)
backup_files.append(
{
"filename": filename,
"size": stat.st_size,
"created_at": stat.st_mtime,
}
)
# 按创建时间倒序排序
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
# 分页
start = (page - 1) * page_size
end = start + page_size
items = backup_files[start:end]
return (
Response()
.ok(
{
"items": items,
"total": len(backup_files),
"page": page,
"page_size": page_size,
}
)
.__dict__
)
except Exception as e:
logger.error(f"获取备份列表失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取备份列表失败: {e!s}").__dict__
async def export_backup(self):
"""创建备份
返回:
- task_id: 任务ID,用于查询导出进度
"""
try:
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, "export", "pending")
# 启动后台导出任务
asyncio.create_task(self._background_export_task(task_id))
return (
Response()
.ok(
{
"task_id": task_id,
"message": "export task created, processing in background",
}
)
.__dict__
)
except Exception as e:
logger.error(f"创建备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"创建备份失败: {e!s}").__dict__
async def _background_export_task(self, task_id: str):
"""后台导出任务"""
try:
self._update_progress(task_id, status="processing", message="正在初始化...")
# 获取知识库管理器
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
exporter = AstrBotExporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
# 创建进度回调
progress_callback = self._make_progress_callback(task_id)
# 执行导出
zip_path = await exporter.export_all(
output_dir=self.backup_dir,
progress_callback=progress_callback,
)
# 设置成功结果
self._set_task_result(
task_id,
"completed",
result={
"filename": os.path.basename(zip_path),
"path": zip_path,
"size": os.path.getsize(zip_path),
},
)
except Exception as e:
logger.error(f"后台导出任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def upload_backup(self):
"""上传备份文件
将备份文件上传到服务器,返回保存的文件名。
上传后应调用 check_backup 进行预检查。
Form Data:
- file: 备份文件 (.zip)
返回:
- filename: 保存的文件名
"""
try:
files = await request.files
if "file" not in files:
return Response().error("缺少备份文件").__dict__
file = files["file"]
if not file.filename or not file.filename.endswith(".zip"):
return Response().error("请上传 ZIP 格式的备份文件").__dict__
# 清洗文件名并生成唯一名称,防止路径遍历和覆盖
safe_filename = secure_filename(file.filename)
unique_filename = generate_unique_filename(safe_filename)
# 保存上传的文件
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
zip_path = os.path.join(self.backup_dir, unique_filename)
await file.save(zip_path)
logger.info(
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
)
return (
Response()
.ok(
{
"filename": unique_filename,
"original_filename": file.filename,
"size": os.path.getsize(zip_path),
}
)
.__dict__
)
except Exception as e:
logger.error(f"上传备份文件失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"上传备份文件失败: {e!s}").__dict__
async def check_backup(self):
"""预检查备份文件
检查备份文件的版本兼容性,返回确认信息。
用户确认后调用 import_backup 执行导入。
JSON Body:
- filename: 已上传的备份文件名
返回:
- ImportPreCheckResult: 预检查结果
"""
try:
data = await request.json
filename = data.get("filename")
if not filename:
return Response().error("缺少 filename 参数").__dict__
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
zip_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(zip_path):
return Response().error(f"备份文件不存在: {filename}").__dict__
# 获取知识库管理器(用于构造 importer)
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
importer = AstrBotImporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
# 执行预检查
check_result = importer.pre_check(zip_path)
return Response().ok(check_result.to_dict()).__dict__
except Exception as e:
logger.error(f"预检查备份文件失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"预检查备份文件失败: {e!s}").__dict__
async def import_backup(self):
"""执行备份导入
在用户确认后执行实际的导入操作。
需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。
JSON Body:
- filename: 已上传的备份文件名(必填)
- confirmed: 用户已确认(必填,必须为 true)
返回:
- task_id: 任务ID,用于查询导入进度
"""
try:
data = await request.json
filename = data.get("filename")
confirmed = data.get("confirmed", False)
if not filename:
return Response().error("缺少 filename 参数").__dict__
if not confirmed:
return (
Response()
.error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。")
.__dict__
)
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
zip_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(zip_path):
return Response().error(f"备份文件不存在: {filename}").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, "import", "pending")
# 启动后台导入任务
asyncio.create_task(self._background_import_task(task_id, zip_path))
return (
Response()
.ok(
{
"task_id": task_id,
"message": "import task created, processing in background",
}
)
.__dict__
)
except Exception as e:
logger.error(f"导入备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"导入备份失败: {e!s}").__dict__
async def _background_import_task(self, task_id: str, zip_path: str):
"""后台导入任务"""
try:
self._update_progress(task_id, status="processing", message="正在初始化...")
# 获取知识库管理器
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
importer = AstrBotImporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
# 创建进度回调
progress_callback = self._make_progress_callback(task_id)
# 执行导入
result = await importer.import_all(
zip_path=zip_path,
mode="replace",
progress_callback=progress_callback,
)
# 设置结果
if result.success:
self._set_task_result(
task_id,
"completed",
result=result.to_dict(),
)
else:
self._set_task_result(
task_id,
"failed",
error="; ".join(result.errors),
)
except Exception as e:
logger.error(f"后台导入任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def get_progress(self):
"""获取任务进度
Query 参数:
- task_id: 任务 ID (必填)
"""
try:
task_id = request.args.get("task_id")
if not task_id:
return Response().error("缺少参数 task_id").__dict__
if task_id not in self.backup_tasks:
return Response().error("找不到该任务").__dict__
task_info = self.backup_tasks[task_id]
status = task_info["status"]
response_data = {
"task_id": task_id,
"type": task_info["type"],
"status": status,
}
# 如果任务正在处理,返回进度信息
if status == "processing" and task_id in self.backup_progress:
response_data["progress"] = self.backup_progress[task_id]
# 如果任务完成,返回结果
if status == "completed":
response_data["result"] = task_info["result"]
# 如果任务失败,返回错误信息
if status == "failed":
response_data["error"] = task_info["error"]
return Response().ok(response_data).__dict__
except Exception as e:
logger.error(f"获取任务进度失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取任务进度失败: {e!s}").__dict__
async def download_backup(self):
"""下载备份文件
Query 参数:
- filename: 备份文件名 (必填)
"""
try:
filename = request.args.get("filename")
if not filename:
return Response().error("缺少参数 filename").__dict__
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
file_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(file_path):
return Response().error("备份文件不存在").__dict__
return await send_file(
file_path,
as_attachment=True,
attachment_filename=filename,
)
except Exception as e:
logger.error(f"下载备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"下载备份失败: {e!s}").__dict__
async def delete_backup(self):
"""删除备份文件
Body:
- filename: 备份文件名 (必填)
"""
try:
data = await request.json
filename = data.get("filename")
if not filename:
return Response().error("缺少参数 filename").__dict__
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
file_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(file_path):
return Response().error("备份文件不存在").__dict__
os.remove(file_path)
return Response().ok(message="删除备份成功").__dict__
except Exception as e:
logger.error(f"删除备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"删除备份失败: {e!s}").__dict__
+8 -1
View File
@@ -19,6 +19,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import get_local_ip_addresses
from .routes import *
from .routes.backup import BackupRoute
from .routes.platform import PlatformRoute
from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute
@@ -85,6 +86,7 @@ class AstrBotDashboard:
self.t2i_route = T2iRoute(self.context, core_lifecycle)
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
self.platform_route = PlatformRoute(self.context, core_lifecycle)
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
self.app.add_url_rule(
"/api/plug/<path:subpath>",
@@ -108,7 +110,12 @@ class AstrBotDashboard:
async def auth_middleware(self):
if not request.path.startswith("/api"):
return None
allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
allowed_endpoints = [
"/api/auth/login",
"/api/file",
"/api/platform/webhook",
"/api/stat/start-time",
]
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
return None
# 声明 JWT
@@ -0,0 +1,673 @@
<template>
<v-dialog v-model="isOpen" persistent max-width="700" scrollable>
<v-card>
<v-card-title class="d-flex align-center">
<v-icon class="mr-2">mdi-backup-restore</v-icon>
{{ t('features.settings.backup.dialog.title') }}
</v-card-title>
<v-card-text class="pa-6">
<!-- 选项卡 -->
<v-tabs v-model="activeTab" color="primary" class="mb-4">
<v-tab value="export">
<v-icon class="mr-2">mdi-export</v-icon>
{{ t('features.settings.backup.tabs.export') }}
</v-tab>
<v-tab value="import">
<v-icon class="mr-2">mdi-import</v-icon>
{{ t('features.settings.backup.tabs.import') }}
</v-tab>
<v-tab value="list">
<v-icon class="mr-2">mdi-format-list-bulleted</v-icon>
{{ t('features.settings.backup.tabs.list') }}
</v-tab>
</v-tabs>
<v-window v-model="activeTab">
<!-- 导出标签页 -->
<v-window-item value="export">
<div v-if="exportStatus === 'idle'" class="text-center py-8">
<v-icon size="64" color="primary" class="mb-4">mdi-cloud-upload</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.export.title') }}</h3>
<p class="mb-4 text-grey">{{ t('features.settings.backup.export.description') }}</p>
<v-alert type="info" variant="tonal" class="mb-4 text-left">
<template v-slot:prepend>
<v-icon>mdi-information</v-icon>
</template>
{{ t('features.settings.backup.export.includes') }}
</v-alert>
<v-btn color="primary" size="large" @click="startExport" :loading="exportStatus === 'processing'">
<v-icon class="mr-2">mdi-export</v-icon>
{{ t('features.settings.backup.export.button') }}
</v-btn>
</div>
<div v-else-if="exportStatus === 'processing'" class="text-center py-8">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<h3 class="mb-4">{{ t('features.settings.backup.export.processing') }}</h3>
<p class="text-grey">{{ exportProgress.message || t('features.settings.backup.export.wait') }}</p>
<v-progress-linear :model-value="exportProgress.current" :max="exportProgress.total" class="mt-4" color="primary"></v-progress-linear>
</div>
<div v-else-if="exportStatus === 'completed'" class="text-center py-8">
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.export.completed') }}</h3>
<p class="mb-4">{{ exportResult?.filename }}</p>
<v-btn color="primary" @click="downloadBackup(exportResult?.filename)" class="mr-2">
<v-icon class="mr-2">mdi-download</v-icon>
{{ t('features.settings.backup.export.download') }}
</v-btn>
<v-btn color="grey" variant="text" @click="resetExport">
{{ t('features.settings.backup.export.another') }}
</v-btn>
</div>
<div v-else-if="exportStatus === 'failed'" class="text-center py-8">
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.export.failed') }}</h3>
<v-alert type="error" variant="tonal" class="mb-4">
{{ exportError }}
</v-alert>
<v-btn color="primary" @click="resetExport">
{{ t('features.settings.backup.export.retry') }}
</v-btn>
</div>
</v-window-item>
<!-- 导入标签页 -->
<v-window-item value="import">
<!-- 步骤1: 选择文件 -->
<div v-if="importStatus === 'idle'" class="py-4">
<v-alert type="warning" variant="tonal" class="mb-4">
<template v-slot:prepend>
<v-icon>mdi-alert</v-icon>
</template>
{{ t('features.settings.backup.import.warning') }}
</v-alert>
<v-file-input
v-model="importFile"
:label="t('features.settings.backup.import.selectFile')"
accept=".zip"
prepend-icon="mdi-file-upload"
show-size
class="mb-4"
></v-file-input>
<div class="d-flex justify-center">
<v-btn
color="primary"
size="large"
@click="uploadAndCheck"
:disabled="!importFile"
:loading="importStatus === 'uploading'"
>
<v-icon class="mr-2">mdi-upload</v-icon>
{{ t('features.settings.backup.import.uploadAndCheck') }}
</v-btn>
</div>
</div>
<!-- 步骤1.5: 上传中 -->
<div v-else-if="importStatus === 'uploading'" class="text-center py-8">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<h3 class="mb-4">{{ t('features.settings.backup.import.uploading') }}</h3>
<p class="text-grey">{{ t('features.settings.backup.import.uploadWait') }}</p>
</div>
<!-- 步骤2: 确认导入 -->
<div v-else-if="importStatus === 'confirm'" class="py-4">
<v-alert
:type="versionAlertType"
variant="tonal"
class="mb-4"
>
<template v-slot:prepend>
<v-icon>{{ versionAlertIcon }}</v-icon>
</template>
<div class="confirm-message">
<div class="text-h6 mb-2">{{ versionAlertTitle }}</div>
<div class="mb-2">
<strong>{{ t('features.settings.backup.import.version.backupVersion') }}:</strong> {{ checkResult?.backup_version }}<br>
<strong>{{ t('features.settings.backup.import.version.currentVersion') }}:</strong> {{ checkResult?.current_version }}
</div>
<div v-if="checkResult?.backup_time && checkResult?.backup_time !== '未知'" class="mb-2">
<strong>{{ t('features.settings.backup.import.version.backupTime') }}:</strong> {{ formatISODate(checkResult?.backup_time) }}
</div>
<div class="mt-3" style="white-space: pre-line;">{{ versionAlertMessage }}</div>
</div>
</v-alert>
<!-- 备份摘要 -->
<v-card variant="outlined" class="mb-4" v-if="checkResult?.backup_summary">
<v-card-title class="text-subtitle-1">
<v-icon class="mr-2">mdi-package-variant</v-icon>
{{ t('features.settings.backup.import.backupContents') }}
</v-card-title>
<v-card-text>
<div class="d-flex flex-wrap ga-2">
<v-chip v-if="checkResult.backup_summary.tables?.length" size="small" color="primary" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ checkResult.backup_summary.tables.length }} {{ t('features.settings.backup.import.tables') }}
</v-chip>
<v-chip v-if="checkResult.backup_summary.has_knowledge_bases" size="small" color="success" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ t('features.settings.backup.import.knowledgeBases') }}
</v-chip>
<v-chip v-if="checkResult.backup_summary.has_config" size="small" color="info" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ t('features.settings.backup.import.configFiles') }}
</v-chip>
<v-chip v-for="dir in (checkResult.backup_summary.directories || [])" :key="dir" size="small" color="warning" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ dir }}
</v-chip>
</div>
</v-card-text>
</v-card>
<!-- 警告信息 -->
<v-alert v-if="checkResult?.warnings?.length" type="warning" variant="tonal" class="mb-4">
<div v-for="(warning, idx) in checkResult.warnings" :key="idx">{{ warning }}</div>
</v-alert>
<div class="d-flex justify-center align-center mt-4" style="gap: 16px;">
<v-btn
color="grey-darken-1"
variant="outlined"
size="large"
@click="resetImport"
>
<v-icon class="mr-2">mdi-close</v-icon>
{{ t('core.common.cancel') }}
</v-btn>
<v-btn
v-if="checkResult?.can_import"
color="error"
size="large"
variant="flat"
@click="confirmImport"
>
<v-icon class="mr-2">mdi-alert</v-icon>
{{ t('features.settings.backup.import.confirmImport') }}
</v-btn>
</div>
</div>
<!-- 步骤3: 导入进行中 -->
<div v-else-if="importStatus === 'processing'" class="text-center py-8">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<h3 class="mb-4">{{ t('features.settings.backup.import.processing') }}</h3>
<p class="text-grey">{{ importProgress.message || t('features.settings.backup.import.wait') }}</p>
<v-progress-linear :model-value="importProgress.current" :max="importProgress.total" class="mt-4" color="primary"></v-progress-linear>
</div>
<div v-else-if="importStatus === 'completed'" class="text-center py-8">
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.import.completed') }}</h3>
<v-alert type="info" variant="tonal" class="mb-4">
{{ t('features.settings.backup.import.restartRequired') }}
</v-alert>
<v-btn color="primary" @click="restartAstrBot" class="mr-2">
<v-icon class="mr-2">mdi-restart</v-icon>
{{ t('features.settings.backup.import.restartNow') }}
</v-btn>
<v-btn color="grey" variant="text" @click="resetImport">
{{ t('core.common.close') }}
</v-btn>
</div>
<div v-else-if="importStatus === 'failed'" class="text-center py-8">
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.import.failed') }}</h3>
<v-alert type="error" variant="tonal" class="mb-4">
{{ importError }}
</v-alert>
<v-btn color="primary" @click="resetImport">
{{ t('features.settings.backup.import.retry') }}
</v-btn>
</div>
</v-window-item>
<!-- 备份列表标签页 -->
<v-window-item value="list">
<div v-if="loadingList" class="text-center py-8">
<v-progress-circular indeterminate color="primary"></v-progress-circular>
</div>
<div v-else-if="backupList.length === 0" class="text-center py-8">
<v-icon size="64" color="grey" class="mb-4">mdi-folder-open-outline</v-icon>
<p class="text-grey">{{ t('features.settings.backup.list.empty') }}</p>
</div>
<v-list v-else lines="two">
<v-list-item
v-for="backup in backupList"
:key="backup.filename"
>
<template v-slot:prepend>
<v-icon color="primary">mdi-zip-box</v-icon>
</template>
<v-list-item-title>{{ backup.filename }}</v-list-item-title>
<v-list-item-subtitle>
{{ formatFileSize(backup.size) }} · {{ formatDate(backup.created_at) }}
</v-list-item-subtitle>
<template v-slot:append>
<v-btn icon="mdi-download" variant="text" size="small" @click="downloadBackup(backup.filename)"></v-btn>
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="deleteBackup(backup.filename)"></v-btn>
</template>
</v-list-item>
</v-list>
<div class="d-flex justify-center mt-4">
<v-btn color="primary" variant="text" @click="loadBackupList">
<v-icon class="mr-2">mdi-refresh</v-icon>
{{ t('features.settings.backup.list.refresh') }}
</v-btn>
</div>
</v-window-item>
</v-window>
</v-card-text>
<v-card-actions class="px-6 py-4">
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="handleClose" :disabled="isProcessing">
{{ t('core.common.close') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script setup>
import { ref, computed, watch } from 'vue'
import axios from 'axios'
import { useI18n } from '@/i18n/composables'
import WaitingForRestart from './WaitingForRestart.vue'
const { t } = useI18n()
const isOpen = ref(false)
const activeTab = ref('export')
const wfr = ref(null)
// 导出状态
const exportStatus = ref('idle') // idle, processing, completed, failed
const exportTaskId = ref(null)
const exportProgress = ref({ current: 0, total: 100, message: '' })
const exportResult = ref(null)
const exportError = ref('')
// 导入状态
const importStatus = ref('idle') // idle, uploading, confirm, processing, completed, failed
const importFile = ref(null)
const importTaskId = ref(null)
const importProgress = ref({ current: 0, total: 100, message: '' })
const importError = ref('')
const uploadedFilename = ref('') // 已上传的文件名
const checkResult = ref(null) // 预检查结果
// 备份列表
const loadingList = ref(false)
const backupList = ref([])
// 计算属性
const isProcessing = computed(() => {
return exportStatus.value === 'processing' || importStatus.value === 'processing'
})
// 版本检查相关的计算属性
const versionAlertType = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return 'error'
if (status === 'minor_diff') return 'warning'
return 'info'
})
const versionAlertIcon = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return 'mdi-close-circle'
if (status === 'minor_diff') return 'mdi-alert'
return 'mdi-check-circle'
})
const versionAlertTitle = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffTitle')
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffTitle')
return t('features.settings.backup.import.version.matchTitle')
})
const versionAlertMessage = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffMessage')
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffMessage')
return t('features.settings.backup.import.version.matchMessage')
})
// 监听对话框打开
watch(isOpen, (newVal) => {
if (newVal) {
loadBackupList()
} else {
resetAll()
}
})
// 监听标签页切换
watch(activeTab, (newVal) => {
if (newVal === 'list') {
loadBackupList()
}
})
// 加载备份列表
const loadBackupList = async () => {
loadingList.value = true
try {
const response = await axios.get('/api/backup/list')
if (response.data.status === 'ok') {
backupList.value = response.data.data.items || []
}
} catch (error) {
console.error('Failed to load backup list:', error)
} finally {
loadingList.value = false
}
}
// 开始导出
const startExport = async () => {
exportStatus.value = 'processing'
exportProgress.value = { current: 0, total: 100, message: '' }
try {
const response = await axios.post('/api/backup/export')
if (response.data.status === 'ok') {
exportTaskId.value = response.data.data.task_id
pollExportProgress()
} else {
throw new Error(response.data.message)
}
} catch (error) {
exportStatus.value = 'failed'
exportError.value = error.message || 'Export failed'
}
}
// 轮询导出进度
const pollExportProgress = async () => {
if (!exportTaskId.value) return
try {
const response = await axios.get('/api/backup/progress', {
params: { task_id: exportTaskId.value }
})
if (response.data.status === 'ok') {
const data = response.data.data
if (data.status === 'processing' && data.progress) {
exportProgress.value = {
current: data.progress.current || 0,
total: data.progress.total || 100,
message: data.progress.message || ''
}
setTimeout(pollExportProgress, 1000)
} else if (data.status === 'completed') {
exportStatus.value = 'completed'
exportResult.value = data.result
loadBackupList()
} else if (data.status === 'failed') {
exportStatus.value = 'failed'
exportError.value = data.error || 'Export failed'
} else {
setTimeout(pollExportProgress, 1000)
}
}
} catch (error) {
exportStatus.value = 'failed'
exportError.value = error.message || 'Failed to get export progress'
}
}
// 重置导出状态
const resetExport = () => {
exportStatus.value = 'idle'
exportTaskId.value = null
exportProgress.value = { current: 0, total: 100, message: '' }
exportResult.value = null
exportError.value = ''
}
// 上传并检查
const uploadAndCheck = async () => {
if (!importFile.value) return
importStatus.value = 'uploading'
try {
// 步骤1: 上传文件
const formData = new FormData()
formData.append('file', importFile.value)
const uploadResponse = await axios.post('/api/backup/upload', formData, {
headers: { 'Content-Type': 'multipart/form-data' }
})
if (uploadResponse.data.status !== 'ok') {
throw new Error(uploadResponse.data.message)
}
uploadedFilename.value = uploadResponse.data.data.filename
// 步骤2: 预检查
const checkResponse = await axios.post('/api/backup/check', {
filename: uploadedFilename.value
})
if (checkResponse.data.status !== 'ok') {
throw new Error(checkResponse.data.message)
}
checkResult.value = checkResponse.data.data
// 检查是否有效
if (!checkResult.value.valid) {
importStatus.value = 'failed'
importError.value = checkResult.value.error || t('features.settings.backup.import.invalidBackup')
return
}
// 显示确认对话框
importStatus.value = 'confirm'
} catch (error) {
importStatus.value = 'failed'
importError.value = error.response?.data?.message || error.message || 'Upload failed'
}
}
// 确认导入
const confirmImport = async () => {
if (!uploadedFilename.value) return
importStatus.value = 'processing'
importProgress.value = { current: 0, total: 100, message: '' }
try {
const response = await axios.post('/api/backup/import', {
filename: uploadedFilename.value,
confirmed: true
})
if (response.data.status === 'ok') {
importTaskId.value = response.data.data.task_id
pollImportProgress()
} else {
throw new Error(response.data.message)
}
} catch (error) {
importStatus.value = 'failed'
importError.value = error.response?.data?.message || error.message || 'Import failed'
}
}
// 轮询导入进度
const pollImportProgress = async () => {
if (!importTaskId.value) return
try {
const response = await axios.get('/api/backup/progress', {
params: { task_id: importTaskId.value }
})
if (response.data.status === 'ok') {
const data = response.data.data
if (data.status === 'processing' && data.progress) {
importProgress.value = {
current: data.progress.current || 0,
total: data.progress.total || 100,
message: data.progress.message || ''
}
setTimeout(pollImportProgress, 1000)
} else if (data.status === 'completed') {
importStatus.value = 'completed'
} else if (data.status === 'failed') {
importStatus.value = 'failed'
importError.value = data.error || 'Import failed'
} else {
setTimeout(pollImportProgress, 1000)
}
}
} catch (error) {
importStatus.value = 'failed'
importError.value = error.message || 'Failed to get import progress'
}
}
// 重置导入状态
const resetImport = () => {
importStatus.value = 'idle'
importFile.value = null
importTaskId.value = null
importProgress.value = { current: 0, total: 100, message: '' }
importError.value = ''
uploadedFilename.value = ''
checkResult.value = null
}
// 下载备份
const downloadBackup = async (filename) => {
try {
const response = await axios.get('/api/backup/download', {
params: { filename },
responseType: 'blob'
})
// 创建 Blob URL 并触发下载
const blob = new Blob([response.data], { type: 'application/zip' })
const url = window.URL.createObjectURL(blob)
const link = document.createElement('a')
link.href = url
link.download = filename
document.body.appendChild(link)
link.click()
document.body.removeChild(link)
window.URL.revokeObjectURL(url)
} catch (error) {
console.error('Download failed:', error)
alert(t('features.settings.backup.export.failed') + ': ' + (error.message || 'Unknown error'))
}
}
// 删除备份
const deleteBackup = async (filename) => {
if (!confirm(t('features.settings.backup.list.confirmDelete'))) return
try {
const response = await axios.post('/api/backup/delete', { filename })
if (response.data.status === 'ok') {
loadBackupList()
} else {
alert(response.data.message || 'Delete failed')
}
} catch (error) {
alert(error.message || 'Delete failed')
}
}
// 格式化文件大小
const formatFileSize = (bytes) => {
if (bytes === 0) return '0 B'
const k = 1024
const sizes = ['B', 'KB', 'MB', 'GB']
const i = Math.floor(Math.log(bytes) / Math.log(k))
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
}
// 格式化日期(从时间戳)
const formatDate = (timestamp) => {
return new Date(timestamp * 1000).toLocaleString()
}
// 格式化 ISO 日期字符串
const formatISODate = (isoString) => {
if (!isoString) return ''
try {
return new Date(isoString).toLocaleString()
} catch {
return isoString
}
}
// 重启 AstrBot
const restartAstrBot = () => {
axios.post('/api/stat/restart-core').then(() => {
if (wfr.value) {
wfr.value.check()
}
})
}
// 重置所有状态
const resetAll = () => {
resetExport()
resetImport()
activeTab.value = 'export'
}
// 关闭对话框
const handleClose = () => {
if (isProcessing.value) return
isOpen.value = false
}
// 打开对话框
const open = () => {
isOpen.value = true
}
defineExpose({ open })
</script>
<style scoped>
.v-list-item {
border-bottom: 1px solid rgba(0, 0, 0, 0.08);
}
.v-list-item:last-child {
border-bottom: none;
}
/* 禁用 Chip 的交互效果 */
.non-interactive-chip {
pointer-events: none;
cursor: default;
}
.non-interactive-chip:hover {
box-shadow: none !important;
}
</style>
@@ -18,6 +18,11 @@
"title": "Data Migration to v4.0.0",
"subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant",
"button": "Start Migration Assistant"
},
"backup": {
"title": "Backup & Restore",
"subtitle": "Export or import all AstrBot data for easy migration to a new server",
"button": "Backup Manager"
}
},
"sidebar": {
@@ -29,5 +34,66 @@
"mainItems": "Main Modules",
"moreItems": "More Features"
}
},
"backup": {
"dialog": {
"title": "Backup Manager"
},
"tabs": {
"export": "Export Backup",
"import": "Import Backup",
"list": "Backup List"
},
"export": {
"title": "Create Backup",
"description": "Export all data as a ZIP backup file, including database, knowledge base, config and attachments.",
"includes": "Backup includes: Main database, Knowledge bases (metadata + vector index + documents), Config files, Attachment files",
"button": "Start Export",
"processing": "Exporting...",
"wait": "Please wait, packaging data...",
"completed": "Export Completed!",
"download": "Download Backup",
"another": "Create New Backup",
"failed": "Export Failed",
"retry": "Retry"
},
"import": {
"title": "Import Backup",
"warning": "⚠️ Import will clear and overwrite existing data! Please make sure you have backed up your current data.",
"selectFile": "Select backup file (.zip)",
"uploadAndCheck": "Upload & Check",
"uploading": "Uploading...",
"uploadWait": "Please wait, uploading backup file...",
"invalidBackup": "Invalid backup file",
"backupContents": "Backup Contents",
"tables": "tables",
"knowledgeBases": "Knowledge Bases",
"configFiles": "Config Files",
"confirmImport": "Confirm Import",
"button": "Start Import",
"processing": "Importing...",
"wait": "Please wait, restoring data...",
"completed": "Import Completed!",
"restartRequired": "Data has been successfully imported. It is recommended to restart AstrBot immediately for all changes to take effect.",
"restartNow": "Restart Now",
"failed": "Import Failed",
"retry": "Retry",
"version": {
"backupVersion": "Backup Version",
"currentVersion": "Current Version",
"backupTime": "Backup Time",
"matchTitle": "✅ Version Match",
"matchMessage": "Import will clear and overwrite all existing data, including:\n• Main database (conversations, settings, etc.)\n• Knowledge bases\n• Plugins and plugin data\n• Configuration files\n\nThis action cannot be undone! Do you want to continue?",
"minorDiffTitle": "⚠️ Version Difference Warning",
"minorDiffMessage": "Minor version differences are usually compatible, but there may be some data structure changes.\nImport will clear and overwrite all existing data!\n\nDo you want to continue?",
"majorDiffTitle": "⛔ Cannot Import",
"majorDiffMessage": "Major version numbers are different. Cross-major-version import may cause data corruption.\nPlease use the same major version of AstrBot for import."
}
},
"list": {
"empty": "No backup files",
"refresh": "Refresh List",
"confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone."
}
}
}
}
@@ -18,6 +18,11 @@
"title": "数据迁移到 v4.0.0 格式",
"subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手",
"button": "启动迁移助手"
},
"backup": {
"title": "数据备份与恢复",
"subtitle": "导出或导入 AstrBot 的所有数据,方便迁移到新服务器",
"button": "备份管理"
}
},
"sidebar": {
@@ -29,5 +34,66 @@
"mainItems": "主要模块",
"moreItems": "更多功能"
}
},
"backup": {
"dialog": {
"title": "备份管理"
},
"tabs": {
"export": "导出备份",
"import": "导入备份",
"list": "备份列表"
},
"export": {
"title": "创建备份",
"description": "将所有数据导出为 ZIP 备份文件,包括数据库、知识库、配置和附件。",
"includes": "备份包含:主数据库、知识库(元数据+向量索引+文档)、配置文件、附件文件",
"button": "开始导出",
"processing": "正在导出...",
"wait": "请稍候,正在打包数据...",
"completed": "导出完成!",
"download": "下载备份",
"another": "创建新备份",
"failed": "导出失败",
"retry": "重试"
},
"import": {
"title": "导入备份",
"warning": "⚠️ 导入将会清空并覆盖现有数据!请确保已备份当前数据。",
"selectFile": "选择备份文件 (.zip)",
"uploadAndCheck": "上传并检查",
"uploading": "正在上传...",
"uploadWait": "请稍候,正在上传备份文件...",
"invalidBackup": "无效的备份文件",
"backupContents": "备份内容",
"tables": "个数据表",
"knowledgeBases": "知识库",
"configFiles": "配置文件",
"confirmImport": "确认导入",
"button": "开始导入",
"processing": "正在导入...",
"wait": "请稍候,正在恢复数据...",
"completed": "导入完成!",
"restartRequired": "数据已成功导入。建议立即重启 AstrBot 以使所有更改生效。",
"restartNow": "立即重启",
"failed": "导入失败",
"retry": "重试",
"version": {
"backupVersion": "备份版本",
"currentVersion": "当前版本",
"backupTime": "备份时间",
"matchTitle": "✅ 版本匹配",
"matchMessage": "导入将会清空并覆盖现有的所有数据,包括:\n• 主数据库(对话记录、配置等)\n• 知识库数据\n• 插件及插件数据\n• 配置文件\n\n此操作不可撤销!是否确认继续?",
"minorDiffTitle": "⚠️ 版本差异警告",
"minorDiffMessage": "小版本差异通常是兼容的,但可能存在少量数据结构变化。\n导入将会清空并覆盖现有的所有数据!\n\n是否确认继续导入?",
"majorDiffTitle": "⛔ 无法导入",
"majorDiffMessage": "主版本号不同,跨主版本导入可能导致数据损坏。\n请使用相同主版本的 AstrBot 进行导入。"
}
},
"list": {
"empty": "暂无备份文件",
"refresh": "刷新列表",
"confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。"
}
}
}
}
+16
View File
@@ -17,6 +17,13 @@
<v-list-subheader>{{ tm('system.title') }}</v-list-subheader>
<v-list-item :subtitle="tm('system.backup.subtitle')" :title="tm('system.backup.title')">
<v-btn style="margin-top: 16px;" color="primary" @click="openBackupDialog">
<v-icon class="mr-2">mdi-backup-restore</v-icon>
{{ tm('system.backup.button') }}
</v-btn>
</v-list-item>
<v-list-item :subtitle="tm('system.restart.subtitle')" :title="tm('system.restart.title')">
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">{{ tm('system.restart.button') }}</v-btn>
</v-list-item>
@@ -30,6 +37,7 @@
<WaitingForRestart ref="wfr"></WaitingForRestart>
<MigrationDialog ref="migrationDialog"></MigrationDialog>
<BackupDialog ref="backupDialog"></BackupDialog>
</template>
@@ -40,12 +48,14 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import ProxySelector from '@/components/shared/ProxySelector.vue';
import MigrationDialog from '@/components/shared/MigrationDialog.vue';
import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue';
import BackupDialog from '@/components/shared/BackupDialog.vue';
import { useModuleI18n } from '@/i18n/composables';
const { tm } = useModuleI18n('features/settings');
const wfr = ref(null);
const migrationDialog = ref(null);
const backupDialog = ref(null);
const restartAstrBot = () => {
axios.post('/api/stat/restart-core').then(() => {
@@ -65,4 +75,10 @@ const startMigration = async () => {
}
}
}
const openBackupDialog = () => {
if (backupDialog.value) {
backupDialog.value.open();
}
}
</script>
+1
View File
@@ -19,6 +19,7 @@ export default defineConfig({
],
resolve: {
alias: {
mermaid: 'mermaid/dist/mermaid.js',
'@': fileURLToPath(new URL('./src', import.meta.url))
}
},
+760
View File
@@ -0,0 +1,760 @@
"""备份功能单元测试"""
import json
import os
import re
import zipfile
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from astrbot.core.backup import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
ImportPreCheckResult,
)
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.backup.importer import (
AstrBotImporter,
ImportResult,
_get_major_version,
)
from astrbot.core.config.default import VERSION
from astrbot.core.db.po import (
ConversationV2,
)
from astrbot.core.utils.version_comparator import VersionComparator
from astrbot.dashboard.routes.backup import (
generate_unique_filename,
secure_filename,
)
@pytest.fixture
def temp_backup_dir(tmp_path):
"""创建临时备份目录"""
backup_dir = tmp_path / "backups"
backup_dir.mkdir()
return backup_dir
@pytest.fixture
def temp_data_dir(tmp_path):
"""创建临时数据目录"""
data_dir = tmp_path / "data"
data_dir.mkdir()
# 创建配置文件
config_path = data_dir / "cmd_config.json"
config_path.write_text(json.dumps({"test": "config"}))
# 创建附件目录
attachments_dir = data_dir / "attachments"
attachments_dir.mkdir()
return data_dir
@pytest.fixture
def mock_main_db():
"""创建模拟的主数据库"""
db = MagicMock()
# 模拟异步上下文管理器
session = AsyncMock()
db.get_db = MagicMock(
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
)
return db
@pytest.fixture
def mock_kb_manager():
"""创建模拟的知识库管理器"""
kb_manager = MagicMock()
kb_manager.kb_insts = {}
# 模拟 kb_db
kb_db = MagicMock()
session = AsyncMock()
kb_db.get_db = MagicMock(
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
)
kb_manager.kb_db = kb_db
return kb_manager
class TestImportResult:
"""ImportResult 类测试"""
def test_init(self):
"""测试初始化"""
result = ImportResult()
assert result.success is True
assert result.imported_tables == {}
assert result.imported_files == {}
assert result.warnings == []
assert result.errors == []
def test_add_warning(self):
"""测试添加警告"""
result = ImportResult()
result.add_warning("test warning")
assert "test warning" in result.warnings
assert result.success is True # 警告不影响成功状态
def test_add_error(self):
"""测试添加错误"""
result = ImportResult()
result.add_error("test error")
assert "test error" in result.errors
assert result.success is False # 错误会导致失败
def test_to_dict(self):
"""测试转换为字典"""
result = ImportResult()
result.imported_tables = {"test_table": 10}
result.add_warning("warning")
d = result.to_dict()
assert d["success"] is True
assert d["imported_tables"] == {"test_table": 10}
assert "warning" in d["warnings"]
class TestAstrBotExporter:
"""AstrBotExporter 类测试"""
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
"""测试初始化"""
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
config_path=str(temp_data_dir / "cmd_config.json"),
)
assert exporter.main_db is mock_main_db
assert exporter.kb_manager is mock_kb_manager
def test_model_to_dict_with_model_dump(self):
"""测试 _model_to_dict 使用 model_dump 方法"""
exporter = AstrBotExporter(main_db=MagicMock())
# 创建一个有 model_dump 方法的模拟对象
mock_record = MagicMock()
mock_record.model_dump.return_value = {"id": 1, "name": "test"}
result = exporter._model_to_dict(mock_record)
assert result == {"id": 1, "name": "test"}
def test_model_to_dict_with_datetime(self):
"""测试 _model_to_dict 处理 datetime 字段"""
exporter = AstrBotExporter(main_db=MagicMock())
now = datetime.now()
mock_record = MagicMock()
mock_record.model_dump.return_value = {"id": 1, "created_at": now}
result = exporter._model_to_dict(mock_record)
assert result["created_at"] == now.isoformat()
def test_add_checksum(self):
"""测试添加校验和"""
exporter = AstrBotExporter(main_db=MagicMock())
exporter._add_checksum("test.json", '{"test": "data"}')
assert "test.json" in exporter._checksums
assert exporter._checksums["test.json"].startswith("sha256:")
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
"""测试生成清单"""
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
)
main_data = {
"platform_stats": [{"id": 1}],
"conversations": [],
"attachments": [],
}
kb_meta_data = {
"knowledge_bases": [],
"kb_documents": [],
}
dir_stats = {
"plugins": {"files": 10, "size": 1024},
"plugin_data": {"files": 5, "size": 512},
}
manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats)
assert manifest["version"] == BACKUP_MANIFEST_VERSION
assert manifest["astrbot_version"] == VERSION
assert "exported_at" in manifest
assert "tables" in manifest
assert "statistics" in manifest
assert "directories" in manifest
assert manifest["statistics"]["main_db"]["platform_stats"] == 1
assert manifest["statistics"]["directories"] == dir_stats
@pytest.mark.asyncio
async def test_export_all_creates_zip(
self, mock_main_db, temp_backup_dir, temp_data_dir
):
"""测试导出创建 ZIP 文件"""
# 设置模拟数据库返回空数据
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=result)
mock_main_db.get_db.return_value = AsyncMock(
__aenter__=AsyncMock(return_value=session),
__aexit__=AsyncMock(return_value=None),
)
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=None,
config_path=str(temp_data_dir / "cmd_config.json"),
)
zip_path = await exporter.export_all(output_dir=str(temp_backup_dir))
assert os.path.exists(zip_path)
assert zip_path.endswith(".zip")
assert "astrbot_backup_" in zip_path
# 验证 ZIP 文件内容
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "manifest.json" in namelist
assert "databases/main_db.json" in namelist
assert "config/cmd_config.json" in namelist
class TestAstrBotImporter:
"""AstrBotImporter 类测试"""
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
"""测试初始化"""
importer = AstrBotImporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
config_path=str(temp_data_dir / "cmd_config.json"),
)
assert importer.main_db is mock_main_db
assert importer.kb_manager is mock_kb_manager
def test_validate_version_match(self):
"""测试版本匹配验证"""
importer = AstrBotImporter(main_db=MagicMock())
manifest = {"astrbot_version": VERSION}
# 不应该抛出异常
importer._validate_version(manifest)
def test_validate_version_major_diff_rejected(self):
"""测试主版本不同被拒绝"""
importer = AstrBotImporter(main_db=MagicMock())
# 使用一个明显不同的主版本
manifest = {"astrbot_version": "0.0.1"}
with pytest.raises(ValueError, match="主版本不兼容"):
importer._validate_version(manifest)
def test_validate_version_minor_diff_allowed(self):
"""测试小版本不同被允许(仅警告)"""
importer = AstrBotImporter(main_db=MagicMock())
# 获取当前主版本
major_version = _get_major_version(VERSION)
# 构造一个同主版本但小版本不同的版本
minor_diff_version = f"{major_version}.999"
manifest = {"astrbot_version": minor_diff_version}
# 不应该抛出异常
importer._validate_version(manifest)
def test_validate_version_missing(self):
"""测试缺少版本信息"""
importer = AstrBotImporter(main_db=MagicMock())
manifest = {}
with pytest.raises(ValueError, match="缺少版本信息"):
importer._validate_version(manifest)
def test_convert_datetime_fields(self):
"""测试 datetime 字段转换"""
importer = AstrBotImporter(main_db=MagicMock())
# 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段)
row = {
"conversation_id": "test-123",
"platform_id": "test",
"user_id": "user1",
"created_at": "2024-01-01T12:00:00",
"updated_at": "2024-01-01T12:00:00",
}
result = importer._convert_datetime_fields(row, ConversationV2)
# created_at 应该被转换为 datetime 对象
assert isinstance(result["created_at"], datetime)
assert isinstance(result["updated_at"], datetime)
@pytest.mark.asyncio
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
"""测试导入不存在的文件"""
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(tmp_path / "nonexistent.zip"))
assert result.success is False
assert any("不存在" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_invalid_zip(self, mock_main_db, tmp_path):
"""测试导入无效的 ZIP 文件"""
# 创建一个无效的文件
invalid_zip = tmp_path / "invalid.zip"
invalid_zip.write_text("not a zip file")
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(invalid_zip))
assert result.success is False
assert any("无效" in err or "ZIP" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_missing_manifest(self, mock_main_db, tmp_path):
"""测试导入缺少 manifest 的 ZIP 文件"""
# 创建一个没有 manifest 的 ZIP 文件
zip_path = tmp_path / "no_manifest.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "test content")
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(zip_path))
assert result.success is False
assert any("manifest" in err.lower() for err in result.errors)
@pytest.mark.asyncio
async def test_import_major_version_mismatch(self, mock_main_db, tmp_path):
"""测试导入主版本不匹配的备份"""
# 创建一个主版本不匹配的备份
zip_path = tmp_path / "old_version.zip"
manifest = {
"version": "1.0",
"astrbot_version": "0.0.1", # 主版本不同
"tables": {"main_db": []},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(zip_path))
assert result.success is False
assert any("主版本不兼容" in err for err in result.errors)
class TestSecureFilename:
"""安全文件名函数测试"""
def test_secure_filename_normal(self):
"""测试正常文件名"""
assert secure_filename("backup.zip") == "backup.zip"
assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip"
def test_secure_filename_path_traversal(self):
"""测试路径遍历攻击"""
assert ".." not in secure_filename("../../../etc/passwd")
assert "/" not in secure_filename("/etc/passwd")
assert "\\" not in secure_filename("..\\..\\windows\\system32")
def test_secure_filename_with_path(self):
"""测试带路径的文件名"""
result = secure_filename("/path/to/backup.zip")
assert result == "backup.zip"
result = secure_filename("C:\\Users\\test\\backup.zip")
assert result == "backup.zip"
def test_secure_filename_special_chars(self):
"""测试特殊字符"""
result = secure_filename('backup<>:"|?*.zip')
# 特殊字符应被替换为下划线
assert "<" not in result
assert ">" not in result
assert ":" not in result
assert '"' not in result
assert "|" not in result
assert "?" not in result
assert "*" not in result
def test_secure_filename_hidden_file(self):
"""测试隐藏文件(前导点)"""
result = secure_filename(".hidden_backup.zip")
assert not result.startswith(".")
def test_secure_filename_empty(self):
"""测试空文件名"""
assert secure_filename("") == "backup"
assert secure_filename("...") == "backup"
def test_generate_unique_filename(self):
"""测试生成唯一文件名"""
result = generate_unique_filename("backup.zip")
# 应包含 uploaded_ 前缀和时间戳
assert result.startswith("uploaded_")
assert result.endswith("_backup.zip")
# 应包含时间戳格式 YYYYMMDD_HHMMSS
assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result)
class TestVersionComparison:
"""版本比较函数测试 - 使用 VersionComparator"""
def test_get_major_version_simple(self):
"""测试提取简单主版本号"""
assert _get_major_version("1.0") == "1.0"
assert _get_major_version("2.1") == "2.1"
assert _get_major_version("4.9.1") == "4.9"
def test_get_major_version_with_prefix(self):
"""测试带 v 前缀的版本号"""
assert _get_major_version("v1.0") == "1.0"
assert _get_major_version("V4.9.1") == "4.9"
def test_get_major_version_with_prerelease(self):
"""测试带预发布标签的版本号"""
assert _get_major_version("4.9.1-beta") == "4.9"
assert _get_major_version("4.9.1-alpha.1") == "4.9"
assert _get_major_version("4.9.1+build123") == "4.9"
def test_get_major_version_single_part(self):
"""测试单部分版本号"""
assert _get_major_version("1") == "1.0"
def test_get_major_version_empty(self):
"""测试空版本号"""
assert _get_major_version("") == "0.0"
def test_compare_versions_equal(self):
"""测试版本相等"""
assert VersionComparator.compare_version("1.0", "1.0") == 0
assert VersionComparator.compare_version("1.0.0", "1.0") == 0
assert VersionComparator.compare_version("2.10", "2.10") == 0
def test_compare_versions_less_than(self):
"""测试版本小于"""
assert VersionComparator.compare_version("1.0", "1.1") == -1
assert (
VersionComparator.compare_version("1.9", "1.10") == -1
) # 关键测试:多位数版本比较
assert VersionComparator.compare_version("1.2", "1.10") == -1
assert VersionComparator.compare_version("1.0", "2.0") == -1
def test_compare_versions_greater_than(self):
"""测试版本大于"""
assert VersionComparator.compare_version("1.1", "1.0") == 1
assert (
VersionComparator.compare_version("1.10", "1.9") == 1
) # 关键测试:多位数版本比较
assert VersionComparator.compare_version("1.10", "1.2") == 1
assert VersionComparator.compare_version("2.0", "1.0") == 1
def test_compare_versions_different_lengths(self):
"""测试不同长度版本比较"""
assert VersionComparator.compare_version("1.0", "1.0.0") == 0
assert VersionComparator.compare_version("1.0", "1.0.1") == -1
assert VersionComparator.compare_version("1.0.1", "1.0") == 1
def test_compare_versions_prerelease(self):
"""测试预发布版本比较"""
# 预发布版本低于正式版本
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0") == -1
assert VersionComparator.compare_version("1.0.0", "1.0.0-beta") == 1
# alpha < beta
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0-beta") == -1
class TestImportPreCheckResult:
"""ImportPreCheckResult 类测试"""
def test_init_default_values(self):
"""测试默认值初始化"""
result = ImportPreCheckResult()
assert result.valid is False
assert result.can_import is False
assert result.version_status == ""
assert result.backup_version == ""
assert result.current_version == VERSION
assert result.confirm_message == ""
assert result.warnings == []
assert result.error == ""
assert result.backup_summary == {}
def test_to_dict(self):
"""测试转换为字典"""
result = ImportPreCheckResult(
valid=True,
can_import=True,
version_status="match",
backup_version="4.9.0",
confirm_message="确认导入?",
warnings=["警告1"],
backup_summary={"tables": ["table1"]},
)
d = result.to_dict()
assert d["valid"] is True
assert d["can_import"] is True
assert d["version_status"] == "match"
assert d["backup_version"] == "4.9.0"
assert d["confirm_message"] == "确认导入?"
assert "警告1" in d["warnings"]
assert d["backup_summary"]["tables"] == ["table1"]
class TestPreCheck:
"""预检查功能测试"""
def test_pre_check_file_not_exists(self, mock_main_db):
"""测试预检查不存在的文件"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check("/nonexistent/file.zip")
assert result.valid is False
assert "不存在" in result.error
def test_pre_check_invalid_zip(self, mock_main_db, tmp_path):
"""测试预检查无效的 ZIP 文件"""
invalid_zip = tmp_path / "invalid.zip"
invalid_zip.write_text("not a zip file")
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(invalid_zip))
assert result.valid is False
assert "ZIP" in result.error or "无效" in result.error
def test_pre_check_missing_manifest(self, mock_main_db, tmp_path):
"""测试预检查缺少 manifest 的 ZIP 文件"""
zip_path = tmp_path / "no_manifest.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "test content")
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is False
assert "manifest" in result.error.lower()
def test_pre_check_version_match(self, mock_main_db, tmp_path):
"""测试预检查版本匹配"""
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": VERSION,
"created_at": "2024-01-01T12:00:00",
"tables": {"platform_stats": 1},
"has_knowledge_bases": True,
"has_config": True,
"directories": ["plugins"],
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True
assert result.can_import is True
assert result.version_status == "match"
assert result.backup_version == VERSION
# confirm_message 现在由前端生成,后端不再生成
assert result.backup_summary["has_knowledge_bases"] is True
def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path):
"""测试预检查小版本差异"""
# 构造一个同主版本但小版本不同的版本
major_version = _get_major_version(VERSION)
minor_diff_version = f"{major_version}.999"
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": minor_diff_version,
"created_at": "2024-01-01T12:00:00",
"tables": {},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True
assert result.can_import is True
assert result.version_status == "minor_diff"
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
# warnings 列表保留用于其他非版本相关的警告
def test_pre_check_major_version_diff(self, mock_main_db, tmp_path):
"""测试预检查主版本差异"""
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": "0.0.1", # 主版本不同
"created_at": "2024-01-01T12:00:00",
"tables": {},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True # 文件有效
assert result.can_import is False # 但不能导入
assert result.version_status == "major_diff"
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
class TestVersionCompatibility:
"""版本兼容性检查测试"""
def test_check_version_compatibility_match(self, mock_main_db):
"""测试版本完全匹配"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility(VERSION)
assert result["status"] == "match"
assert result["can_import"] is True
def test_check_version_compatibility_minor_diff(self, mock_main_db):
"""测试小版本差异"""
major_version = _get_major_version(VERSION)
minor_diff_version = f"{major_version}.999"
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility(minor_diff_version)
assert result["status"] == "minor_diff"
assert result["can_import"] is True
def test_check_version_compatibility_major_diff(self, mock_main_db):
"""测试主版本差异"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility("0.0.1")
assert result["status"] == "major_diff"
assert result["can_import"] is False
def test_check_version_compatibility_empty_version(self, mock_main_db):
"""测试空版本号"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility("")
assert result["status"] == "major_diff"
assert result["can_import"] is False
class TestModelMappings:
"""测试模型映射配置"""
def test_main_db_models_not_empty(self):
"""测试主数据库模型映射非空"""
assert len(MAIN_DB_MODELS) > 0
def test_main_db_models_contain_expected_tables(self):
"""测试主数据库模型映射包含预期的表"""
expected_tables = [
"platform_stats",
"conversations",
"personas",
"preferences",
"attachments",
]
for table in expected_tables:
assert table in MAIN_DB_MODELS, f"Missing table: {table}"
def test_kb_metadata_models_not_empty(self):
"""测试知识库元数据模型映射非空"""
assert len(KB_METADATA_MODELS) > 0
def test_kb_metadata_models_contain_expected_tables(self):
"""测试知识库元数据模型映射包含预期的表"""
expected_tables = [
"knowledge_bases",
"kb_documents",
"kb_media",
]
for table in expected_tables:
assert table in KB_METADATA_MODELS, f"Missing table: {table}"
class TestBackupIntegration:
"""备份集成测试"""
@pytest.mark.asyncio
async def test_export_import_roundtrip(self, tmp_path):
"""测试导出-导入往返"""
backup_dir = tmp_path / "backups"
backup_dir.mkdir()
data_dir = tmp_path / "data"
data_dir.mkdir()
config_path = data_dir / "cmd_config.json"
config_path.write_text(json.dumps({"setting": "value"}))
attachments_dir = data_dir / "attachments"
attachments_dir.mkdir()
# 创建模拟数据库
mock_db = MagicMock()
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=result)
mock_db.get_db.return_value = AsyncMock(
__aenter__=AsyncMock(return_value=session),
__aexit__=AsyncMock(return_value=None),
)
# 导出
exporter = AstrBotExporter(
main_db=mock_db,
kb_manager=None,
config_path=str(config_path),
)
zip_path = await exporter.export_all(output_dir=str(backup_dir))
assert os.path.exists(zip_path)
# 验证 ZIP 内容
with zipfile.ZipFile(zip_path, "r") as zf:
# 读取 manifest
manifest = json.loads(zf.read("manifest.json"))
assert manifest["astrbot_version"] == VERSION
# 读取配置
config = json.loads(zf.read("config/cmd_config.json"))
assert config["setting"] == "value"
# 读取主数据库
main_db = json.loads(zf.read("databases/main_db.json"))
assert "platform_stats" in main_db