4abea2bd30
* fix: harden backup import for duplicate platform stats - 修复 replace 模式下主库清空失败仍继续导入的问题。 - 导入前对 platform_stats 重复键做聚合(count 累加),并统一时间戳判重格式。 - 非法 count 按 0 处理并告警(限流),补充对应测试。 * refactor: improve robustness and readability of platform stats import - 告警上限魔法数字提取为模块常量 PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - 抽取 parse_count 内联函数,消除重复的 try/except 分支 - 存储行的 timestamp 同步写入规范化值,避免落库格式混用 - 补充测试:已有行 count 非法、告警限流、replace 模式中断断言 * fix: normalize invalid platform_stats count for non-duplicate rows * fix: avoid merging invalid platform_stats timestamps * refactor: simplify platform stats merge and normalize naive UTC * refactor: inline platform stats merge helpers * refactor: flatten platform stats merge flow * refactor: harden platform stats merge key handling * refactor: streamline platform stats preprocessing * refactor: simplify platform stats merge helpers * refactor: inline platform stats merge normalization * refactor: extract platform stats merge helpers * refactor: simplify platform stats preprocessing flow * refactor: flatten platform stats preprocess helpers * refactor: streamline platform stats merge helpers * refactor: isolate platform stats warning limiter --------- Co-authored-by: 邹永赫 <1259085392@qq.com>
947 lines
35 KiB
Python
947 lines
35 KiB
Python
"""AstrBot 数据导入器
|
||
|
||
负责从 ZIP 备份文件恢复所有数据。
|
||
导入时进行版本校验:
|
||
- 主版本(前两位)不同时直接拒绝导入
|
||
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||
- 版本匹配时也需要用户确认
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import shutil
|
||
import zipfile
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Any
|
||
|
||
from sqlalchemy import delete
|
||
|
||
from astrbot.core import logger
|
||
from astrbot.core.config.default import VERSION
|
||
from astrbot.core.db import BaseDatabase
|
||
from astrbot.core.utils.astrbot_path import (
|
||
get_astrbot_data_path,
|
||
get_astrbot_knowledge_base_path,
|
||
)
|
||
from astrbot.core.utils.version_comparator import VersionComparator
|
||
|
||
# 从共享常量模块导入
|
||
from .constants import (
|
||
KB_METADATA_MODELS,
|
||
MAIN_DB_MODELS,
|
||
get_backup_directories,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||
|
||
|
||
def _get_major_version(version_str: str) -> str:
|
||
"""提取版本的主版本部分(前两位)
|
||
|
||
Args:
|
||
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||
|
||
Returns:
|
||
主版本字符串,如 "4.9", "4.10"
|
||
"""
|
||
if not version_str:
|
||
return "0.0"
|
||
# 移除 v 前缀和预发布标签
|
||
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||
if len(parts) >= 2:
|
||
return f"{parts[0]}.{parts[1]}"
|
||
elif len(parts) == 1 and parts[0]:
|
||
return f"{parts[0]}.0"
|
||
return "0.0"
|
||
|
||
|
||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||
KB_PATH = get_astrbot_knowledge_base_path()
|
||
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
|
||
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = (
|
||
"ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT"
|
||
)
|
||
|
||
|
||
def _load_platform_stats_invalid_count_warn_limit() -> int:
|
||
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
|
||
if raw_value is None:
|
||
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
|
||
|
||
try:
|
||
value = int(raw_value)
|
||
if value < 0:
|
||
raise ValueError("negative")
|
||
return value
|
||
except (TypeError, ValueError):
|
||
logger.warning(
|
||
"Invalid env %s=%r, fallback to default %d",
|
||
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
|
||
raw_value,
|
||
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
|
||
)
|
||
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
|
||
|
||
|
||
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = (
|
||
_load_platform_stats_invalid_count_warn_limit()
|
||
)
|
||
|
||
|
||
class _InvalidCountWarnLimiter:
|
||
"""Rate-limit warnings for invalid platform_stats count values."""
|
||
|
||
def __init__(self, limit: int) -> None:
|
||
self.limit = limit
|
||
self._count = 0
|
||
self._suppression_logged = False
|
||
|
||
def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None:
|
||
if self.limit > 0:
|
||
if self._count < self.limit:
|
||
logger.warning(
|
||
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
|
||
value,
|
||
key_for_log,
|
||
)
|
||
self._count += 1
|
||
if self._count == self.limit and not self._suppression_logged:
|
||
logger.warning(
|
||
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
|
||
self.limit,
|
||
)
|
||
self._suppression_logged = True
|
||
return
|
||
|
||
if not self._suppression_logged:
|
||
# limit <= 0: emit only one suppression warning.
|
||
logger.warning(
|
||
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
|
||
self.limit,
|
||
)
|
||
self._suppression_logged = True
|
||
|
||
|
||
@dataclass
|
||
class ImportPreCheckResult:
|
||
"""导入预检查结果
|
||
|
||
用于在实际导入前检查备份文件的版本兼容性,
|
||
并返回确认信息让用户决定是否继续导入。
|
||
"""
|
||
|
||
# 检查是否通过(文件有效且版本可导入)
|
||
valid: bool = False
|
||
# 是否可以导入(版本兼容)
|
||
can_import: bool = False
|
||
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||
version_status: str = ""
|
||
# 备份文件中的 AstrBot 版本
|
||
backup_version: str = ""
|
||
# 当前运行的 AstrBot 版本
|
||
current_version: str = VERSION
|
||
# 备份创建时间
|
||
backup_time: str = ""
|
||
# 确认消息(显示给用户)
|
||
confirm_message: str = ""
|
||
# 警告消息列表
|
||
warnings: list[str] = field(default_factory=list)
|
||
# 错误消息(如果检查失败)
|
||
error: str = ""
|
||
# 备份包含的内容摘要
|
||
backup_summary: dict = field(default_factory=dict)
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"valid": self.valid,
|
||
"can_import": self.can_import,
|
||
"version_status": self.version_status,
|
||
"backup_version": self.backup_version,
|
||
"current_version": self.current_version,
|
||
"backup_time": self.backup_time,
|
||
"confirm_message": self.confirm_message,
|
||
"warnings": self.warnings,
|
||
"error": self.error,
|
||
"backup_summary": self.backup_summary,
|
||
}
|
||
|
||
|
||
class ImportResult:
|
||
"""导入结果"""
|
||
|
||
def __init__(self) -> None:
|
||
self.success = True
|
||
self.imported_tables: dict[str, int] = {}
|
||
self.imported_files: dict[str, int] = {}
|
||
self.imported_directories: dict[str, int] = {}
|
||
self.warnings: list[str] = []
|
||
self.errors: list[str] = []
|
||
|
||
def add_warning(self, msg: str) -> None:
|
||
self.warnings.append(msg)
|
||
logger.warning(msg)
|
||
|
||
def add_error(self, msg: str) -> None:
|
||
self.errors.append(msg)
|
||
self.success = False
|
||
logger.error(msg)
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"success": self.success,
|
||
"imported_tables": self.imported_tables,
|
||
"imported_files": self.imported_files,
|
||
"imported_directories": self.imported_directories,
|
||
"warnings": self.warnings,
|
||
"errors": self.errors,
|
||
}
|
||
|
||
|
||
class DatabaseClearError(RuntimeError):
|
||
"""Raised when clearing the main database in replace mode fails."""
|
||
|
||
|
||
class AstrBotImporter:
|
||
"""AstrBot 数据导入器
|
||
|
||
导入备份文件中的所有数据,包括:
|
||
- 主数据库所有表
|
||
- 知识库元数据和文档
|
||
- 配置文件
|
||
- 附件文件
|
||
- 知识库多媒体文件
|
||
- 插件目录(data/plugins)
|
||
- 插件数据目录(data/plugin_data)
|
||
- 配置目录(data/config)
|
||
- T2I 模板目录(data/t2i_templates)
|
||
- WebChat 数据目录(data/webchat)
|
||
- 临时文件目录(data/temp)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
main_db: BaseDatabase,
|
||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||
kb_root_dir: str = KB_PATH,
|
||
) -> None:
|
||
self.main_db = main_db
|
||
self.kb_manager = kb_manager
|
||
self.config_path = config_path
|
||
self.kb_root_dir = kb_root_dir
|
||
|
||
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||
"""预检查备份文件
|
||
|
||
在实际导入前检查备份文件的有效性和版本兼容性。
|
||
返回检查结果供前端显示确认对话框。
|
||
|
||
Args:
|
||
zip_path: ZIP 备份文件路径
|
||
|
||
Returns:
|
||
ImportPreCheckResult: 预检查结果
|
||
"""
|
||
result = ImportPreCheckResult()
|
||
result.current_version = VERSION
|
||
|
||
if not os.path.exists(zip_path):
|
||
result.error = f"备份文件不存在: {zip_path}"
|
||
return result
|
||
|
||
try:
|
||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||
# 读取 manifest
|
||
try:
|
||
manifest_data = zf.read("manifest.json")
|
||
manifest = json.loads(manifest_data)
|
||
except KeyError:
|
||
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||
return result
|
||
except json.JSONDecodeError as e:
|
||
result.error = f"manifest.json 格式错误: {e}"
|
||
return result
|
||
|
||
# 提取基本信息
|
||
result.backup_version = manifest.get("astrbot_version", "未知")
|
||
result.backup_time = manifest.get("exported_at", "未知")
|
||
result.valid = True
|
||
|
||
# 构建备份摘要
|
||
result.backup_summary = {
|
||
"tables": list(manifest.get("tables", {}).keys()),
|
||
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||
"has_config": manifest.get("has_config", False),
|
||
"directories": manifest.get("directories", []),
|
||
}
|
||
|
||
# 检查版本兼容性
|
||
version_check = self._check_version_compatibility(result.backup_version)
|
||
result.version_status = version_check["status"]
|
||
result.can_import = version_check["can_import"]
|
||
|
||
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||
# warnings 列表保留用于其他非版本相关的警告
|
||
|
||
return result
|
||
|
||
except zipfile.BadZipFile:
|
||
result.error = "无效的 ZIP 文件"
|
||
return result
|
||
except Exception as e:
|
||
result.error = f"检查备份文件失败: {e}"
|
||
return result
|
||
|
||
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||
"""检查版本兼容性
|
||
|
||
规则:
|
||
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||
|
||
Returns:
|
||
dict: {status, can_import, message}
|
||
"""
|
||
if not backup_version:
|
||
return {
|
||
"status": "major_diff",
|
||
"can_import": False,
|
||
"message": "备份文件缺少版本信息",
|
||
}
|
||
|
||
# 提取主版本(前两位)进行比较
|
||
backup_major = _get_major_version(backup_version)
|
||
current_major = _get_major_version(VERSION)
|
||
|
||
# 比较主版本
|
||
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||
return {
|
||
"status": "major_diff",
|
||
"can_import": False,
|
||
"message": (
|
||
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||
),
|
||
}
|
||
|
||
# 比较完整版本
|
||
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||
if version_cmp != 0:
|
||
return {
|
||
"status": "minor_diff",
|
||
"can_import": True,
|
||
"message": (
|
||
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||
),
|
||
}
|
||
|
||
return {
|
||
"status": "match",
|
||
"can_import": True,
|
||
"message": "版本匹配",
|
||
}
|
||
|
||
async def import_all(
|
||
self,
|
||
zip_path: str,
|
||
mode: str = "replace", # "replace" 清空后导入
|
||
progress_callback: Any | None = None,
|
||
) -> ImportResult:
|
||
"""从 ZIP 文件导入所有数据
|
||
|
||
Args:
|
||
zip_path: ZIP 备份文件路径
|
||
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||
|
||
Returns:
|
||
ImportResult: 导入结果
|
||
"""
|
||
result = ImportResult()
|
||
|
||
if not os.path.exists(zip_path):
|
||
result.add_error(f"备份文件不存在: {zip_path}")
|
||
return result
|
||
|
||
logger.info(f"开始从 {zip_path} 导入备份")
|
||
|
||
try:
|
||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||
# 1. 读取并验证 manifest
|
||
if progress_callback:
|
||
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||
|
||
try:
|
||
manifest_data = zf.read("manifest.json")
|
||
manifest = json.loads(manifest_data)
|
||
except KeyError:
|
||
result.add_error("备份文件缺少 manifest.json")
|
||
return result
|
||
except json.JSONDecodeError as e:
|
||
result.add_error(f"manifest.json 格式错误: {e}")
|
||
return result
|
||
|
||
# 版本校验
|
||
try:
|
||
self._validate_version(manifest)
|
||
except ValueError as e:
|
||
result.add_error(str(e))
|
||
return result
|
||
|
||
if progress_callback:
|
||
await progress_callback("validate", 100, 100, "验证完成")
|
||
|
||
# 2. 导入主数据库
|
||
if progress_callback:
|
||
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||
|
||
try:
|
||
main_data_content = zf.read("databases/main_db.json")
|
||
main_data = json.loads(main_data_content)
|
||
|
||
if mode == "replace":
|
||
await self._clear_main_db()
|
||
|
||
imported = await self._import_main_database(main_data)
|
||
result.imported_tables.update(imported)
|
||
except DatabaseClearError as e:
|
||
result.add_error(f"清空主数据库失败: {e}")
|
||
return result
|
||
except Exception as e:
|
||
result.add_error(f"导入主数据库失败: {e}")
|
||
return result
|
||
|
||
if progress_callback:
|
||
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||
|
||
# 3. 导入知识库
|
||
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||
if progress_callback:
|
||
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||
|
||
try:
|
||
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||
kb_meta_data = json.loads(kb_meta_content)
|
||
|
||
if mode == "replace":
|
||
await self._clear_kb_data()
|
||
|
||
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||
except Exception as e:
|
||
result.add_warning(f"导入知识库失败: {e}")
|
||
|
||
if progress_callback:
|
||
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||
|
||
# 4. 导入配置文件
|
||
if progress_callback:
|
||
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||
|
||
if "config/cmd_config.json" in zf.namelist():
|
||
try:
|
||
config_content = zf.read("config/cmd_config.json")
|
||
# 备份现有配置
|
||
if os.path.exists(self.config_path):
|
||
backup_path = f"{self.config_path}.bak"
|
||
shutil.copy2(self.config_path, backup_path)
|
||
|
||
with open(self.config_path, "wb") as f:
|
||
f.write(config_content)
|
||
result.imported_files["config"] = 1
|
||
except Exception as e:
|
||
result.add_warning(f"导入配置文件失败: {e}")
|
||
|
||
if progress_callback:
|
||
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||
|
||
# 5. 导入附件文件
|
||
if progress_callback:
|
||
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||
|
||
attachment_count = await self._import_attachments(
|
||
zf, main_data.get("attachments", [])
|
||
)
|
||
result.imported_files["attachments"] = attachment_count
|
||
|
||
if progress_callback:
|
||
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||
|
||
# 6. 导入插件和其他目录
|
||
if progress_callback:
|
||
await progress_callback(
|
||
"directories", 0, 100, "正在导入插件和数据目录..."
|
||
)
|
||
|
||
dir_stats = await self._import_directories(zf, manifest, result)
|
||
result.imported_directories = dir_stats
|
||
|
||
if progress_callback:
|
||
await progress_callback("directories", 100, 100, "目录导入完成")
|
||
|
||
logger.info(f"备份导入完成: {result.to_dict()}")
|
||
return result
|
||
|
||
except zipfile.BadZipFile:
|
||
result.add_error("无效的 ZIP 文件")
|
||
return result
|
||
except Exception as e:
|
||
result.add_error(f"导入失败: {e}")
|
||
return result
|
||
|
||
def _validate_version(self, manifest: dict) -> None:
|
||
"""验证版本兼容性 - 仅允许相同主版本导入
|
||
|
||
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||
"""
|
||
backup_version = manifest.get("astrbot_version")
|
||
if not backup_version:
|
||
raise ValueError("备份文件缺少版本信息")
|
||
|
||
# 使用新的版本兼容性检查
|
||
version_check = self._check_version_compatibility(backup_version)
|
||
|
||
if version_check["status"] == "major_diff":
|
||
raise ValueError(version_check["message"])
|
||
|
||
# minor_diff 和 match 都允许导入
|
||
if version_check["status"] == "minor_diff":
|
||
logger.warning(f"版本差异警告: {version_check['message']}")
|
||
|
||
async def _clear_main_db(self) -> None:
|
||
"""清空主数据库所有表"""
|
||
async with self.main_db.get_db() as session:
|
||
async with session.begin():
|
||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||
try:
|
||
await session.execute(delete(model_class))
|
||
logger.debug(f"已清空表 {table_name}")
|
||
except Exception as e:
|
||
raise DatabaseClearError(
|
||
f"清空表 {table_name} 失败: {e}"
|
||
) from e
|
||
|
||
async def _clear_kb_data(self) -> None:
|
||
"""清空知识库数据"""
|
||
if not self.kb_manager:
|
||
return
|
||
|
||
# 清空知识库元数据表
|
||
async with self.kb_manager.kb_db.get_db() as session:
|
||
async with session.begin():
|
||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||
try:
|
||
await session.execute(delete(model_class))
|
||
logger.debug(f"已清空知识库表 {table_name}")
|
||
except Exception as e:
|
||
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||
|
||
# 删除知识库文件目录
|
||
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||
try:
|
||
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||
await kb_helper.terminate()
|
||
if kb_helper.kb_dir.exists():
|
||
shutil.rmtree(kb_helper.kb_dir)
|
||
except Exception as e:
|
||
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||
|
||
self.kb_manager.kb_insts.clear()
|
||
|
||
async def _import_main_database(
|
||
self, data: dict[str, list[dict]]
|
||
) -> dict[str, int]:
|
||
"""导入主数据库数据"""
|
||
imported: dict[str, int] = {}
|
||
|
||
async with self.main_db.get_db() as session:
|
||
async with session.begin():
|
||
for table_name, rows in data.items():
|
||
model_class = MAIN_DB_MODELS.get(table_name)
|
||
if not model_class:
|
||
logger.warning(f"未知的表: {table_name}")
|
||
continue
|
||
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
|
||
|
||
count = 0
|
||
for row in normalized_rows:
|
||
try:
|
||
# 转换 datetime 字符串为 datetime 对象
|
||
row = self._convert_datetime_fields(row, model_class)
|
||
obj = model_class(**row)
|
||
session.add(obj)
|
||
count += 1
|
||
except Exception as e:
|
||
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||
|
||
imported[table_name] = count
|
||
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||
|
||
return imported
|
||
|
||
def _preprocess_main_table_rows(
|
||
self, table_name: str, rows: list[dict[str, Any]]
|
||
) -> list[dict[str, Any]]:
|
||
if table_name == "platform_stats":
|
||
normalized_rows = self._merge_platform_stats_rows(rows)
|
||
duplicate_count = len(rows) - len(normalized_rows)
|
||
if duplicate_count > 0:
|
||
logger.warning(
|
||
"检测到 %s 重复键 %d 条,已在导入前聚合",
|
||
table_name,
|
||
duplicate_count,
|
||
)
|
||
return normalized_rows
|
||
return rows
|
||
|
||
def _merge_platform_stats_rows(
|
||
self, rows: list[dict[str, Any]]
|
||
) -> list[dict[str, Any]]:
|
||
"""Merge duplicate platform_stats rows by normalized timestamp/platform key.
|
||
|
||
Note:
|
||
- Invalid/empty timestamps are kept as distinct rows to avoid accidental merging.
|
||
- Non-string platform_id/platform_type are kept as distinct rows.
|
||
- Invalid count warnings are rate-limited per function invocation.
|
||
"""
|
||
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
|
||
result: list[dict[str, Any]] = []
|
||
warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT)
|
||
|
||
for row in rows:
|
||
normalized_row, normalized_timestamp, count = (
|
||
self._normalize_platform_stats_entry(row, warn_limiter)
|
||
)
|
||
platform_id = normalized_row.get("platform_id")
|
||
platform_type = normalized_row.get("platform_type")
|
||
|
||
if (
|
||
normalized_timestamp is None
|
||
or not isinstance(platform_id, str)
|
||
or not isinstance(platform_type, str)
|
||
):
|
||
result.append(normalized_row)
|
||
continue
|
||
|
||
merge_key = (normalized_timestamp, platform_id, platform_type)
|
||
existing = merged.get(merge_key)
|
||
if existing is None:
|
||
merged[merge_key] = normalized_row
|
||
result.append(normalized_row)
|
||
else:
|
||
existing["count"] += count
|
||
|
||
return result
|
||
|
||
def _normalize_platform_stats_entry(
|
||
self,
|
||
row: dict[str, Any],
|
||
warn_limiter: _InvalidCountWarnLimiter,
|
||
) -> tuple[dict[str, Any], str | None, int]:
|
||
normalized_row = dict(row)
|
||
raw_timestamp = normalized_row.get("timestamp")
|
||
normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp)
|
||
|
||
if normalized_timestamp is not None:
|
||
normalized_row["timestamp"] = normalized_timestamp
|
||
elif isinstance(raw_timestamp, str):
|
||
normalized_row["timestamp"] = raw_timestamp.strip()
|
||
elif raw_timestamp is None:
|
||
normalized_row["timestamp"] = ""
|
||
else:
|
||
normalized_row["timestamp"] = str(raw_timestamp)
|
||
|
||
raw_count = normalized_row.get("count", 0)
|
||
try:
|
||
count = int(raw_count)
|
||
except (TypeError, ValueError):
|
||
key_for_log = (
|
||
normalized_row.get("timestamp"),
|
||
repr(normalized_row.get("platform_id")),
|
||
repr(normalized_row.get("platform_type")),
|
||
)
|
||
warn_limiter.warn_invalid_count(raw_count, key_for_log)
|
||
count = 0
|
||
|
||
normalized_row["count"] = count
|
||
return normalized_row, normalized_timestamp, count
|
||
|
||
def _normalize_platform_stats_timestamp(self, value: Any) -> str | None:
|
||
if isinstance(value, datetime):
|
||
dt = value
|
||
if dt.tzinfo is None:
|
||
dt = dt.replace(tzinfo=timezone.utc)
|
||
else:
|
||
dt = dt.astimezone(timezone.utc)
|
||
return dt.isoformat()
|
||
if isinstance(value, str):
|
||
timestamp = value.strip()
|
||
if not timestamp:
|
||
return None
|
||
if timestamp.endswith("Z"):
|
||
timestamp = f"{timestamp[:-1]}+00:00"
|
||
try:
|
||
dt = datetime.fromisoformat(timestamp)
|
||
if dt.tzinfo is None:
|
||
dt = dt.replace(tzinfo=timezone.utc)
|
||
else:
|
||
dt = dt.astimezone(timezone.utc)
|
||
return dt.isoformat()
|
||
except ValueError:
|
||
return None
|
||
return None
|
||
|
||
async def _import_knowledge_bases(
|
||
self,
|
||
zf: zipfile.ZipFile,
|
||
kb_meta_data: dict[str, list[dict]],
|
||
result: ImportResult,
|
||
) -> None:
|
||
"""导入知识库数据"""
|
||
if not self.kb_manager:
|
||
return
|
||
|
||
# 1. 导入知识库元数据
|
||
async with self.kb_manager.kb_db.get_db() as session:
|
||
async with session.begin():
|
||
for table_name, rows in kb_meta_data.items():
|
||
model_class = KB_METADATA_MODELS.get(table_name)
|
||
if not model_class:
|
||
continue
|
||
|
||
count = 0
|
||
for row in rows:
|
||
try:
|
||
row = self._convert_datetime_fields(row, model_class)
|
||
obj = model_class(**row)
|
||
session.add(obj)
|
||
count += 1
|
||
except Exception as e:
|
||
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||
|
||
result.imported_tables[f"kb_{table_name}"] = count
|
||
|
||
# 2. 导入每个知识库的文档和文件
|
||
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||
kb_id = kb_data.get("kb_id")
|
||
if not kb_id:
|
||
continue
|
||
|
||
# 创建知识库目录
|
||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 导入文档数据
|
||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||
if doc_path in zf.namelist():
|
||
try:
|
||
doc_content = zf.read(doc_path)
|
||
doc_data = json.loads(doc_content)
|
||
|
||
# 导入到文档存储数据库
|
||
await self._import_kb_documents(kb_id, doc_data)
|
||
except Exception as e:
|
||
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||
|
||
# 导入 FAISS 索引
|
||
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||
if faiss_path in zf.namelist():
|
||
try:
|
||
target_path = kb_dir / "index.faiss"
|
||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||
dst.write(src.read())
|
||
except Exception as e:
|
||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||
|
||
# 导入媒体文件
|
||
media_prefix = f"files/kb_media/{kb_id}/"
|
||
for name in zf.namelist():
|
||
if name.startswith(media_prefix):
|
||
try:
|
||
rel_path = name[len(media_prefix) :]
|
||
target_path = kb_dir / rel_path
|
||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||
dst.write(src.read())
|
||
except Exception as e:
|
||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||
|
||
# 3. 重新加载知识库实例
|
||
await self.kb_manager.load_kbs()
|
||
|
||
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||
"""导入知识库文档到向量数据库"""
|
||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||
|
||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||
doc_db_path = kb_dir / "doc.db"
|
||
|
||
# 初始化文档存储
|
||
doc_storage = DocumentStorage(str(doc_db_path))
|
||
await doc_storage.initialize()
|
||
|
||
try:
|
||
documents = doc_data.get("documents", [])
|
||
for doc in documents:
|
||
try:
|
||
await doc_storage.insert_document(
|
||
doc_id=doc.get("doc_id", ""),
|
||
text=doc.get("text", ""),
|
||
metadata=json.loads(doc.get("metadata", "{}")),
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"导入文档块失败: {e}")
|
||
finally:
|
||
await doc_storage.close()
|
||
|
||
async def _import_attachments(
|
||
self,
|
||
zf: zipfile.ZipFile,
|
||
attachments: list[dict],
|
||
) -> int:
|
||
"""导入附件文件"""
|
||
count = 0
|
||
|
||
attachments_dir = Path(self.config_path).parent / "attachments"
|
||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
attachment_prefix = "files/attachments/"
|
||
for name in zf.namelist():
|
||
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||
try:
|
||
# 从附件记录中找到原始路径
|
||
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||
original_path = None
|
||
for att in attachments:
|
||
if att.get("attachment_id") == attachment_id:
|
||
original_path = att.get("path")
|
||
break
|
||
|
||
if original_path:
|
||
target_path = Path(original_path)
|
||
else:
|
||
target_path = attachments_dir / os.path.basename(name)
|
||
|
||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||
dst.write(src.read())
|
||
count += 1
|
||
except Exception as e:
|
||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||
|
||
return count
|
||
|
||
async def _import_directories(
|
||
self,
|
||
zf: zipfile.ZipFile,
|
||
manifest: dict,
|
||
result: ImportResult,
|
||
) -> dict[str, int]:
|
||
"""导入插件和其他数据目录
|
||
|
||
Args:
|
||
zf: ZIP 文件对象
|
||
manifest: 备份清单
|
||
result: 导入结果对象
|
||
|
||
Returns:
|
||
dict: 每个目录导入的文件数量
|
||
"""
|
||
dir_stats: dict[str, int] = {}
|
||
|
||
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||
backup_version = manifest.get("version", "1.0")
|
||
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||
return dir_stats
|
||
|
||
backed_up_dirs = manifest.get("directories", [])
|
||
backup_directories = get_backup_directories()
|
||
|
||
for dir_name in backed_up_dirs:
|
||
if dir_name not in backup_directories:
|
||
result.add_warning(f"未知的目录类型: {dir_name}")
|
||
continue
|
||
|
||
target_dir = Path(backup_directories[dir_name])
|
||
archive_prefix = f"directories/{dir_name}/"
|
||
|
||
file_count = 0
|
||
|
||
try:
|
||
# 获取该目录下的所有文件
|
||
dir_files = [
|
||
name
|
||
for name in zf.namelist()
|
||
if name.startswith(archive_prefix) and name != archive_prefix
|
||
]
|
||
|
||
if not dir_files:
|
||
continue
|
||
|
||
# 备份现有目录(如果存在)
|
||
if target_dir.exists():
|
||
backup_path = Path(f"{target_dir}.bak")
|
||
if backup_path.exists():
|
||
shutil.rmtree(backup_path)
|
||
shutil.move(str(target_dir), str(backup_path))
|
||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||
|
||
# 创建目标目录
|
||
target_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 解压文件
|
||
for name in dir_files:
|
||
try:
|
||
# 计算相对路径
|
||
rel_path = name[len(archive_prefix) :]
|
||
if not rel_path: # 跳过目录条目
|
||
continue
|
||
|
||
target_path = target_dir / rel_path
|
||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||
dst.write(src.read())
|
||
file_count += 1
|
||
except Exception as e:
|
||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||
|
||
dir_stats[dir_name] = file_count
|
||
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||
|
||
except Exception as e:
|
||
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||
dir_stats[dir_name] = 0
|
||
|
||
return dir_stats
|
||
|
||
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||
"""转换 datetime 字符串字段为 datetime 对象"""
|
||
result = row.copy()
|
||
|
||
# 获取模型的 datetime 字段
|
||
from sqlalchemy import inspect as sa_inspect
|
||
|
||
try:
|
||
mapper = sa_inspect(model_class)
|
||
for column in mapper.columns:
|
||
if column.name in result and result[column.name] is not None:
|
||
# 检查是否是 datetime 类型的列
|
||
from sqlalchemy import DateTime
|
||
|
||
if isinstance(column.type, DateTime):
|
||
value = result[column.name]
|
||
if isinstance(value, str):
|
||
# 解析 ISO 格式的日期时间字符串
|
||
result[column.name] = datetime.fromisoformat(value)
|
||
except Exception:
|
||
pass
|
||
|
||
return result
|