feat: supports data backup (#4105)
* feat: 添加数据迁移功能 * test: 添加迁移相关测试 * feat: 备份插件及相关持久化目录 * fix: 修复版本号比较逻辑,添加相关测试 * fix: 清洗文件名,添加相关测试 * fix: 修复安全文件名测试用例断言 * refactor: 优化代码,为备份模块提取公用常量 * feat: 修改备份版本校验逻辑,允许强制小版本间导入 * fix: 修复备份创建时间读取,修复备份相关i18n * refactor(backup): 使用 astrbot_path 统一管理备份目录路径 * fix(backup): 清理备份模块中未使用的导入 * refactor(backup): 统一备份路径与参数并移除未用附件目录 - 通过 astrbot_path 动态获取备份/知识库/数据相关路径 - 移除 exporter/importer 未使用的 attachments_dir/data_root 传参 - 更新备份路由与测试用例的构造参数 * fix(dashboard): alias mermaid to dist entry for Vite prebundle * fix(backup): 放行start-time接口到白名单以处理备份导入后jwt token变化导致无法自动刷新webui的问题 * chore(backup): 统一配置路径以使用动态数据目录 * refactor(backup): 使用 VersionComparator 替代重复的版本比较函数 * style(backup test): format code --------- Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""AstrBot 备份与恢复模块
|
||||
|
||||
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
|
||||
"""
|
||||
|
||||
# 从 constants 模块导入共享常量
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
# 导入导出器和导入器
|
||||
from .exporter import AstrBotExporter
|
||||
from .importer import AstrBotImporter, ImportPreCheckResult
|
||||
|
||||
__all__ = [
|
||||
"AstrBotExporter",
|
||||
"AstrBotImporter",
|
||||
"ImportPreCheckResult",
|
||||
"MAIN_DB_MODELS",
|
||||
"KB_METADATA_MODELS",
|
||||
"get_backup_directories",
|
||||
"BACKUP_MANIFEST_VERSION",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""AstrBot 备份模块共享常量
|
||||
|
||||
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
)
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_data_path,
|
||||
get_astrbot_plugin_path,
|
||||
get_astrbot_t2i_templates_path,
|
||||
get_astrbot_temp_path,
|
||||
get_astrbot_webchat_path,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 共享常量 - 确保导出和导入端配置一致
|
||||
# ============================================================
|
||||
|
||||
# 主数据库模型类映射
|
||||
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
|
||||
"platform_stats": PlatformStat,
|
||||
"conversations": ConversationV2,
|
||||
"personas": Persona,
|
||||
"preferences": Preference,
|
||||
"platform_message_history": PlatformMessageHistory,
|
||||
"platform_sessions": PlatformSession,
|
||||
"attachments": Attachment,
|
||||
"command_configs": CommandConfig,
|
||||
"command_conflicts": CommandConflict,
|
||||
}
|
||||
|
||||
# 知识库元数据模型类映射
|
||||
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
|
||||
"knowledge_bases": KnowledgeBase,
|
||||
"kb_documents": KBDocument,
|
||||
"kb_media": KBMedia,
|
||||
}
|
||||
|
||||
|
||||
def get_backup_directories() -> dict[str, str]:
|
||||
"""获取需要备份的目录列表
|
||||
|
||||
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
|
||||
|
||||
Returns:
|
||||
dict: 键为备份文件中的目录名称,值为目录的绝对路径
|
||||
"""
|
||||
return {
|
||||
"plugins": get_astrbot_plugin_path(), # 插件本体
|
||||
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
|
||||
"config": get_astrbot_config_path(), # 配置目录
|
||||
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
|
||||
"webchat": get_astrbot_webchat_path(), # WebChat 数据
|
||||
"temp": get_astrbot_temp_path(), # 临时文件
|
||||
}
|
||||
|
||||
|
||||
# 备份清单版本号
|
||||
BACKUP_MANIFEST_VERSION = "1.1"
|
||||
@@ -0,0 +1,476 @@
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
负责将所有数据导出为 ZIP 备份文件。
|
||||
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
|
||||
|
||||
class AstrBotExporter:
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
导出内容:
|
||||
- 主数据库所有表(data/data_v4.db)
|
||||
- 知识库元数据(data/knowledge_base/kb.db)
|
||||
- 每个知识库的向量文档数据
|
||||
- 配置文件(data/cmd_config.json)
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self._checksums: dict[str, str] = {}
|
||||
|
||||
async def export_all(
|
||||
self,
|
||||
output_dir: str | None = None,
|
||||
progress_callback: Any | None = None,
|
||||
) -> str:
|
||||
"""导出所有数据到 ZIP 文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
str: 生成的 ZIP 文件路径
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = get_astrbot_backups_path()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||
zip_path = os.path.join(output_dir, zip_filename)
|
||||
|
||||
logger.info(f"开始导出备份到 {zip_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# 1. 导出主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
|
||||
main_data = await self._export_main_database()
|
||||
main_db_json = json.dumps(
|
||||
main_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/main_db.json", main_db_json)
|
||||
self._add_checksum("databases/main_db.json", main_db_json)
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导出完成")
|
||||
|
||||
# 2. 导出知识库数据
|
||||
kb_meta_data: dict[str, Any] = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
"kb_media": [],
|
||||
}
|
||||
if self.kb_manager:
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 0, 100, "正在导出知识库元数据..."
|
||||
)
|
||||
kb_meta_data = await self._export_kb_metadata()
|
||||
kb_meta_json = json.dumps(
|
||||
kb_meta_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/kb_metadata.json", kb_meta_json)
|
||||
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 100, 100, "知识库元数据导出完成"
|
||||
)
|
||||
|
||||
# 导出每个知识库的文档数据
|
||||
kb_insts = self.kb_manager.kb_insts
|
||||
total_kbs = len(kb_insts)
|
||||
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents",
|
||||
idx,
|
||||
total_kbs,
|
||||
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
|
||||
)
|
||||
doc_data = await self._export_kb_documents(kb_helper)
|
||||
doc_json = json.dumps(
|
||||
doc_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
zf.writestr(doc_path, doc_json)
|
||||
self._add_checksum(doc_path, doc_json)
|
||||
|
||||
# 导出 FAISS 索引文件
|
||||
await self._export_faiss_index(zf, kb_helper, kb_id)
|
||||
|
||||
# 导出知识库多媒体文件
|
||||
await self._export_kb_media_files(zf, kb_helper, kb_id)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
|
||||
)
|
||||
|
||||
# 3. 导出配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_content = f.read()
|
||||
zf.writestr("config/cmd_config.json", config_content)
|
||||
self._add_checksum("config/cmd_config.json", config_content)
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导出完成")
|
||||
|
||||
# 4. 导出附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导出附件...")
|
||||
await self._export_attachments(zf, main_data.get("attachments", []))
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导出完成")
|
||||
|
||||
# 5. 导出插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导出插件和数据目录..."
|
||||
)
|
||||
dir_stats = await self._export_directories(zf)
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导出完成")
|
||||
|
||||
# 6. 生成 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 0, 100, "正在生成清单...")
|
||||
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", manifest_json)
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 100, 100, "清单生成完成")
|
||||
|
||||
logger.info(f"备份导出完成: {zip_path}")
|
||||
return zip_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份导出失败: {e}")
|
||||
# 清理失败的文件
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
raise
|
||||
|
||||
async def _export_main_database(self) -> dict[str, list[dict]]:
|
||||
"""导出主数据库所有表"""
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
|
||||
"""导出知识库元数据库"""
|
||||
if not self.kb_manager:
|
||||
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
|
||||
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
|
||||
"""导出知识库的文档块数据"""
|
||||
try:
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
|
||||
vec_db: FaissVecDB = kb_helper.vec_db
|
||||
if not vec_db or not vec_db.document_storage:
|
||||
return {"documents": []}
|
||||
|
||||
# 获取所有文档
|
||||
docs = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={},
|
||||
offset=0,
|
||||
limit=None, # 获取全部
|
||||
)
|
||||
|
||||
return {"documents": docs}
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库文档失败: {e}")
|
||||
return {"documents": []}
|
||||
|
||||
async def _export_faiss_index(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_helper: Any,
|
||||
kb_id: str,
|
||||
) -> None:
|
||||
"""导出 FAISS 索引文件"""
|
||||
try:
|
||||
index_path = kb_helper.kb_dir / "index.faiss"
|
||||
if index_path.exists():
|
||||
archive_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
zf.write(str(index_path), archive_path)
|
||||
logger.debug(f"导出 FAISS 索引: {archive_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"导出 FAISS 索引失败: {e}")
|
||||
|
||||
async def _export_kb_media_files(
|
||||
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
|
||||
) -> None:
|
||||
"""导出知识库的多媒体文件"""
|
||||
try:
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if not media_dir.exists():
|
||||
return
|
||||
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(kb_helper.kb_dir)
|
||||
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库媒体文件失败: {e}")
|
||||
|
||||
async def _export_directories(
|
||||
self, zf: zipfile.ZipFile
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""导出插件和其他数据目录
|
||||
|
||||
Returns:
|
||||
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
|
||||
"""
|
||||
stats: dict[str, dict[str, int]] = {}
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name, dir_path in backup_directories.items():
|
||||
full_path = Path(dir_path)
|
||||
if not full_path.exists():
|
||||
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||
continue
|
||||
|
||||
file_count = 0
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
for root, dirs, files in os.walk(full_path):
|
||||
# 跳过 __pycache__ 目录
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
|
||||
for file in files:
|
||||
# 跳过 .pyc 文件
|
||||
if file.endswith(".pyc"):
|
||||
continue
|
||||
|
||||
file_path = Path(root) / file
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(full_path)
|
||||
archive_path = f"directories/{dir_name}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
file_count += 1
|
||||
total_size += file_path.stat().st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"导出文件 {file_path} 失败: {e}")
|
||||
|
||||
stats[dir_name] = {"files": file_count, "size": total_size}
|
||||
logger.debug(
|
||||
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出目录 {dir_path} 失败: {e}")
|
||||
stats[dir_name] = {"files": 0, "size": 0}
|
||||
|
||||
return stats
|
||||
|
||||
async def _export_attachments(
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""导出附件文件"""
|
||||
for attachment in attachments:
|
||||
try:
|
||||
file_path = attachment.get("path", "")
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 使用 attachment_id 作为文件名
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出附件失败: {e}")
|
||||
|
||||
def _model_to_dict(self, record: Any) -> dict:
|
||||
"""将 SQLModel 实例转换为字典
|
||||
|
||||
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
|
||||
"""
|
||||
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
|
||||
if hasattr(record, "model_dump"):
|
||||
data = record.model_dump(mode="python")
|
||||
# 处理 datetime 类型
|
||||
for key, value in data.items():
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
return data
|
||||
|
||||
# 回退到手动提取
|
||||
data = {}
|
||||
# 使用 inspect 获取表信息
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
mapper = sa_inspect(record.__class__)
|
||||
for column in mapper.columns:
|
||||
value = getattr(record, column.name)
|
||||
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
data[column.name] = value
|
||||
return data
|
||||
|
||||
def _add_checksum(self, path: str, content: str | bytes) -> None:
|
||||
"""计算并添加文件校验和"""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
self._checksums[path] = f"sha256:{checksum}"
|
||||
|
||||
def _generate_manifest(
|
||||
self,
|
||||
main_data: dict[str, list[dict]],
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
dir_stats: dict[str, dict[str, int]] | None = None,
|
||||
) -> dict:
|
||||
"""生成备份清单"""
|
||||
if dir_stats is None:
|
||||
dir_stats = {}
|
||||
# 收集知识库 ID
|
||||
kb_document_tables = {}
|
||||
if self.kb_manager:
|
||||
for kb_id in self.kb_manager.kb_insts.keys():
|
||||
kb_document_tables[kb_id] = "documents"
|
||||
|
||||
# 收集附件文件列表
|
||||
attachment_files = []
|
||||
for attachment in main_data.get("attachments", []):
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
path = attachment.get("path", "")
|
||||
if attachment_id and path:
|
||||
ext = os.path.splitext(path)[1]
|
||||
attachment_files.append(f"{attachment_id}{ext}")
|
||||
|
||||
# 收集知识库媒体文件
|
||||
kb_media_files: dict[str, list[str]] = {}
|
||||
if self.kb_manager:
|
||||
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
|
||||
media_files: list[str] = []
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if media_dir.exists():
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
media_files.append(file)
|
||||
if media_files:
|
||||
kb_media_files[kb_id] = media_files
|
||||
|
||||
manifest = {
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
"kb_db": "v1",
|
||||
},
|
||||
"tables": {
|
||||
"main_db": list(main_data.keys()),
|
||||
"kb_metadata": list(kb_meta_data.keys()),
|
||||
"kb_documents": kb_document_tables,
|
||||
},
|
||||
"files": {
|
||||
"attachments": attachment_files,
|
||||
"kb_media": kb_media_files,
|
||||
},
|
||||
"directories": list(dir_stats.keys()),
|
||||
"checksums": self._checksums,
|
||||
"statistics": {
|
||||
"main_db": {
|
||||
table: len(records) for table, records in main_data.items()
|
||||
},
|
||||
"kb_metadata": {
|
||||
table: len(records) for table, records in kb_meta_data.items()
|
||||
},
|
||||
"directories": dir_stats,
|
||||
},
|
||||
}
|
||||
|
||||
return manifest
|
||||
@@ -0,0 +1,761 @@
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
负责从 ZIP 备份文件恢复所有数据。
|
||||
导入时进行版本校验:
|
||||
- 主版本(前两位)不同时直接拒绝导入
|
||||
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||||
- 版本匹配时也需要用户确认
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_knowledge_base_path,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
def _get_major_version(version_str: str) -> str:
|
||||
"""提取版本的主版本部分(前两位)
|
||||
|
||||
Args:
|
||||
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||||
|
||||
Returns:
|
||||
主版本字符串,如 "4.9", "4.10"
|
||||
"""
|
||||
if not version_str:
|
||||
return "0.0"
|
||||
# 移除 v 前缀和预发布标签
|
||||
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||||
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[0]}.{parts[1]}"
|
||||
elif len(parts) == 1 and parts[0]:
|
||||
return f"{parts[0]}.0"
|
||||
return "0.0"
|
||||
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
KB_PATH = get_astrbot_knowledge_base_path()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportPreCheckResult:
|
||||
"""导入预检查结果
|
||||
|
||||
用于在实际导入前检查备份文件的版本兼容性,
|
||||
并返回确认信息让用户决定是否继续导入。
|
||||
"""
|
||||
|
||||
# 检查是否通过(文件有效且版本可导入)
|
||||
valid: bool = False
|
||||
# 是否可以导入(版本兼容)
|
||||
can_import: bool = False
|
||||
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||||
version_status: str = ""
|
||||
# 备份文件中的 AstrBot 版本
|
||||
backup_version: str = ""
|
||||
# 当前运行的 AstrBot 版本
|
||||
current_version: str = VERSION
|
||||
# 备份创建时间
|
||||
backup_time: str = ""
|
||||
# 确认消息(显示给用户)
|
||||
confirm_message: str = ""
|
||||
# 警告消息列表
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# 错误消息(如果检查失败)
|
||||
error: str = ""
|
||||
# 备份包含的内容摘要
|
||||
backup_summary: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"valid": self.valid,
|
||||
"can_import": self.can_import,
|
||||
"version_status": self.version_status,
|
||||
"backup_version": self.backup_version,
|
||||
"current_version": self.current_version,
|
||||
"backup_time": self.backup_time,
|
||||
"confirm_message": self.confirm_message,
|
||||
"warnings": self.warnings,
|
||||
"error": self.error,
|
||||
"backup_summary": self.backup_summary,
|
||||
}
|
||||
|
||||
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.success = True
|
||||
self.imported_tables: dict[str, int] = {}
|
||||
self.imported_files: dict[str, int] = {}
|
||||
self.imported_directories: dict[str, int] = {}
|
||||
self.warnings: list[str] = []
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_warning(self, msg: str) -> None:
|
||||
self.warnings.append(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
def add_error(self, msg: str) -> None:
|
||||
self.errors.append(msg)
|
||||
self.success = False
|
||||
logger.error(msg)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"imported_tables": self.imported_tables,
|
||||
"imported_files": self.imported_files,
|
||||
"imported_directories": self.imported_directories,
|
||||
"warnings": self.warnings,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class AstrBotImporter:
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
导入备份文件中的所有数据,包括:
|
||||
- 主数据库所有表
|
||||
- 知识库元数据和文档
|
||||
- 配置文件
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
kb_root_dir: str = KB_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self.kb_root_dir = kb_root_dir
|
||||
|
||||
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||||
"""预检查备份文件
|
||||
|
||||
在实际导入前检查备份文件的有效性和版本兼容性。
|
||||
返回检查结果供前端显示确认对话框。
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
|
||||
Returns:
|
||||
ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
result = ImportPreCheckResult()
|
||||
result.current_version = VERSION
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.error = f"备份文件不存在: {zip_path}"
|
||||
return result
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.error = f"manifest.json 格式错误: {e}"
|
||||
return result
|
||||
|
||||
# 提取基本信息
|
||||
result.backup_version = manifest.get("astrbot_version", "未知")
|
||||
result.backup_time = manifest.get("exported_at", "未知")
|
||||
result.valid = True
|
||||
|
||||
# 构建备份摘要
|
||||
result.backup_summary = {
|
||||
"tables": list(manifest.get("tables", {}).keys()),
|
||||
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||||
"has_config": manifest.get("has_config", False),
|
||||
"directories": manifest.get("directories", []),
|
||||
}
|
||||
|
||||
# 检查版本兼容性
|
||||
version_check = self._check_version_compatibility(result.backup_version)
|
||||
result.version_status = version_check["status"]
|
||||
result.can_import = version_check["can_import"]
|
||||
|
||||
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||||
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.error = "无效的 ZIP 文件"
|
||||
return result
|
||||
except Exception as e:
|
||||
result.error = f"检查备份文件失败: {e}"
|
||||
return result
|
||||
|
||||
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||||
"""检查版本兼容性
|
||||
|
||||
规则:
|
||||
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||||
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||||
|
||||
Returns:
|
||||
dict: {status, can_import, message}
|
||||
"""
|
||||
if not backup_version:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": "备份文件缺少版本信息",
|
||||
}
|
||||
|
||||
# 提取主版本(前两位)进行比较
|
||||
backup_major = _get_major_version(backup_version)
|
||||
current_major = _get_major_version(VERSION)
|
||||
|
||||
# 比较主版本
|
||||
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": (
|
||||
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||||
),
|
||||
}
|
||||
|
||||
# 比较完整版本
|
||||
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||||
if version_cmp != 0:
|
||||
return {
|
||||
"status": "minor_diff",
|
||||
"can_import": True,
|
||||
"message": (
|
||||
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "match",
|
||||
"can_import": True,
|
||||
"message": "版本匹配",
|
||||
}
|
||||
|
||||
async def import_all(
|
||||
self,
|
||||
zip_path: str,
|
||||
mode: str = "replace", # "replace" 清空后导入
|
||||
progress_callback: Any | None = None,
|
||||
) -> ImportResult:
|
||||
"""从 ZIP 文件导入所有数据
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.add_error(f"备份文件不存在: {zip_path}")
|
||||
return result
|
||||
|
||||
logger.info(f"开始从 {zip_path} 导入备份")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 1. 读取并验证 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||||
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.add_error("备份文件缺少 manifest.json")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.add_error(f"manifest.json 格式错误: {e}")
|
||||
return result
|
||||
|
||||
# 版本校验
|
||||
try:
|
||||
self._validate_version(manifest)
|
||||
except ValueError as e:
|
||||
result.add_error(str(e))
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 100, 100, "验证完成")
|
||||
|
||||
# 2. 导入主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||||
|
||||
try:
|
||||
main_data_content = zf.read("databases/main_db.json")
|
||||
main_data = json.loads(main_data_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_main_db()
|
||||
|
||||
imported = await self._import_main_database(main_data)
|
||||
result.imported_tables.update(imported)
|
||||
except Exception as e:
|
||||
result.add_error(f"导入主数据库失败: {e}")
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||||
|
||||
# 3. 导入知识库
|
||||
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||||
|
||||
try:
|
||||
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||||
kb_meta_data = json.loads(kb_meta_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_kb_data()
|
||||
|
||||
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||||
|
||||
# 4. 导入配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||||
|
||||
if "config/cmd_config.json" in zf.namelist():
|
||||
try:
|
||||
config_content = zf.read("config/cmd_config.json")
|
||||
# 备份现有配置
|
||||
if os.path.exists(self.config_path):
|
||||
backup_path = f"{self.config_path}.bak"
|
||||
shutil.copy2(self.config_path, backup_path)
|
||||
|
||||
with open(self.config_path, "wb") as f:
|
||||
f.write(config_content)
|
||||
result.imported_files["config"] = 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入配置文件失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||||
|
||||
# 5. 导入附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||||
|
||||
attachment_count = await self._import_attachments(
|
||||
zf, main_data.get("attachments", [])
|
||||
)
|
||||
result.imported_files["attachments"] = attachment_count
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||||
|
||||
# 6. 导入插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导入插件和数据目录..."
|
||||
)
|
||||
|
||||
dir_stats = await self._import_directories(zf, manifest, result)
|
||||
result.imported_directories = dir_stats
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导入完成")
|
||||
|
||||
logger.info(f"备份导入完成: {result.to_dict()}")
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.add_error("无效的 ZIP 文件")
|
||||
return result
|
||||
except Exception as e:
|
||||
result.add_error(f"导入失败: {e}")
|
||||
return result
|
||||
|
||||
def _validate_version(self, manifest: dict) -> None:
|
||||
"""验证版本兼容性 - 仅允许相同主版本导入
|
||||
|
||||
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||||
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||||
"""
|
||||
backup_version = manifest.get("astrbot_version")
|
||||
if not backup_version:
|
||||
raise ValueError("备份文件缺少版本信息")
|
||||
|
||||
# 使用新的版本兼容性检查
|
||||
version_check = self._check_version_compatibility(backup_version)
|
||||
|
||||
if version_check["status"] == "major_diff":
|
||||
raise ValueError(version_check["message"])
|
||||
|
||||
# minor_diff 和 match 都允许导入
|
||||
if version_check["status"] == "minor_diff":
|
||||
logger.warning(f"版本差异警告: {version_check['message']}")
|
||||
|
||||
async def _clear_main_db(self) -> None:
|
||||
"""清空主数据库所有表"""
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||
|
||||
async def _clear_kb_data(self) -> None:
|
||||
"""清空知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 清空知识库元数据表
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空知识库表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||||
|
||||
# 删除知识库文件目录
|
||||
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||||
try:
|
||||
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||||
await kb_helper.terminate()
|
||||
if kb_helper.kb_dir.exists():
|
||||
shutil.rmtree(kb_helper.kb_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||||
|
||||
self.kb_manager.kb_insts.clear()
|
||||
|
||||
async def _import_main_database(
|
||||
self, data: dict[str, list[dict]]
|
||||
) -> dict[str, int]:
|
||||
"""导入主数据库数据"""
|
||||
imported: dict[str, int] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in data.items():
|
||||
model_class = MAIN_DB_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
logger.warning(f"未知的表: {table_name}")
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
# 转换 datetime 字符串为 datetime 对象
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||||
|
||||
imported[table_name] = count
|
||||
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||||
|
||||
return imported
|
||||
|
||||
async def _import_knowledge_bases(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
result: ImportResult,
|
||||
) -> None:
|
||||
"""导入知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 1. 导入知识库元数据
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in kb_meta_data.items():
|
||||
model_class = KB_METADATA_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||||
|
||||
result.imported_tables[f"kb_{table_name}"] = count
|
||||
|
||||
# 2. 导入每个知识库的文档和文件
|
||||
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||||
kb_id = kb_data.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
|
||||
# 创建知识库目录
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 导入文档数据
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
if doc_path in zf.namelist():
|
||||
try:
|
||||
doc_content = zf.read(doc_path)
|
||||
doc_data = json.loads(doc_content)
|
||||
|
||||
# 导入到文档存储数据库
|
||||
await self._import_kb_documents(kb_id, doc_data)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||||
|
||||
# 导入 FAISS 索引
|
||||
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
if faiss_path in zf.namelist():
|
||||
try:
|
||||
target_path = kb_dir / "index.faiss"
|
||||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||
|
||||
# 导入媒体文件
|
||||
media_prefix = f"files/kb_media/{kb_id}/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(media_prefix):
|
||||
try:
|
||||
rel_path = name[len(media_prefix) :]
|
||||
target_path = kb_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||
|
||||
# 3. 重新加载知识库实例
|
||||
await self.kb_manager.load_kbs()
|
||||
|
||||
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||||
"""导入知识库文档到向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
doc_db_path = kb_dir / "doc.db"
|
||||
|
||||
# 初始化文档存储
|
||||
doc_storage = DocumentStorage(str(doc_db_path))
|
||||
await doc_storage.initialize()
|
||||
|
||||
try:
|
||||
documents = doc_data.get("documents", [])
|
||||
for doc in documents:
|
||||
try:
|
||||
await doc_storage.insert_document(
|
||||
doc_id=doc.get("doc_id", ""),
|
||||
text=doc.get("text", ""),
|
||||
metadata=json.loads(doc.get("metadata", "{}")),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导入文档块失败: {e}")
|
||||
finally:
|
||||
await doc_storage.close()
|
||||
|
||||
async def _import_attachments(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
attachments: list[dict],
|
||||
) -> int:
|
||||
"""导入附件文件"""
|
||||
count = 0
|
||||
|
||||
attachments_dir = Path(self.config_path).parent / "attachments"
|
||||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
attachment_prefix = "files/attachments/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||||
try:
|
||||
# 从附件记录中找到原始路径
|
||||
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||||
original_path = None
|
||||
for att in attachments:
|
||||
if att.get("attachment_id") == attachment_id:
|
||||
original_path = att.get("path")
|
||||
break
|
||||
|
||||
if original_path:
|
||||
target_path = Path(original_path)
|
||||
else:
|
||||
target_path = attachments_dir / os.path.basename(name)
|
||||
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||
|
||||
return count
|
||||
|
||||
async def _import_directories(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
manifest: dict,
|
||||
result: ImportResult,
|
||||
) -> dict[str, int]:
|
||||
"""导入插件和其他数据目录
|
||||
|
||||
Args:
|
||||
zf: ZIP 文件对象
|
||||
manifest: 备份清单
|
||||
result: 导入结果对象
|
||||
|
||||
Returns:
|
||||
dict: 每个目录导入的文件数量
|
||||
"""
|
||||
dir_stats: dict[str, int] = {}
|
||||
|
||||
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||||
backup_version = manifest.get("version", "1.0")
|
||||
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||||
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||||
return dir_stats
|
||||
|
||||
backed_up_dirs = manifest.get("directories", [])
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name in backed_up_dirs:
|
||||
if dir_name not in backup_directories:
|
||||
result.add_warning(f"未知的目录类型: {dir_name}")
|
||||
continue
|
||||
|
||||
target_dir = Path(backup_directories[dir_name])
|
||||
archive_prefix = f"directories/{dir_name}/"
|
||||
|
||||
file_count = 0
|
||||
|
||||
try:
|
||||
# 获取该目录下的所有文件
|
||||
dir_files = [
|
||||
name
|
||||
for name in zf.namelist()
|
||||
if name.startswith(archive_prefix) and name != archive_prefix
|
||||
]
|
||||
|
||||
if not dir_files:
|
||||
continue
|
||||
|
||||
# 备份现有目录(如果存在)
|
||||
if target_dir.exists():
|
||||
backup_path = Path(f"{target_dir}.bak")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(str(target_dir), str(backup_path))
|
||||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||
|
||||
# 创建目标目录
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解压文件
|
||||
for name in dir_files:
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = name[len(archive_prefix) :]
|
||||
if not rel_path: # 跳过目录条目
|
||||
continue
|
||||
|
||||
target_path = target_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
file_count += 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||
|
||||
dir_stats[dir_name] = file_count
|
||||
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||||
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||||
dir_stats[dir_name] = 0
|
||||
|
||||
return dir_stats
|
||||
|
||||
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||||
"""转换 datetime 字符串字段为 datetime 对象"""
|
||||
result = row.copy()
|
||||
|
||||
# 获取模型的 datetime 字段
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
try:
|
||||
mapper = sa_inspect(model_class)
|
||||
for column in mapper.columns:
|
||||
if column.name in result and result[column.name] is not None:
|
||||
# 检查是否是 datetime 类型的列
|
||||
from sqlalchemy import DateTime
|
||||
|
||||
if isinstance(column.type, DateTime):
|
||||
value = result[column.name]
|
||||
if isinstance(value, str):
|
||||
# 解析 ISO 格式的日期时间字符串
|
||||
result[column.name] = datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -5,6 +5,10 @@
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
插件数据目录路径:固定为数据目录下的 plugin_data 目录
|
||||
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
|
||||
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
|
||||
临时文件目录路径:固定为数据目录下的 temp 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -37,3 +41,33 @@ def get_astrbot_config_path() -> str:
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_data_path() -> str:
|
||||
"""获取Astrbot插件数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
|
||||
|
||||
|
||||
def get_astrbot_t2i_templates_path() -> str:
|
||||
"""获取Astrbot T2I 模板目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
|
||||
|
||||
|
||||
def get_astrbot_webchat_path() -> str:
|
||||
"""获取Astrbot WebChat 数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
|
||||
|
||||
|
||||
def get_astrbot_temp_path() -> str:
|
||||
"""获取Astrbot临时文件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
|
||||
|
||||
|
||||
def get_astrbot_knowledge_base_path() -> str:
|
||||
"""获取Astrbot知识库根目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
|
||||
|
||||
|
||||
def get_astrbot_backups_path() -> str:
|
||||
"""获取Astrbot备份目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .auth import AuthRoute
|
||||
from .backup import BackupRoute
|
||||
from .chat import ChatRoute
|
||||
from .command import CommandRoute
|
||||
from .config import ConfigRoute
|
||||
@@ -17,6 +18,7 @@ from .update import UpdateRoute
|
||||
|
||||
__all__ = [
|
||||
"AuthRoute",
|
||||
"BackupRoute",
|
||||
"ChatRoute",
|
||||
"CommandRoute",
|
||||
"ConfigRoute",
|
||||
|
||||
@@ -0,0 +1,589 @@
|
||||
"""备份管理 API 路由"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from quart import request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.backup.importer import AstrBotImporter
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
"""清洗文件名,移除路径遍历字符和危险字符
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
安全的文件名
|
||||
"""
|
||||
# 跨平台处理:先将反斜杠替换为正斜杠,再取文件名
|
||||
filename = filename.replace("\\", "/")
|
||||
# 仅保留文件名部分,移除路径
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# 替换路径遍历字符
|
||||
filename = filename.replace("..", "_")
|
||||
|
||||
# 仅保留字母、数字、下划线、连字符、点
|
||||
filename = re.sub(r"[^\w\-.]", "_", filename)
|
||||
|
||||
# 移除前导点(隐藏文件)和尾部点
|
||||
filename = filename.strip(".")
|
||||
|
||||
# 如果文件名为空或只包含下划线,生成一个默认名称
|
||||
if not filename or filename.replace("_", "") == "":
|
||||
filename = "backup"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def generate_unique_filename(original_filename: str) -> str:
|
||||
"""生成唯一的文件名,添加时间戳前缀
|
||||
|
||||
Args:
|
||||
original_filename: 原始文件名(已清洗)
|
||||
|
||||
Returns:
|
||||
唯一的文件名
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
name, ext = os.path.splitext(original_filename)
|
||||
return f"uploaded_{timestamp}_{name}{ext}"
|
||||
|
||||
|
||||
class BackupRoute(Route):
|
||||
"""备份管理路由
|
||||
|
||||
提供备份导出、导入、列表等 API 接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.backup_dir = get_astrbot_backups_path()
|
||||
self.data_dir = get_astrbot_data_path()
|
||||
|
||||
# 任务状态跟踪
|
||||
self.backup_tasks: dict[str, dict] = {}
|
||||
self.backup_progress: dict[str, dict] = {}
|
||||
|
||||
# 注册路由
|
||||
self.routes = {
|
||||
"/backup/list": ("GET", self.list_backups),
|
||||
"/backup/export": ("POST", self.export_backup),
|
||||
"/backup/upload": ("POST", self.upload_backup), # 上传文件
|
||||
"/backup/check": ("POST", self.check_backup), # 预检查
|
||||
"/backup/import": ("POST", self.import_backup), # 确认导入
|
||||
"/backup/progress": ("GET", self.get_progress),
|
||||
"/backup/download": ("GET", self.download_backup),
|
||||
"/backup/delete": ("POST", self.delete_backup),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
|
||||
"""初始化任务状态"""
|
||||
self.backup_tasks[task_id] = {
|
||||
"type": task_type,
|
||||
"status": status,
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
self.backup_progress[task_id] = {
|
||||
"status": status,
|
||||
"stage": "waiting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
"message": "",
|
||||
}
|
||||
|
||||
def _set_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
result: dict | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""设置任务结果"""
|
||||
if task_id in self.backup_tasks:
|
||||
self.backup_tasks[task_id]["status"] = status
|
||||
self.backup_tasks[task_id]["result"] = result
|
||||
self.backup_tasks[task_id]["error"] = error
|
||||
if task_id in self.backup_progress:
|
||||
self.backup_progress[task_id]["status"] = status
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
stage: str | None = None,
|
||||
current: int | None = None,
|
||||
total: int | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""更新任务进度"""
|
||||
if task_id not in self.backup_progress:
|
||||
return
|
||||
p = self.backup_progress[task_id]
|
||||
if status is not None:
|
||||
p["status"] = status
|
||||
if stage is not None:
|
||||
p["stage"] = stage
|
||||
if current is not None:
|
||||
p["current"] = current
|
||||
if total is not None:
|
||||
p["total"] = total
|
||||
if message is not None:
|
||||
p["message"] = message
|
||||
|
||||
def _make_progress_callback(self, task_id: str):
|
||||
"""创建进度回调函数"""
|
||||
|
||||
async def _callback(stage: str, current: int, total: int, message: str = ""):
|
||||
self._update_progress(
|
||||
task_id,
|
||||
status="processing",
|
||||
stage=stage,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
|
||||
return _callback
|
||||
|
||||
async def list_backups(self):
|
||||
"""获取备份列表
|
||||
|
||||
Query 参数:
|
||||
- page: 页码 (默认 1)
|
||||
- page_size: 每页数量 (默认 20)
|
||||
"""
|
||||
try:
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
# 确保备份目录存在
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取所有备份文件
|
||||
backup_files = []
|
||||
for filename in os.listdir(self.backup_dir):
|
||||
if filename.endswith(".zip") and filename.startswith("astrbot_backup_"):
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
stat = os.stat(file_path)
|
||||
backup_files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": stat.st_size,
|
||||
"created_at": stat.st_mtime,
|
||||
}
|
||||
)
|
||||
|
||||
# 按创建时间倒序排序
|
||||
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
# 分页
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
items = backup_files[start:end]
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"items": items,
|
||||
"total": len(backup_files),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取备份列表失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取备份列表失败: {e!s}").__dict__
|
||||
|
||||
async def export_backup(self):
|
||||
"""创建备份
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询导出进度
|
||||
"""
|
||||
try:
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化任务状态
|
||||
self._init_task(task_id, "export", "pending")
|
||||
|
||||
# 启动后台导出任务
|
||||
asyncio.create_task(self._background_export_task(task_id))
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"message": "export task created, processing in background",
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"创建备份失败: {e!s}").__dict__
|
||||
|
||||
async def _background_export_task(self, task_id: str):
|
||||
"""后台导出任务"""
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
|
||||
# 获取知识库管理器
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
exporter = AstrBotExporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 创建进度回调
|
||||
progress_callback = self._make_progress_callback(task_id)
|
||||
|
||||
# 执行导出
|
||||
zip_path = await exporter.export_all(
|
||||
output_dir=self.backup_dir,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# 设置成功结果
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result={
|
||||
"filename": os.path.basename(zip_path),
|
||||
"path": zip_path,
|
||||
"size": os.path.getsize(zip_path),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"后台导出任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(e))
|
||||
|
||||
async def upload_backup(self):
|
||||
"""上传备份文件
|
||||
|
||||
将备份文件上传到服务器,返回保存的文件名。
|
||||
上传后应调用 check_backup 进行预检查。
|
||||
|
||||
Form Data:
|
||||
- file: 备份文件 (.zip)
|
||||
|
||||
返回:
|
||||
- filename: 保存的文件名
|
||||
"""
|
||||
try:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return Response().error("缺少备份文件").__dict__
|
||||
|
||||
file = files["file"]
|
||||
if not file.filename or not file.filename.endswith(".zip"):
|
||||
return Response().error("请上传 ZIP 格式的备份文件").__dict__
|
||||
|
||||
# 清洗文件名并生成唯一名称,防止路径遍历和覆盖
|
||||
safe_filename = secure_filename(file.filename)
|
||||
unique_filename = generate_unique_filename(safe_filename)
|
||||
|
||||
# 保存上传的文件
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
zip_path = os.path.join(self.backup_dir, unique_filename)
|
||||
await file.save(zip_path)
|
||||
|
||||
logger.info(
|
||||
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"filename": unique_filename,
|
||||
"original_filename": file.filename,
|
||||
"size": os.path.getsize(zip_path),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传备份文件失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"上传备份文件失败: {e!s}").__dict__
|
||||
|
||||
async def check_backup(self):
|
||||
"""预检查备份文件
|
||||
|
||||
检查备份文件的版本兼容性,返回确认信息。
|
||||
用户确认后调用 import_backup 执行导入。
|
||||
|
||||
JSON Body:
|
||||
- filename: 已上传的备份文件名
|
||||
|
||||
返回:
|
||||
- ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少 filename 参数").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||
|
||||
# 获取知识库管理器(用于构造 importer)
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 执行预检查
|
||||
check_result = importer.pre_check(zip_path)
|
||||
|
||||
return Response().ok(check_result.to_dict()).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"预检查备份文件失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"预检查备份文件失败: {e!s}").__dict__
|
||||
|
||||
async def import_backup(self):
|
||||
"""执行备份导入
|
||||
|
||||
在用户确认后执行实际的导入操作。
|
||||
需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。
|
||||
|
||||
JSON Body:
|
||||
- filename: 已上传的备份文件名(必填)
|
||||
- confirmed: 用户已确认(必填,必须为 true)
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询导入进度
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
confirmed = data.get("confirmed", False)
|
||||
|
||||
if not filename:
|
||||
return Response().error("缺少 filename 参数").__dict__
|
||||
|
||||
if not confirmed:
|
||||
return (
|
||||
Response()
|
||||
.error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化任务状态
|
||||
self._init_task(task_id, "import", "pending")
|
||||
|
||||
# 启动后台导入任务
|
||||
asyncio.create_task(self._background_import_task(task_id, zip_path))
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"message": "import task created, processing in background",
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"导入备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"导入备份失败: {e!s}").__dict__
|
||||
|
||||
async def _background_import_task(self, task_id: str, zip_path: str):
|
||||
"""后台导入任务"""
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
|
||||
# 获取知识库管理器
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 创建进度回调
|
||||
progress_callback = self._make_progress_callback(task_id)
|
||||
|
||||
# 执行导入
|
||||
result = await importer.import_all(
|
||||
zip_path=zip_path,
|
||||
mode="replace",
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# 设置结果
|
||||
if result.success:
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result=result.to_dict(),
|
||||
)
|
||||
else:
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"failed",
|
||||
error="; ".join(result.errors),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"后台导入任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(e))
|
||||
|
||||
async def get_progress(self):
|
||||
"""获取任务进度
|
||||
|
||||
Query 参数:
|
||||
- task_id: 任务 ID (必填)
|
||||
"""
|
||||
try:
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return Response().error("缺少参数 task_id").__dict__
|
||||
|
||||
if task_id not in self.backup_tasks:
|
||||
return Response().error("找不到该任务").__dict__
|
||||
|
||||
task_info = self.backup_tasks[task_id]
|
||||
status = task_info["status"]
|
||||
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"type": task_info["type"],
|
||||
"status": status,
|
||||
}
|
||||
|
||||
# 如果任务正在处理,返回进度信息
|
||||
if status == "processing" and task_id in self.backup_progress:
|
||||
response_data["progress"] = self.backup_progress[task_id]
|
||||
|
||||
# 如果任务完成,返回结果
|
||||
if status == "completed":
|
||||
response_data["result"] = task_info["result"]
|
||||
|
||||
# 如果任务失败,返回错误信息
|
||||
if status == "failed":
|
||||
response_data["error"] = task_info["error"]
|
||||
|
||||
return Response().ok(response_data).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务进度失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取任务进度失败: {e!s}").__dict__
|
||||
|
||||
async def download_backup(self):
|
||||
"""下载备份文件
|
||||
|
||||
Query 参数:
|
||||
- filename: 备份文件名 (必填)
|
||||
"""
|
||||
try:
|
||||
filename = request.args.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return Response().error("备份文件不存在").__dict__
|
||||
|
||||
return await send_file(
|
||||
file_path,
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"下载备份失败: {e!s}").__dict__
|
||||
|
||||
async def delete_backup(self):
|
||||
"""删除备份文件
|
||||
|
||||
Body:
|
||||
- filename: 备份文件名 (必填)
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return Response().error("备份文件不存在").__dict__
|
||||
|
||||
os.remove(file_path)
|
||||
return Response().ok(message="删除备份成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"删除备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除备份失败: {e!s}").__dict__
|
||||
@@ -19,6 +19,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
|
||||
from .routes import *
|
||||
from .routes.backup import BackupRoute
|
||||
from .routes.platform import PlatformRoute
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
@@ -85,6 +86,7 @@ class AstrBotDashboard:
|
||||
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
||||
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
||||
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
||||
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
@@ -108,7 +110,12 @@ class AstrBotDashboard:
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return None
|
||||
allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
|
||||
allowed_endpoints = [
|
||||
"/api/auth/login",
|
||||
"/api/file",
|
||||
"/api/platform/webhook",
|
||||
"/api/stat/start-time",
|
||||
]
|
||||
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
||||
return None
|
||||
# 声明 JWT
|
||||
|
||||
@@ -0,0 +1,673 @@
|
||||
<template>
|
||||
<v-dialog v-model="isOpen" persistent max-width="700" scrollable>
|
||||
<v-card>
|
||||
<v-card-title class="d-flex align-center">
|
||||
<v-icon class="mr-2">mdi-backup-restore</v-icon>
|
||||
{{ t('features.settings.backup.dialog.title') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-6">
|
||||
<!-- 选项卡 -->
|
||||
<v-tabs v-model="activeTab" color="primary" class="mb-4">
|
||||
<v-tab value="export">
|
||||
<v-icon class="mr-2">mdi-export</v-icon>
|
||||
{{ t('features.settings.backup.tabs.export') }}
|
||||
</v-tab>
|
||||
<v-tab value="import">
|
||||
<v-icon class="mr-2">mdi-import</v-icon>
|
||||
{{ t('features.settings.backup.tabs.import') }}
|
||||
</v-tab>
|
||||
<v-tab value="list">
|
||||
<v-icon class="mr-2">mdi-format-list-bulleted</v-icon>
|
||||
{{ t('features.settings.backup.tabs.list') }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
|
||||
<v-window v-model="activeTab">
|
||||
<!-- 导出标签页 -->
|
||||
<v-window-item value="export">
|
||||
<div v-if="exportStatus === 'idle'" class="text-center py-8">
|
||||
<v-icon size="64" color="primary" class="mb-4">mdi-cloud-upload</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.title') }}</h3>
|
||||
<p class="mb-4 text-grey">{{ t('features.settings.backup.export.description') }}</p>
|
||||
<v-alert type="info" variant="tonal" class="mb-4 text-left">
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-information</v-icon>
|
||||
</template>
|
||||
{{ t('features.settings.backup.export.includes') }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" size="large" @click="startExport" :loading="exportStatus === 'processing'">
|
||||
<v-icon class="mr-2">mdi-export</v-icon>
|
||||
{{ t('features.settings.backup.export.button') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-else-if="exportStatus === 'processing'" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.processing') }}</h3>
|
||||
<p class="text-grey">{{ exportProgress.message || t('features.settings.backup.export.wait') }}</p>
|
||||
<v-progress-linear :model-value="exportProgress.current" :max="exportProgress.total" class="mt-4" color="primary"></v-progress-linear>
|
||||
</div>
|
||||
|
||||
<div v-else-if="exportStatus === 'completed'" class="text-center py-8">
|
||||
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.completed') }}</h3>
|
||||
<p class="mb-4">{{ exportResult?.filename }}</p>
|
||||
<v-btn color="primary" @click="downloadBackup(exportResult?.filename)" class="mr-2">
|
||||
<v-icon class="mr-2">mdi-download</v-icon>
|
||||
{{ t('features.settings.backup.export.download') }}
|
||||
</v-btn>
|
||||
<v-btn color="grey" variant="text" @click="resetExport">
|
||||
{{ t('features.settings.backup.export.another') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-else-if="exportStatus === 'failed'" class="text-center py-8">
|
||||
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.failed') }}</h3>
|
||||
<v-alert type="error" variant="tonal" class="mb-4">
|
||||
{{ exportError }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" @click="resetExport">
|
||||
{{ t('features.settings.backup.export.retry') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-window-item>
|
||||
|
||||
<!-- 导入标签页 -->
|
||||
<v-window-item value="import">
|
||||
<!-- 步骤1: 选择文件 -->
|
||||
<div v-if="importStatus === 'idle'" class="py-4">
|
||||
<v-alert type="warning" variant="tonal" class="mb-4">
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-alert</v-icon>
|
||||
</template>
|
||||
{{ t('features.settings.backup.import.warning') }}
|
||||
</v-alert>
|
||||
|
||||
<v-file-input
|
||||
v-model="importFile"
|
||||
:label="t('features.settings.backup.import.selectFile')"
|
||||
accept=".zip"
|
||||
prepend-icon="mdi-file-upload"
|
||||
show-size
|
||||
class="mb-4"
|
||||
></v-file-input>
|
||||
|
||||
<div class="d-flex justify-center">
|
||||
<v-btn
|
||||
color="primary"
|
||||
size="large"
|
||||
@click="uploadAndCheck"
|
||||
:disabled="!importFile"
|
||||
:loading="importStatus === 'uploading'"
|
||||
>
|
||||
<v-icon class="mr-2">mdi-upload</v-icon>
|
||||
{{ t('features.settings.backup.import.uploadAndCheck') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 步骤1.5: 上传中 -->
|
||||
<div v-else-if="importStatus === 'uploading'" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.uploading') }}</h3>
|
||||
<p class="text-grey">{{ t('features.settings.backup.import.uploadWait') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- 步骤2: 确认导入 -->
|
||||
<div v-else-if="importStatus === 'confirm'" class="py-4">
|
||||
<v-alert
|
||||
:type="versionAlertType"
|
||||
variant="tonal"
|
||||
class="mb-4"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon>{{ versionAlertIcon }}</v-icon>
|
||||
</template>
|
||||
<div class="confirm-message">
|
||||
<div class="text-h6 mb-2">{{ versionAlertTitle }}</div>
|
||||
<div class="mb-2">
|
||||
<strong>{{ t('features.settings.backup.import.version.backupVersion') }}:</strong> {{ checkResult?.backup_version }}<br>
|
||||
<strong>{{ t('features.settings.backup.import.version.currentVersion') }}:</strong> {{ checkResult?.current_version }}
|
||||
</div>
|
||||
<div v-if="checkResult?.backup_time && checkResult?.backup_time !== '未知'" class="mb-2">
|
||||
<strong>{{ t('features.settings.backup.import.version.backupTime') }}:</strong> {{ formatISODate(checkResult?.backup_time) }}
|
||||
</div>
|
||||
<div class="mt-3" style="white-space: pre-line;">{{ versionAlertMessage }}</div>
|
||||
</div>
|
||||
</v-alert>
|
||||
|
||||
<!-- 备份摘要 -->
|
||||
<v-card variant="outlined" class="mb-4" v-if="checkResult?.backup_summary">
|
||||
<v-card-title class="text-subtitle-1">
|
||||
<v-icon class="mr-2">mdi-package-variant</v-icon>
|
||||
{{ t('features.settings.backup.import.backupContents') }}
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<div class="d-flex flex-wrap ga-2">
|
||||
<v-chip v-if="checkResult.backup_summary.tables?.length" size="small" color="primary" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ checkResult.backup_summary.tables.length }} {{ t('features.settings.backup.import.tables') }}
|
||||
</v-chip>
|
||||
<v-chip v-if="checkResult.backup_summary.has_knowledge_bases" size="small" color="success" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ t('features.settings.backup.import.knowledgeBases') }}
|
||||
</v-chip>
|
||||
<v-chip v-if="checkResult.backup_summary.has_config" size="small" color="info" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ t('features.settings.backup.import.configFiles') }}
|
||||
</v-chip>
|
||||
<v-chip v-for="dir in (checkResult.backup_summary.directories || [])" :key="dir" size="small" color="warning" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ dir }}
|
||||
</v-chip>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
<!-- 警告信息 -->
|
||||
<v-alert v-if="checkResult?.warnings?.length" type="warning" variant="tonal" class="mb-4">
|
||||
<div v-for="(warning, idx) in checkResult.warnings" :key="idx">{{ warning }}</div>
|
||||
</v-alert>
|
||||
|
||||
<div class="d-flex justify-center align-center mt-4" style="gap: 16px;">
|
||||
<v-btn
|
||||
color="grey-darken-1"
|
||||
variant="outlined"
|
||||
size="large"
|
||||
@click="resetImport"
|
||||
>
|
||||
<v-icon class="mr-2">mdi-close</v-icon>
|
||||
{{ t('core.common.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
v-if="checkResult?.can_import"
|
||||
color="error"
|
||||
size="large"
|
||||
variant="flat"
|
||||
@click="confirmImport"
|
||||
>
|
||||
<v-icon class="mr-2">mdi-alert</v-icon>
|
||||
{{ t('features.settings.backup.import.confirmImport') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 步骤3: 导入进行中 -->
|
||||
<div v-else-if="importStatus === 'processing'" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.processing') }}</h3>
|
||||
<p class="text-grey">{{ importProgress.message || t('features.settings.backup.import.wait') }}</p>
|
||||
<v-progress-linear :model-value="importProgress.current" :max="importProgress.total" class="mt-4" color="primary"></v-progress-linear>
|
||||
</div>
|
||||
|
||||
<div v-else-if="importStatus === 'completed'" class="text-center py-8">
|
||||
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.completed') }}</h3>
|
||||
<v-alert type="info" variant="tonal" class="mb-4">
|
||||
{{ t('features.settings.backup.import.restartRequired') }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" @click="restartAstrBot" class="mr-2">
|
||||
<v-icon class="mr-2">mdi-restart</v-icon>
|
||||
{{ t('features.settings.backup.import.restartNow') }}
|
||||
</v-btn>
|
||||
<v-btn color="grey" variant="text" @click="resetImport">
|
||||
{{ t('core.common.close') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-else-if="importStatus === 'failed'" class="text-center py-8">
|
||||
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.failed') }}</h3>
|
||||
<v-alert type="error" variant="tonal" class="mb-4">
|
||||
{{ importError }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" @click="resetImport">
|
||||
{{ t('features.settings.backup.import.retry') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-window-item>
|
||||
|
||||
<!-- 备份列表标签页 -->
|
||||
<v-window-item value="list">
|
||||
<div v-if="loadingList" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary"></v-progress-circular>
|
||||
</div>
|
||||
|
||||
<div v-else-if="backupList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey" class="mb-4">mdi-folder-open-outline</v-icon>
|
||||
<p class="text-grey">{{ t('features.settings.backup.list.empty') }}</p>
|
||||
</div>
|
||||
|
||||
<v-list v-else lines="two">
|
||||
<v-list-item
|
||||
v-for="backup in backupList"
|
||||
:key="backup.filename"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon color="primary">mdi-zip-box</v-icon>
|
||||
</template>
|
||||
|
||||
<v-list-item-title>{{ backup.filename }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ formatFileSize(backup.size) }} · {{ formatDate(backup.created_at) }}
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-btn icon="mdi-download" variant="text" size="small" @click="downloadBackup(backup.filename)"></v-btn>
|
||||
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="deleteBackup(backup.filename)"></v-btn>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
|
||||
<div class="d-flex justify-center mt-4">
|
||||
<v-btn color="primary" variant="text" @click="loadBackupList">
|
||||
<v-icon class="mr-2">mdi-refresh</v-icon>
|
||||
{{ t('features.settings.backup.list.refresh') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-window-item>
|
||||
</v-window>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="px-6 py-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="handleClose" :disabled="isProcessing">
|
||||
{{ t('core.common.close') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import WaitingForRestart from './WaitingForRestart.vue'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const isOpen = ref(false)
|
||||
const activeTab = ref('export')
|
||||
const wfr = ref(null)
|
||||
|
||||
// 导出状态
|
||||
const exportStatus = ref('idle') // idle, processing, completed, failed
|
||||
const exportTaskId = ref(null)
|
||||
const exportProgress = ref({ current: 0, total: 100, message: '' })
|
||||
const exportResult = ref(null)
|
||||
const exportError = ref('')
|
||||
|
||||
// 导入状态
|
||||
const importStatus = ref('idle') // idle, uploading, confirm, processing, completed, failed
|
||||
const importFile = ref(null)
|
||||
const importTaskId = ref(null)
|
||||
const importProgress = ref({ current: 0, total: 100, message: '' })
|
||||
const importError = ref('')
|
||||
const uploadedFilename = ref('') // 已上传的文件名
|
||||
const checkResult = ref(null) // 预检查结果
|
||||
|
||||
// 备份列表
|
||||
const loadingList = ref(false)
|
||||
const backupList = ref([])
|
||||
|
||||
// 计算属性
|
||||
const isProcessing = computed(() => {
|
||||
return exportStatus.value === 'processing' || importStatus.value === 'processing'
|
||||
})
|
||||
|
||||
// 版本检查相关的计算属性
|
||||
const versionAlertType = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return 'error'
|
||||
if (status === 'minor_diff') return 'warning'
|
||||
return 'info'
|
||||
})
|
||||
|
||||
const versionAlertIcon = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return 'mdi-close-circle'
|
||||
if (status === 'minor_diff') return 'mdi-alert'
|
||||
return 'mdi-check-circle'
|
||||
})
|
||||
|
||||
const versionAlertTitle = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffTitle')
|
||||
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffTitle')
|
||||
return t('features.settings.backup.import.version.matchTitle')
|
||||
})
|
||||
|
||||
const versionAlertMessage = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffMessage')
|
||||
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffMessage')
|
||||
return t('features.settings.backup.import.version.matchMessage')
|
||||
})
|
||||
|
||||
// 监听对话框打开
|
||||
watch(isOpen, (newVal) => {
|
||||
if (newVal) {
|
||||
loadBackupList()
|
||||
} else {
|
||||
resetAll()
|
||||
}
|
||||
})
|
||||
|
||||
// 监听标签页切换
|
||||
watch(activeTab, (newVal) => {
|
||||
if (newVal === 'list') {
|
||||
loadBackupList()
|
||||
}
|
||||
})
|
||||
|
||||
// 加载备份列表
|
||||
const loadBackupList = async () => {
|
||||
loadingList.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/backup/list')
|
||||
if (response.data.status === 'ok') {
|
||||
backupList.value = response.data.data.items || []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load backup list:', error)
|
||||
} finally {
|
||||
loadingList.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 开始导出
|
||||
const startExport = async () => {
|
||||
exportStatus.value = 'processing'
|
||||
exportProgress.value = { current: 0, total: 100, message: '' }
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/backup/export')
|
||||
if (response.data.status === 'ok') {
|
||||
exportTaskId.value = response.data.data.task_id
|
||||
pollExportProgress()
|
||||
} else {
|
||||
throw new Error(response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
exportStatus.value = 'failed'
|
||||
exportError.value = error.message || 'Export failed'
|
||||
}
|
||||
}
|
||||
|
||||
// 轮询导出进度
|
||||
const pollExportProgress = async () => {
|
||||
if (!exportTaskId.value) return
|
||||
|
||||
try {
|
||||
const response = await axios.get('/api/backup/progress', {
|
||||
params: { task_id: exportTaskId.value }
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data
|
||||
|
||||
if (data.status === 'processing' && data.progress) {
|
||||
exportProgress.value = {
|
||||
current: data.progress.current || 0,
|
||||
total: data.progress.total || 100,
|
||||
message: data.progress.message || ''
|
||||
}
|
||||
setTimeout(pollExportProgress, 1000)
|
||||
} else if (data.status === 'completed') {
|
||||
exportStatus.value = 'completed'
|
||||
exportResult.value = data.result
|
||||
loadBackupList()
|
||||
} else if (data.status === 'failed') {
|
||||
exportStatus.value = 'failed'
|
||||
exportError.value = data.error || 'Export failed'
|
||||
} else {
|
||||
setTimeout(pollExportProgress, 1000)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
exportStatus.value = 'failed'
|
||||
exportError.value = error.message || 'Failed to get export progress'
|
||||
}
|
||||
}
|
||||
|
||||
// 重置导出状态
|
||||
const resetExport = () => {
|
||||
exportStatus.value = 'idle'
|
||||
exportTaskId.value = null
|
||||
exportProgress.value = { current: 0, total: 100, message: '' }
|
||||
exportResult.value = null
|
||||
exportError.value = ''
|
||||
}
|
||||
|
||||
// 上传并检查
|
||||
const uploadAndCheck = async () => {
|
||||
if (!importFile.value) return
|
||||
|
||||
importStatus.value = 'uploading'
|
||||
|
||||
try {
|
||||
// 步骤1: 上传文件
|
||||
const formData = new FormData()
|
||||
formData.append('file', importFile.value)
|
||||
|
||||
const uploadResponse = await axios.post('/api/backup/upload', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' }
|
||||
})
|
||||
|
||||
if (uploadResponse.data.status !== 'ok') {
|
||||
throw new Error(uploadResponse.data.message)
|
||||
}
|
||||
|
||||
uploadedFilename.value = uploadResponse.data.data.filename
|
||||
|
||||
// 步骤2: 预检查
|
||||
const checkResponse = await axios.post('/api/backup/check', {
|
||||
filename: uploadedFilename.value
|
||||
})
|
||||
|
||||
if (checkResponse.data.status !== 'ok') {
|
||||
throw new Error(checkResponse.data.message)
|
||||
}
|
||||
|
||||
checkResult.value = checkResponse.data.data
|
||||
|
||||
// 检查是否有效
|
||||
if (!checkResult.value.valid) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = checkResult.value.error || t('features.settings.backup.import.invalidBackup')
|
||||
return
|
||||
}
|
||||
|
||||
// 显示确认对话框
|
||||
importStatus.value = 'confirm'
|
||||
|
||||
} catch (error) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = error.response?.data?.message || error.message || 'Upload failed'
|
||||
}
|
||||
}
|
||||
|
||||
// 确认导入
|
||||
const confirmImport = async () => {
|
||||
if (!uploadedFilename.value) return
|
||||
|
||||
importStatus.value = 'processing'
|
||||
importProgress.value = { current: 0, total: 100, message: '' }
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/backup/import', {
|
||||
filename: uploadedFilename.value,
|
||||
confirmed: true
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
importTaskId.value = response.data.data.task_id
|
||||
pollImportProgress()
|
||||
} else {
|
||||
throw new Error(response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = error.response?.data?.message || error.message || 'Import failed'
|
||||
}
|
||||
}
|
||||
|
||||
// 轮询导入进度
|
||||
const pollImportProgress = async () => {
|
||||
if (!importTaskId.value) return
|
||||
|
||||
try {
|
||||
const response = await axios.get('/api/backup/progress', {
|
||||
params: { task_id: importTaskId.value }
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data
|
||||
|
||||
if (data.status === 'processing' && data.progress) {
|
||||
importProgress.value = {
|
||||
current: data.progress.current || 0,
|
||||
total: data.progress.total || 100,
|
||||
message: data.progress.message || ''
|
||||
}
|
||||
setTimeout(pollImportProgress, 1000)
|
||||
} else if (data.status === 'completed') {
|
||||
importStatus.value = 'completed'
|
||||
} else if (data.status === 'failed') {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = data.error || 'Import failed'
|
||||
} else {
|
||||
setTimeout(pollImportProgress, 1000)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = error.message || 'Failed to get import progress'
|
||||
}
|
||||
}
|
||||
|
||||
// 重置导入状态
|
||||
const resetImport = () => {
|
||||
importStatus.value = 'idle'
|
||||
importFile.value = null
|
||||
importTaskId.value = null
|
||||
importProgress.value = { current: 0, total: 100, message: '' }
|
||||
importError.value = ''
|
||||
uploadedFilename.value = ''
|
||||
checkResult.value = null
|
||||
}
|
||||
|
||||
// 下载备份
|
||||
const downloadBackup = async (filename) => {
|
||||
try {
|
||||
const response = await axios.get('/api/backup/download', {
|
||||
params: { filename },
|
||||
responseType: 'blob'
|
||||
})
|
||||
|
||||
// 创建 Blob URL 并触发下载
|
||||
const blob = new Blob([response.data], { type: 'application/zip' })
|
||||
const url = window.URL.createObjectURL(blob)
|
||||
const link = document.createElement('a')
|
||||
link.href = url
|
||||
link.download = filename
|
||||
document.body.appendChild(link)
|
||||
link.click()
|
||||
document.body.removeChild(link)
|
||||
window.URL.revokeObjectURL(url)
|
||||
} catch (error) {
|
||||
console.error('Download failed:', error)
|
||||
alert(t('features.settings.backup.export.failed') + ': ' + (error.message || 'Unknown error'))
|
||||
}
|
||||
}
|
||||
|
||||
// 删除备份
|
||||
const deleteBackup = async (filename) => {
|
||||
if (!confirm(t('features.settings.backup.list.confirmDelete'))) return
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/backup/delete', { filename })
|
||||
if (response.data.status === 'ok') {
|
||||
loadBackupList()
|
||||
} else {
|
||||
alert(response.data.message || 'Delete failed')
|
||||
}
|
||||
} catch (error) {
|
||||
alert(error.message || 'Delete failed')
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化文件大小
|
||||
const formatFileSize = (bytes) => {
|
||||
if (bytes === 0) return '0 B'
|
||||
const k = 1024
|
||||
const sizes = ['B', 'KB', 'MB', 'GB']
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
||||
}
|
||||
|
||||
// 格式化日期(从时间戳)
|
||||
const formatDate = (timestamp) => {
|
||||
return new Date(timestamp * 1000).toLocaleString()
|
||||
}
|
||||
|
||||
// 格式化 ISO 日期字符串
|
||||
const formatISODate = (isoString) => {
|
||||
if (!isoString) return ''
|
||||
try {
|
||||
return new Date(isoString).toLocaleString()
|
||||
} catch {
|
||||
return isoString
|
||||
}
|
||||
}
|
||||
|
||||
// 重启 AstrBot
|
||||
const restartAstrBot = () => {
|
||||
axios.post('/api/stat/restart-core').then(() => {
|
||||
if (wfr.value) {
|
||||
wfr.value.check()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 重置所有状态
|
||||
const resetAll = () => {
|
||||
resetExport()
|
||||
resetImport()
|
||||
activeTab.value = 'export'
|
||||
}
|
||||
|
||||
// 关闭对话框
|
||||
const handleClose = () => {
|
||||
if (isProcessing.value) return
|
||||
isOpen.value = false
|
||||
}
|
||||
|
||||
// 打开对话框
|
||||
const open = () => {
|
||||
isOpen.value = true
|
||||
}
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.v-list-item {
|
||||
border-bottom: 1px solid rgba(0, 0, 0, 0.08);
|
||||
}
|
||||
|
||||
.v-list-item:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
/* 禁用 Chip 的交互效果 */
|
||||
.non-interactive-chip {
|
||||
pointer-events: none;
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.non-interactive-chip:hover {
|
||||
box-shadow: none !important;
|
||||
}
|
||||
</style>
|
||||
@@ -18,6 +18,11 @@
|
||||
"title": "Data Migration to v4.0.0",
|
||||
"subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant",
|
||||
"button": "Start Migration Assistant"
|
||||
},
|
||||
"backup": {
|
||||
"title": "Backup & Restore",
|
||||
"subtitle": "Export or import all AstrBot data for easy migration to a new server",
|
||||
"button": "Backup Manager"
|
||||
}
|
||||
},
|
||||
"sidebar": {
|
||||
@@ -29,5 +34,66 @@
|
||||
"mainItems": "Main Modules",
|
||||
"moreItems": "More Features"
|
||||
}
|
||||
},
|
||||
"backup": {
|
||||
"dialog": {
|
||||
"title": "Backup Manager"
|
||||
},
|
||||
"tabs": {
|
||||
"export": "Export Backup",
|
||||
"import": "Import Backup",
|
||||
"list": "Backup List"
|
||||
},
|
||||
"export": {
|
||||
"title": "Create Backup",
|
||||
"description": "Export all data as a ZIP backup file, including database, knowledge base, config and attachments.",
|
||||
"includes": "Backup includes: Main database, Knowledge bases (metadata + vector index + documents), Config files, Attachment files",
|
||||
"button": "Start Export",
|
||||
"processing": "Exporting...",
|
||||
"wait": "Please wait, packaging data...",
|
||||
"completed": "Export Completed!",
|
||||
"download": "Download Backup",
|
||||
"another": "Create New Backup",
|
||||
"failed": "Export Failed",
|
||||
"retry": "Retry"
|
||||
},
|
||||
"import": {
|
||||
"title": "Import Backup",
|
||||
"warning": "⚠️ Import will clear and overwrite existing data! Please make sure you have backed up your current data.",
|
||||
"selectFile": "Select backup file (.zip)",
|
||||
"uploadAndCheck": "Upload & Check",
|
||||
"uploading": "Uploading...",
|
||||
"uploadWait": "Please wait, uploading backup file...",
|
||||
"invalidBackup": "Invalid backup file",
|
||||
"backupContents": "Backup Contents",
|
||||
"tables": "tables",
|
||||
"knowledgeBases": "Knowledge Bases",
|
||||
"configFiles": "Config Files",
|
||||
"confirmImport": "Confirm Import",
|
||||
"button": "Start Import",
|
||||
"processing": "Importing...",
|
||||
"wait": "Please wait, restoring data...",
|
||||
"completed": "Import Completed!",
|
||||
"restartRequired": "Data has been successfully imported. It is recommended to restart AstrBot immediately for all changes to take effect.",
|
||||
"restartNow": "Restart Now",
|
||||
"failed": "Import Failed",
|
||||
"retry": "Retry",
|
||||
"version": {
|
||||
"backupVersion": "Backup Version",
|
||||
"currentVersion": "Current Version",
|
||||
"backupTime": "Backup Time",
|
||||
"matchTitle": "✅ Version Match",
|
||||
"matchMessage": "Import will clear and overwrite all existing data, including:\n• Main database (conversations, settings, etc.)\n• Knowledge bases\n• Plugins and plugin data\n• Configuration files\n\nThis action cannot be undone! Do you want to continue?",
|
||||
"minorDiffTitle": "⚠️ Version Difference Warning",
|
||||
"minorDiffMessage": "Minor version differences are usually compatible, but there may be some data structure changes.\nImport will clear and overwrite all existing data!\n\nDo you want to continue?",
|
||||
"majorDiffTitle": "⛔ Cannot Import",
|
||||
"majorDiffMessage": "Major version numbers are different. Cross-major-version import may cause data corruption.\nPlease use the same major version of AstrBot for import."
|
||||
}
|
||||
},
|
||||
"list": {
|
||||
"empty": "No backup files",
|
||||
"refresh": "Refresh List",
|
||||
"confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,11 @@
|
||||
"title": "数据迁移到 v4.0.0 格式",
|
||||
"subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手",
|
||||
"button": "启动迁移助手"
|
||||
},
|
||||
"backup": {
|
||||
"title": "数据备份与恢复",
|
||||
"subtitle": "导出或导入 AstrBot 的所有数据,方便迁移到新服务器",
|
||||
"button": "备份管理"
|
||||
}
|
||||
},
|
||||
"sidebar": {
|
||||
@@ -29,5 +34,66 @@
|
||||
"mainItems": "主要模块",
|
||||
"moreItems": "更多功能"
|
||||
}
|
||||
},
|
||||
"backup": {
|
||||
"dialog": {
|
||||
"title": "备份管理"
|
||||
},
|
||||
"tabs": {
|
||||
"export": "导出备份",
|
||||
"import": "导入备份",
|
||||
"list": "备份列表"
|
||||
},
|
||||
"export": {
|
||||
"title": "创建备份",
|
||||
"description": "将所有数据导出为 ZIP 备份文件,包括数据库、知识库、配置和附件。",
|
||||
"includes": "备份包含:主数据库、知识库(元数据+向量索引+文档)、配置文件、附件文件",
|
||||
"button": "开始导出",
|
||||
"processing": "正在导出...",
|
||||
"wait": "请稍候,正在打包数据...",
|
||||
"completed": "导出完成!",
|
||||
"download": "下载备份",
|
||||
"another": "创建新备份",
|
||||
"failed": "导出失败",
|
||||
"retry": "重试"
|
||||
},
|
||||
"import": {
|
||||
"title": "导入备份",
|
||||
"warning": "⚠️ 导入将会清空并覆盖现有数据!请确保已备份当前数据。",
|
||||
"selectFile": "选择备份文件 (.zip)",
|
||||
"uploadAndCheck": "上传并检查",
|
||||
"uploading": "正在上传...",
|
||||
"uploadWait": "请稍候,正在上传备份文件...",
|
||||
"invalidBackup": "无效的备份文件",
|
||||
"backupContents": "备份内容",
|
||||
"tables": "个数据表",
|
||||
"knowledgeBases": "知识库",
|
||||
"configFiles": "配置文件",
|
||||
"confirmImport": "确认导入",
|
||||
"button": "开始导入",
|
||||
"processing": "正在导入...",
|
||||
"wait": "请稍候,正在恢复数据...",
|
||||
"completed": "导入完成!",
|
||||
"restartRequired": "数据已成功导入。建议立即重启 AstrBot 以使所有更改生效。",
|
||||
"restartNow": "立即重启",
|
||||
"failed": "导入失败",
|
||||
"retry": "重试",
|
||||
"version": {
|
||||
"backupVersion": "备份版本",
|
||||
"currentVersion": "当前版本",
|
||||
"backupTime": "备份时间",
|
||||
"matchTitle": "✅ 版本匹配",
|
||||
"matchMessage": "导入将会清空并覆盖现有的所有数据,包括:\n• 主数据库(对话记录、配置等)\n• 知识库数据\n• 插件及插件数据\n• 配置文件\n\n此操作不可撤销!是否确认继续?",
|
||||
"minorDiffTitle": "⚠️ 版本差异警告",
|
||||
"minorDiffMessage": "小版本差异通常是兼容的,但可能存在少量数据结构变化。\n导入将会清空并覆盖现有的所有数据!\n\n是否确认继续导入?",
|
||||
"majorDiffTitle": "⛔ 无法导入",
|
||||
"majorDiffMessage": "主版本号不同,跨主版本导入可能导致数据损坏。\n请使用相同主版本的 AstrBot 进行导入。"
|
||||
}
|
||||
},
|
||||
"list": {
|
||||
"empty": "暂无备份文件",
|
||||
"refresh": "刷新列表",
|
||||
"confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,13 @@
|
||||
|
||||
<v-list-subheader>{{ tm('system.title') }}</v-list-subheader>
|
||||
|
||||
<v-list-item :subtitle="tm('system.backup.subtitle')" :title="tm('system.backup.title')">
|
||||
<v-btn style="margin-top: 16px;" color="primary" @click="openBackupDialog">
|
||||
<v-icon class="mr-2">mdi-backup-restore</v-icon>
|
||||
{{ tm('system.backup.button') }}
|
||||
</v-btn>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item :subtitle="tm('system.restart.subtitle')" :title="tm('system.restart.title')">
|
||||
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">{{ tm('system.restart.button') }}</v-btn>
|
||||
</v-list-item>
|
||||
@@ -30,6 +37,7 @@
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
<MigrationDialog ref="migrationDialog"></MigrationDialog>
|
||||
<BackupDialog ref="backupDialog"></BackupDialog>
|
||||
|
||||
</template>
|
||||
|
||||
@@ -40,12 +48,14 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
||||
import MigrationDialog from '@/components/shared/MigrationDialog.vue';
|
||||
import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue';
|
||||
import BackupDialog from '@/components/shared/BackupDialog.vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
const { tm } = useModuleI18n('features/settings');
|
||||
|
||||
const wfr = ref(null);
|
||||
const migrationDialog = ref(null);
|
||||
const backupDialog = ref(null);
|
||||
|
||||
const restartAstrBot = () => {
|
||||
axios.post('/api/stat/restart-core').then(() => {
|
||||
@@ -65,4 +75,10 @@ const startMigration = async () => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const openBackupDialog = () => {
|
||||
if (backupDialog.value) {
|
||||
backupDialog.value.open();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -19,6 +19,7 @@ export default defineConfig({
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
mermaid: 'mermaid/dist/mermaid.js',
|
||||
'@': fileURLToPath(new URL('./src', import.meta.url))
|
||||
}
|
||||
},
|
||||
|
||||
@@ -0,0 +1,760 @@
|
||||
"""备份功能单元测试"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.backup import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
ImportPreCheckResult,
|
||||
)
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.backup.importer import (
|
||||
AstrBotImporter,
|
||||
ImportResult,
|
||||
_get_major_version,
|
||||
)
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db.po import (
|
||||
ConversationV2,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
from astrbot.dashboard.routes.backup import (
|
||||
generate_unique_filename,
|
||||
secure_filename,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_backup_dir(tmp_path):
|
||||
"""创建临时备份目录"""
|
||||
backup_dir = tmp_path / "backups"
|
||||
backup_dir.mkdir()
|
||||
return backup_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dir(tmp_path):
|
||||
"""创建临时数据目录"""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
# 创建配置文件
|
||||
config_path = data_dir / "cmd_config.json"
|
||||
config_path.write_text(json.dumps({"test": "config"}))
|
||||
|
||||
# 创建附件目录
|
||||
attachments_dir = data_dir / "attachments"
|
||||
attachments_dir.mkdir()
|
||||
|
||||
return data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_main_db():
|
||||
"""创建模拟的主数据库"""
|
||||
db = MagicMock()
|
||||
|
||||
# 模拟异步上下文管理器
|
||||
session = AsyncMock()
|
||||
db.get_db = MagicMock(
|
||||
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
|
||||
)
|
||||
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kb_manager():
|
||||
"""创建模拟的知识库管理器"""
|
||||
kb_manager = MagicMock()
|
||||
kb_manager.kb_insts = {}
|
||||
|
||||
# 模拟 kb_db
|
||||
kb_db = MagicMock()
|
||||
session = AsyncMock()
|
||||
kb_db.get_db = MagicMock(
|
||||
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
|
||||
)
|
||||
kb_manager.kb_db = kb_db
|
||||
|
||||
return kb_manager
|
||||
|
||||
|
||||
class TestImportResult:
|
||||
"""ImportResult 类测试"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
result = ImportResult()
|
||||
assert result.success is True
|
||||
assert result.imported_tables == {}
|
||||
assert result.imported_files == {}
|
||||
assert result.warnings == []
|
||||
assert result.errors == []
|
||||
|
||||
def test_add_warning(self):
|
||||
"""测试添加警告"""
|
||||
result = ImportResult()
|
||||
result.add_warning("test warning")
|
||||
assert "test warning" in result.warnings
|
||||
assert result.success is True # 警告不影响成功状态
|
||||
|
||||
def test_add_error(self):
|
||||
"""测试添加错误"""
|
||||
result = ImportResult()
|
||||
result.add_error("test error")
|
||||
assert "test error" in result.errors
|
||||
assert result.success is False # 错误会导致失败
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
result = ImportResult()
|
||||
result.imported_tables = {"test_table": 10}
|
||||
result.add_warning("warning")
|
||||
|
||||
d = result.to_dict()
|
||||
assert d["success"] is True
|
||||
assert d["imported_tables"] == {"test_table": 10}
|
||||
assert "warning" in d["warnings"]
|
||||
|
||||
|
||||
class TestAstrBotExporter:
|
||||
"""AstrBotExporter 类测试"""
|
||||
|
||||
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
|
||||
"""测试初始化"""
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=mock_kb_manager,
|
||||
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||
)
|
||||
assert exporter.main_db is mock_main_db
|
||||
assert exporter.kb_manager is mock_kb_manager
|
||||
|
||||
def test_model_to_dict_with_model_dump(self):
|
||||
"""测试 _model_to_dict 使用 model_dump 方法"""
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
|
||||
# 创建一个有 model_dump 方法的模拟对象
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"id": 1, "name": "test"}
|
||||
|
||||
result = exporter._model_to_dict(mock_record)
|
||||
assert result == {"id": 1, "name": "test"}
|
||||
|
||||
def test_model_to_dict_with_datetime(self):
|
||||
"""测试 _model_to_dict 处理 datetime 字段"""
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
|
||||
now = datetime.now()
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"id": 1, "created_at": now}
|
||||
|
||||
result = exporter._model_to_dict(mock_record)
|
||||
assert result["created_at"] == now.isoformat()
|
||||
|
||||
def test_add_checksum(self):
|
||||
"""测试添加校验和"""
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
|
||||
exporter._add_checksum("test.json", '{"test": "data"}')
|
||||
|
||||
assert "test.json" in exporter._checksums
|
||||
assert exporter._checksums["test.json"].startswith("sha256:")
|
||||
|
||||
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
|
||||
"""测试生成清单"""
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=mock_kb_manager,
|
||||
)
|
||||
|
||||
main_data = {
|
||||
"platform_stats": [{"id": 1}],
|
||||
"conversations": [],
|
||||
"attachments": [],
|
||||
}
|
||||
kb_meta_data = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
}
|
||||
dir_stats = {
|
||||
"plugins": {"files": 10, "size": 1024},
|
||||
"plugin_data": {"files": 5, "size": 512},
|
||||
}
|
||||
|
||||
manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
|
||||
assert manifest["version"] == BACKUP_MANIFEST_VERSION
|
||||
assert manifest["astrbot_version"] == VERSION
|
||||
assert "exported_at" in manifest
|
||||
assert "tables" in manifest
|
||||
assert "statistics" in manifest
|
||||
assert "directories" in manifest
|
||||
assert manifest["statistics"]["main_db"]["platform_stats"] == 1
|
||||
assert manifest["statistics"]["directories"] == dir_stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_all_creates_zip(
|
||||
self, mock_main_db, temp_backup_dir, temp_data_dir
|
||||
):
|
||||
"""测试导出创建 ZIP 文件"""
|
||||
# 设置模拟数据库返回空数据
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalars.return_value.all.return_value = []
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
mock_main_db.get_db.return_value = AsyncMock(
|
||||
__aenter__=AsyncMock(return_value=session),
|
||||
__aexit__=AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=None,
|
||||
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||
)
|
||||
|
||||
zip_path = await exporter.export_all(output_dir=str(temp_backup_dir))
|
||||
|
||||
assert os.path.exists(zip_path)
|
||||
assert zip_path.endswith(".zip")
|
||||
assert "astrbot_backup_" in zip_path
|
||||
|
||||
# 验证 ZIP 文件内容
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
namelist = zf.namelist()
|
||||
assert "manifest.json" in namelist
|
||||
assert "databases/main_db.json" in namelist
|
||||
assert "config/cmd_config.json" in namelist
|
||||
|
||||
|
||||
class TestAstrBotImporter:
|
||||
"""AstrBotImporter 类测试"""
|
||||
|
||||
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
|
||||
"""测试初始化"""
|
||||
importer = AstrBotImporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=mock_kb_manager,
|
||||
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||
)
|
||||
assert importer.main_db is mock_main_db
|
||||
assert importer.kb_manager is mock_kb_manager
|
||||
|
||||
def test_validate_version_match(self):
|
||||
"""测试版本匹配验证"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
manifest = {"astrbot_version": VERSION}
|
||||
# 不应该抛出异常
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_validate_version_major_diff_rejected(self):
|
||||
"""测试主版本不同被拒绝"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
# 使用一个明显不同的主版本
|
||||
manifest = {"astrbot_version": "0.0.1"}
|
||||
with pytest.raises(ValueError, match="主版本不兼容"):
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_validate_version_minor_diff_allowed(self):
|
||||
"""测试小版本不同被允许(仅警告)"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
# 获取当前主版本
|
||||
major_version = _get_major_version(VERSION)
|
||||
# 构造一个同主版本但小版本不同的版本
|
||||
minor_diff_version = f"{major_version}.999"
|
||||
manifest = {"astrbot_version": minor_diff_version}
|
||||
# 不应该抛出异常
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_validate_version_missing(self):
|
||||
"""测试缺少版本信息"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
manifest = {}
|
||||
with pytest.raises(ValueError, match="缺少版本信息"):
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_convert_datetime_fields(self):
|
||||
"""测试 datetime 字段转换"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
# 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段)
|
||||
row = {
|
||||
"conversation_id": "test-123",
|
||||
"platform_id": "test",
|
||||
"user_id": "user1",
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"updated_at": "2024-01-01T12:00:00",
|
||||
}
|
||||
|
||||
result = importer._convert_datetime_fields(row, ConversationV2)
|
||||
|
||||
# created_at 应该被转换为 datetime 对象
|
||||
assert isinstance(result["created_at"], datetime)
|
||||
assert isinstance(result["updated_at"], datetime)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
|
||||
"""测试导入不存在的文件"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
|
||||
result = await importer.import_all(str(tmp_path / "nonexistent.zip"))
|
||||
|
||||
assert result.success is False
|
||||
assert any("不存在" in err for err in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_invalid_zip(self, mock_main_db, tmp_path):
|
||||
"""测试导入无效的 ZIP 文件"""
|
||||
# 创建一个无效的文件
|
||||
invalid_zip = tmp_path / "invalid.zip"
|
||||
invalid_zip.write_text("not a zip file")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = await importer.import_all(str(invalid_zip))
|
||||
|
||||
assert result.success is False
|
||||
assert any("无效" in err or "ZIP" in err for err in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_missing_manifest(self, mock_main_db, tmp_path):
|
||||
"""测试导入缺少 manifest 的 ZIP 文件"""
|
||||
# 创建一个没有 manifest 的 ZIP 文件
|
||||
zip_path = tmp_path / "no_manifest.zip"
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("test.txt", "test content")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = await importer.import_all(str(zip_path))
|
||||
|
||||
assert result.success is False
|
||||
assert any("manifest" in err.lower() for err in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_major_version_mismatch(self, mock_main_db, tmp_path):
|
||||
"""测试导入主版本不匹配的备份"""
|
||||
# 创建一个主版本不匹配的备份
|
||||
zip_path = tmp_path / "old_version.zip"
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"astrbot_version": "0.0.1", # 主版本不同
|
||||
"tables": {"main_db": []},
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = await importer.import_all(str(zip_path))
|
||||
|
||||
assert result.success is False
|
||||
assert any("主版本不兼容" in err for err in result.errors)
|
||||
|
||||
|
||||
class TestSecureFilename:
|
||||
"""安全文件名函数测试"""
|
||||
|
||||
def test_secure_filename_normal(self):
|
||||
"""测试正常文件名"""
|
||||
assert secure_filename("backup.zip") == "backup.zip"
|
||||
assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip"
|
||||
|
||||
def test_secure_filename_path_traversal(self):
|
||||
"""测试路径遍历攻击"""
|
||||
assert ".." not in secure_filename("../../../etc/passwd")
|
||||
assert "/" not in secure_filename("/etc/passwd")
|
||||
assert "\\" not in secure_filename("..\\..\\windows\\system32")
|
||||
|
||||
def test_secure_filename_with_path(self):
|
||||
"""测试带路径的文件名"""
|
||||
result = secure_filename("/path/to/backup.zip")
|
||||
assert result == "backup.zip"
|
||||
|
||||
result = secure_filename("C:\\Users\\test\\backup.zip")
|
||||
assert result == "backup.zip"
|
||||
|
||||
def test_secure_filename_special_chars(self):
|
||||
"""测试特殊字符"""
|
||||
result = secure_filename('backup<>:"|?*.zip')
|
||||
# 特殊字符应被替换为下划线
|
||||
assert "<" not in result
|
||||
assert ">" not in result
|
||||
assert ":" not in result
|
||||
assert '"' not in result
|
||||
assert "|" not in result
|
||||
assert "?" not in result
|
||||
assert "*" not in result
|
||||
|
||||
def test_secure_filename_hidden_file(self):
|
||||
"""测试隐藏文件(前导点)"""
|
||||
result = secure_filename(".hidden_backup.zip")
|
||||
assert not result.startswith(".")
|
||||
|
||||
def test_secure_filename_empty(self):
|
||||
"""测试空文件名"""
|
||||
assert secure_filename("") == "backup"
|
||||
assert secure_filename("...") == "backup"
|
||||
|
||||
def test_generate_unique_filename(self):
|
||||
"""测试生成唯一文件名"""
|
||||
result = generate_unique_filename("backup.zip")
|
||||
# 应包含 uploaded_ 前缀和时间戳
|
||||
assert result.startswith("uploaded_")
|
||||
assert result.endswith("_backup.zip")
|
||||
# 应包含时间戳格式 YYYYMMDD_HHMMSS
|
||||
assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result)
|
||||
|
||||
|
||||
class TestVersionComparison:
|
||||
"""版本比较函数测试 - 使用 VersionComparator"""
|
||||
|
||||
def test_get_major_version_simple(self):
|
||||
"""测试提取简单主版本号"""
|
||||
assert _get_major_version("1.0") == "1.0"
|
||||
assert _get_major_version("2.1") == "2.1"
|
||||
assert _get_major_version("4.9.1") == "4.9"
|
||||
|
||||
def test_get_major_version_with_prefix(self):
|
||||
"""测试带 v 前缀的版本号"""
|
||||
assert _get_major_version("v1.0") == "1.0"
|
||||
assert _get_major_version("V4.9.1") == "4.9"
|
||||
|
||||
def test_get_major_version_with_prerelease(self):
|
||||
"""测试带预发布标签的版本号"""
|
||||
assert _get_major_version("4.9.1-beta") == "4.9"
|
||||
assert _get_major_version("4.9.1-alpha.1") == "4.9"
|
||||
assert _get_major_version("4.9.1+build123") == "4.9"
|
||||
|
||||
def test_get_major_version_single_part(self):
|
||||
"""测试单部分版本号"""
|
||||
assert _get_major_version("1") == "1.0"
|
||||
|
||||
def test_get_major_version_empty(self):
|
||||
"""测试空版本号"""
|
||||
assert _get_major_version("") == "0.0"
|
||||
|
||||
def test_compare_versions_equal(self):
|
||||
"""测试版本相等"""
|
||||
assert VersionComparator.compare_version("1.0", "1.0") == 0
|
||||
assert VersionComparator.compare_version("1.0.0", "1.0") == 0
|
||||
assert VersionComparator.compare_version("2.10", "2.10") == 0
|
||||
|
||||
def test_compare_versions_less_than(self):
|
||||
"""测试版本小于"""
|
||||
assert VersionComparator.compare_version("1.0", "1.1") == -1
|
||||
assert (
|
||||
VersionComparator.compare_version("1.9", "1.10") == -1
|
||||
) # 关键测试:多位数版本比较
|
||||
assert VersionComparator.compare_version("1.2", "1.10") == -1
|
||||
assert VersionComparator.compare_version("1.0", "2.0") == -1
|
||||
|
||||
def test_compare_versions_greater_than(self):
|
||||
"""测试版本大于"""
|
||||
assert VersionComparator.compare_version("1.1", "1.0") == 1
|
||||
assert (
|
||||
VersionComparator.compare_version("1.10", "1.9") == 1
|
||||
) # 关键测试:多位数版本比较
|
||||
assert VersionComparator.compare_version("1.10", "1.2") == 1
|
||||
assert VersionComparator.compare_version("2.0", "1.0") == 1
|
||||
|
||||
def test_compare_versions_different_lengths(self):
|
||||
"""测试不同长度版本比较"""
|
||||
assert VersionComparator.compare_version("1.0", "1.0.0") == 0
|
||||
assert VersionComparator.compare_version("1.0", "1.0.1") == -1
|
||||
assert VersionComparator.compare_version("1.0.1", "1.0") == 1
|
||||
|
||||
def test_compare_versions_prerelease(self):
|
||||
"""测试预发布版本比较"""
|
||||
# 预发布版本低于正式版本
|
||||
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0") == -1
|
||||
assert VersionComparator.compare_version("1.0.0", "1.0.0-beta") == 1
|
||||
# alpha < beta
|
||||
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0-beta") == -1
|
||||
|
||||
|
||||
class TestImportPreCheckResult:
|
||||
"""ImportPreCheckResult 类测试"""
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""测试默认值初始化"""
|
||||
result = ImportPreCheckResult()
|
||||
assert result.valid is False
|
||||
assert result.can_import is False
|
||||
assert result.version_status == ""
|
||||
assert result.backup_version == ""
|
||||
assert result.current_version == VERSION
|
||||
assert result.confirm_message == ""
|
||||
assert result.warnings == []
|
||||
assert result.error == ""
|
||||
assert result.backup_summary == {}
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
result = ImportPreCheckResult(
|
||||
valid=True,
|
||||
can_import=True,
|
||||
version_status="match",
|
||||
backup_version="4.9.0",
|
||||
confirm_message="确认导入?",
|
||||
warnings=["警告1"],
|
||||
backup_summary={"tables": ["table1"]},
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
assert d["valid"] is True
|
||||
assert d["can_import"] is True
|
||||
assert d["version_status"] == "match"
|
||||
assert d["backup_version"] == "4.9.0"
|
||||
assert d["confirm_message"] == "确认导入?"
|
||||
assert "警告1" in d["warnings"]
|
||||
assert d["backup_summary"]["tables"] == ["table1"]
|
||||
|
||||
|
||||
class TestPreCheck:
|
||||
"""预检查功能测试"""
|
||||
|
||||
def test_pre_check_file_not_exists(self, mock_main_db):
|
||||
"""测试预检查不存在的文件"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check("/nonexistent/file.zip")
|
||||
|
||||
assert result.valid is False
|
||||
assert "不存在" in result.error
|
||||
|
||||
def test_pre_check_invalid_zip(self, mock_main_db, tmp_path):
|
||||
"""测试预检查无效的 ZIP 文件"""
|
||||
invalid_zip = tmp_path / "invalid.zip"
|
||||
invalid_zip.write_text("not a zip file")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(invalid_zip))
|
||||
|
||||
assert result.valid is False
|
||||
assert "ZIP" in result.error or "无效" in result.error
|
||||
|
||||
def test_pre_check_missing_manifest(self, mock_main_db, tmp_path):
|
||||
"""测试预检查缺少 manifest 的 ZIP 文件"""
|
||||
zip_path = tmp_path / "no_manifest.zip"
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("test.txt", "test content")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is False
|
||||
assert "manifest" in result.error.lower()
|
||||
|
||||
def test_pre_check_version_match(self, mock_main_db, tmp_path):
|
||||
"""测试预检查版本匹配"""
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
manifest = {
|
||||
"version": "1.1",
|
||||
"astrbot_version": VERSION,
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"tables": {"platform_stats": 1},
|
||||
"has_knowledge_bases": True,
|
||||
"has_config": True,
|
||||
"directories": ["plugins"],
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is True
|
||||
assert result.can_import is True
|
||||
assert result.version_status == "match"
|
||||
assert result.backup_version == VERSION
|
||||
# confirm_message 现在由前端生成,后端不再生成
|
||||
assert result.backup_summary["has_knowledge_bases"] is True
|
||||
|
||||
def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path):
|
||||
"""测试预检查小版本差异"""
|
||||
# 构造一个同主版本但小版本不同的版本
|
||||
major_version = _get_major_version(VERSION)
|
||||
minor_diff_version = f"{major_version}.999"
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
manifest = {
|
||||
"version": "1.1",
|
||||
"astrbot_version": minor_diff_version,
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"tables": {},
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is True
|
||||
assert result.can_import is True
|
||||
assert result.version_status == "minor_diff"
|
||||
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
def test_pre_check_major_version_diff(self, mock_main_db, tmp_path):
|
||||
"""测试预检查主版本差异"""
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
manifest = {
|
||||
"version": "1.1",
|
||||
"astrbot_version": "0.0.1", # 主版本不同
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"tables": {},
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is True # 文件有效
|
||||
assert result.can_import is False # 但不能导入
|
||||
assert result.version_status == "major_diff"
|
||||
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
|
||||
|
||||
|
||||
class TestVersionCompatibility:
|
||||
"""版本兼容性检查测试"""
|
||||
|
||||
def test_check_version_compatibility_match(self, mock_main_db):
|
||||
"""测试版本完全匹配"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility(VERSION)
|
||||
|
||||
assert result["status"] == "match"
|
||||
assert result["can_import"] is True
|
||||
|
||||
def test_check_version_compatibility_minor_diff(self, mock_main_db):
|
||||
"""测试小版本差异"""
|
||||
major_version = _get_major_version(VERSION)
|
||||
minor_diff_version = f"{major_version}.999"
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility(minor_diff_version)
|
||||
|
||||
assert result["status"] == "minor_diff"
|
||||
assert result["can_import"] is True
|
||||
|
||||
def test_check_version_compatibility_major_diff(self, mock_main_db):
|
||||
"""测试主版本差异"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility("0.0.1")
|
||||
|
||||
assert result["status"] == "major_diff"
|
||||
assert result["can_import"] is False
|
||||
|
||||
def test_check_version_compatibility_empty_version(self, mock_main_db):
|
||||
"""测试空版本号"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility("")
|
||||
|
||||
assert result["status"] == "major_diff"
|
||||
assert result["can_import"] is False
|
||||
|
||||
|
||||
class TestModelMappings:
|
||||
"""测试模型映射配置"""
|
||||
|
||||
def test_main_db_models_not_empty(self):
|
||||
"""测试主数据库模型映射非空"""
|
||||
assert len(MAIN_DB_MODELS) > 0
|
||||
|
||||
def test_main_db_models_contain_expected_tables(self):
|
||||
"""测试主数据库模型映射包含预期的表"""
|
||||
expected_tables = [
|
||||
"platform_stats",
|
||||
"conversations",
|
||||
"personas",
|
||||
"preferences",
|
||||
"attachments",
|
||||
]
|
||||
for table in expected_tables:
|
||||
assert table in MAIN_DB_MODELS, f"Missing table: {table}"
|
||||
|
||||
def test_kb_metadata_models_not_empty(self):
|
||||
"""测试知识库元数据模型映射非空"""
|
||||
assert len(KB_METADATA_MODELS) > 0
|
||||
|
||||
def test_kb_metadata_models_contain_expected_tables(self):
|
||||
"""测试知识库元数据模型映射包含预期的表"""
|
||||
expected_tables = [
|
||||
"knowledge_bases",
|
||||
"kb_documents",
|
||||
"kb_media",
|
||||
]
|
||||
for table in expected_tables:
|
||||
assert table in KB_METADATA_MODELS, f"Missing table: {table}"
|
||||
|
||||
|
||||
class TestBackupIntegration:
|
||||
"""备份集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_import_roundtrip(self, tmp_path):
|
||||
"""测试导出-导入往返"""
|
||||
backup_dir = tmp_path / "backups"
|
||||
backup_dir.mkdir()
|
||||
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
config_path = data_dir / "cmd_config.json"
|
||||
config_path.write_text(json.dumps({"setting": "value"}))
|
||||
|
||||
attachments_dir = data_dir / "attachments"
|
||||
attachments_dir.mkdir()
|
||||
|
||||
# 创建模拟数据库
|
||||
mock_db = MagicMock()
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalars.return_value.all.return_value = []
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
mock_db.get_db.return_value = AsyncMock(
|
||||
__aenter__=AsyncMock(return_value=session),
|
||||
__aexit__=AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
# 导出
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_db,
|
||||
kb_manager=None,
|
||||
config_path=str(config_path),
|
||||
)
|
||||
|
||||
zip_path = await exporter.export_all(output_dir=str(backup_dir))
|
||||
assert os.path.exists(zip_path)
|
||||
|
||||
# 验证 ZIP 内容
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert manifest["astrbot_version"] == VERSION
|
||||
|
||||
# 读取配置
|
||||
config = json.loads(zf.read("config/cmd_config.json"))
|
||||
assert config["setting"] == "value"
|
||||
|
||||
# 读取主数据库
|
||||
main_db = json.loads(zf.read("databases/main_db.json"))
|
||||
assert "platform_stats" in main_db
|
||||
Reference in New Issue
Block a user