fix: harden backup import for duplicate platform stats (#5594)
* fix: harden backup import for duplicate platform stats - 修复 replace 模式下主库清空失败仍继续导入的问题。 - 导入前对 platform_stats 重复键做聚合(count 累加),并统一时间戳判重格式。 - 非法 count 按 0 处理并告警(限流),补充对应测试。 * refactor: improve robustness and readability of platform stats import - 告警上限魔法数字提取为模块常量 PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - 抽取 parse_count 内联函数,消除重复的 try/except 分支 - 存储行的 timestamp 同步写入规范化值,避免落库格式混用 - 补充测试:已有行 count 非法、告警限流、replace 模式中断断言 * fix: normalize invalid platform_stats count for non-duplicate rows * fix: avoid merging invalid platform_stats timestamps * refactor: simplify platform stats merge and normalize naive UTC * refactor: inline platform stats merge helpers * refactor: flatten platform stats merge flow * refactor: harden platform stats merge key handling * refactor: streamline platform stats preprocessing * refactor: simplify platform stats merge helpers * refactor: inline platform stats merge normalization * refactor: extract platform stats merge helpers * refactor: simplify platform stats preprocessing flow * refactor: flatten platform stats preprocess helpers * refactor: streamline platform stats merge helpers * refactor: isolate platform stats warning limiter --------- Co-authored-by: 邹永赫 <1259085392@qq.com>
This commit is contained in:
@@ -12,7 +12,7 @@ import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -61,6 +61,69 @@ def _get_major_version(version_str: str) -> str:
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
KB_PATH = get_astrbot_knowledge_base_path()
|
||||
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
|
||||
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = (
|
||||
"ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT"
|
||||
)
|
||||
|
||||
|
||||
def _load_platform_stats_invalid_count_warn_limit() -> int:
|
||||
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
|
||||
if raw_value is None:
|
||||
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
|
||||
|
||||
try:
|
||||
value = int(raw_value)
|
||||
if value < 0:
|
||||
raise ValueError("negative")
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid env %s=%r, fallback to default %d",
|
||||
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
|
||||
raw_value,
|
||||
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
|
||||
)
|
||||
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
|
||||
|
||||
|
||||
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = (
|
||||
_load_platform_stats_invalid_count_warn_limit()
|
||||
)
|
||||
|
||||
|
||||
class _InvalidCountWarnLimiter:
|
||||
"""Rate-limit warnings for invalid platform_stats count values."""
|
||||
|
||||
def __init__(self, limit: int) -> None:
|
||||
self.limit = limit
|
||||
self._count = 0
|
||||
self._suppression_logged = False
|
||||
|
||||
def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None:
|
||||
if self.limit > 0:
|
||||
if self._count < self.limit:
|
||||
logger.warning(
|
||||
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
|
||||
value,
|
||||
key_for_log,
|
||||
)
|
||||
self._count += 1
|
||||
if self._count == self.limit and not self._suppression_logged:
|
||||
logger.warning(
|
||||
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
|
||||
self.limit,
|
||||
)
|
||||
self._suppression_logged = True
|
||||
return
|
||||
|
||||
if not self._suppression_logged:
|
||||
# limit <= 0: emit only one suppression warning.
|
||||
logger.warning(
|
||||
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
|
||||
self.limit,
|
||||
)
|
||||
self._suppression_logged = True
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -138,6 +201,10 @@ class ImportResult:
|
||||
}
|
||||
|
||||
|
||||
class DatabaseClearError(RuntimeError):
|
||||
"""Raised when clearing the main database in replace mode fails."""
|
||||
|
||||
|
||||
class AstrBotImporter:
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
@@ -342,6 +409,9 @@ class AstrBotImporter:
|
||||
|
||||
imported = await self._import_main_database(main_data)
|
||||
result.imported_tables.update(imported)
|
||||
except DatabaseClearError as e:
|
||||
result.add_error(f"清空主数据库失败: {e}")
|
||||
return result
|
||||
except Exception as e:
|
||||
result.add_error(f"导入主数据库失败: {e}")
|
||||
return result
|
||||
@@ -452,7 +522,9 @@ class AstrBotImporter:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||
raise DatabaseClearError(
|
||||
f"清空表 {table_name} 失败: {e}"
|
||||
) from e
|
||||
|
||||
async def _clear_kb_data(self) -> None:
|
||||
"""清空知识库数据"""
|
||||
@@ -494,9 +566,10 @@ class AstrBotImporter:
|
||||
if not model_class:
|
||||
logger.warning(f"未知的表: {table_name}")
|
||||
continue
|
||||
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
for row in normalized_rows:
|
||||
try:
|
||||
# 转换 datetime 字符串为 datetime 对象
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
@@ -511,6 +584,118 @@ class AstrBotImporter:
|
||||
|
||||
return imported
|
||||
|
||||
def _preprocess_main_table_rows(
|
||||
self, table_name: str, rows: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
if table_name == "platform_stats":
|
||||
normalized_rows = self._merge_platform_stats_rows(rows)
|
||||
duplicate_count = len(rows) - len(normalized_rows)
|
||||
if duplicate_count > 0:
|
||||
logger.warning(
|
||||
"检测到 %s 重复键 %d 条,已在导入前聚合",
|
||||
table_name,
|
||||
duplicate_count,
|
||||
)
|
||||
return normalized_rows
|
||||
return rows
|
||||
|
||||
def _merge_platform_stats_rows(
|
||||
self, rows: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Merge duplicate platform_stats rows by normalized timestamp/platform key.
|
||||
|
||||
Note:
|
||||
- Invalid/empty timestamps are kept as distinct rows to avoid accidental merging.
|
||||
- Non-string platform_id/platform_type are kept as distinct rows.
|
||||
- Invalid count warnings are rate-limited per function invocation.
|
||||
"""
|
||||
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
|
||||
result: list[dict[str, Any]] = []
|
||||
warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT)
|
||||
|
||||
for row in rows:
|
||||
normalized_row, normalized_timestamp, count = (
|
||||
self._normalize_platform_stats_entry(row, warn_limiter)
|
||||
)
|
||||
platform_id = normalized_row.get("platform_id")
|
||||
platform_type = normalized_row.get("platform_type")
|
||||
|
||||
if (
|
||||
normalized_timestamp is None
|
||||
or not isinstance(platform_id, str)
|
||||
or not isinstance(platform_type, str)
|
||||
):
|
||||
result.append(normalized_row)
|
||||
continue
|
||||
|
||||
merge_key = (normalized_timestamp, platform_id, platform_type)
|
||||
existing = merged.get(merge_key)
|
||||
if existing is None:
|
||||
merged[merge_key] = normalized_row
|
||||
result.append(normalized_row)
|
||||
else:
|
||||
existing["count"] += count
|
||||
|
||||
return result
|
||||
|
||||
def _normalize_platform_stats_entry(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
warn_limiter: _InvalidCountWarnLimiter,
|
||||
) -> tuple[dict[str, Any], str | None, int]:
|
||||
normalized_row = dict(row)
|
||||
raw_timestamp = normalized_row.get("timestamp")
|
||||
normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp)
|
||||
|
||||
if normalized_timestamp is not None:
|
||||
normalized_row["timestamp"] = normalized_timestamp
|
||||
elif isinstance(raw_timestamp, str):
|
||||
normalized_row["timestamp"] = raw_timestamp.strip()
|
||||
elif raw_timestamp is None:
|
||||
normalized_row["timestamp"] = ""
|
||||
else:
|
||||
normalized_row["timestamp"] = str(raw_timestamp)
|
||||
|
||||
raw_count = normalized_row.get("count", 0)
|
||||
try:
|
||||
count = int(raw_count)
|
||||
except (TypeError, ValueError):
|
||||
key_for_log = (
|
||||
normalized_row.get("timestamp"),
|
||||
repr(normalized_row.get("platform_id")),
|
||||
repr(normalized_row.get("platform_type")),
|
||||
)
|
||||
warn_limiter.warn_invalid_count(raw_count, key_for_log)
|
||||
count = 0
|
||||
|
||||
normalized_row["count"] = count
|
||||
return normalized_row, normalized_timestamp, count
|
||||
|
||||
def _normalize_platform_stats_timestamp(self, value: Any) -> str | None:
|
||||
if isinstance(value, datetime):
|
||||
dt = value
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
return dt.isoformat()
|
||||
if isinstance(value, str):
|
||||
timestamp = value.strip()
|
||||
if not timestamp:
|
||||
return None
|
||||
if timestamp.endswith("Z"):
|
||||
timestamp = f"{timestamp[:-1]}+00:00"
|
||||
try:
|
||||
dt = datetime.fromisoformat(timestamp)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
return dt.isoformat()
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _import_knowledge_bases(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
|
||||
+324
-1
@@ -5,7 +5,7 @@ import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -17,6 +17,8 @@ from astrbot.core.backup import (
|
||||
)
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.backup.importer import (
|
||||
DatabaseClearError,
|
||||
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
|
||||
AstrBotImporter,
|
||||
ImportResult,
|
||||
_get_major_version,
|
||||
@@ -308,6 +310,298 @@ class TestAstrBotImporter:
|
||||
assert isinstance(result["created_at"], datetime)
|
||||
assert isinstance(result["updated_at"], datetime)
|
||||
|
||||
def test_merge_platform_stats_rows(self):
|
||||
"""测试 platform_stats 重复键会在导入前聚合"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
rows = [
|
||||
{
|
||||
"id": 1,
|
||||
"timestamp": "2025-12-13T20:00:00Z",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 14,
|
||||
},
|
||||
{
|
||||
"id": 80,
|
||||
"timestamp": "2025-12-13T20:00:00+00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 3,
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"timestamp": "2025-12-13T20:00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 2,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"timestamp": "2025-12-13T21:00:00",
|
||||
"platform_id": "aiocqhttp",
|
||||
"platform_type": "unknown",
|
||||
"count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
merged_rows = importer._merge_platform_stats_rows(rows)
|
||||
duplicate_count = len(rows) - len(merged_rows)
|
||||
|
||||
assert duplicate_count == 2
|
||||
assert len(merged_rows) == 2
|
||||
webchat_row = next(
|
||||
(
|
||||
r
|
||||
for r in merged_rows
|
||||
if r.get("timestamp") == "2025-12-13T20:00:00+00:00"
|
||||
and r.get("platform_id") == "webchat"
|
||||
and r.get("platform_type") == "unknown"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert webchat_row is not None
|
||||
assert webchat_row["timestamp"] == "2025-12-13T20:00:00+00:00"
|
||||
assert webchat_row["platform_id"] == "webchat"
|
||||
assert webchat_row["platform_type"] == "unknown"
|
||||
assert webchat_row["count"] == 19
|
||||
|
||||
aiocq_row = next(
|
||||
(
|
||||
r
|
||||
for r in merged_rows
|
||||
if r.get("platform_id") == "aiocqhttp"
|
||||
and r.get("platform_type") == "unknown"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert aiocq_row is not None
|
||||
assert aiocq_row["timestamp"] == "2025-12-13T21:00:00+00:00"
|
||||
|
||||
def test_merge_platform_stats_rows_normalizes_naive_timestamp_to_utc(self):
|
||||
"""测试 platform_stats 合并前会将 naive timestamp 标准化为 UTC 偏移"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
rows = [
|
||||
{
|
||||
"timestamp": "2025-12-13T21:00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 1,
|
||||
},
|
||||
{
|
||||
"timestamp": datetime(2025, 12, 13, 22, 0, 0),
|
||||
"platform_id": "telegram",
|
||||
"platform_type": "unknown",
|
||||
"count": 1,
|
||||
},
|
||||
]
|
||||
|
||||
merged_rows = importer._merge_platform_stats_rows(rows)
|
||||
assert len(merged_rows) == 2
|
||||
by_platform = {row["platform_id"]: row for row in merged_rows}
|
||||
assert by_platform["webchat"]["timestamp"] == "2025-12-13T21:00:00+00:00"
|
||||
assert by_platform["telegram"]["timestamp"] == "2025-12-13T22:00:00+00:00"
|
||||
|
||||
def test_merge_platform_stats_rows_warns_on_invalid_count(self):
|
||||
"""测试 platform_stats count 非法时会告警并按 0 处理(含上限)"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
|
||||
rows = [
|
||||
{
|
||||
"timestamp": "2025-12-13T20:00:00+00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 5,
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-12-13T20:00:00Z",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": "bad-count",
|
||||
},
|
||||
]
|
||||
merged_rows = importer._merge_platform_stats_rows(rows)
|
||||
duplicate_count = len(rows) - len(merged_rows)
|
||||
assert duplicate_count == 1
|
||||
assert len(merged_rows) == 1
|
||||
assert merged_rows[0]["count"] == 5
|
||||
assert warning_mock.call_count == 1
|
||||
|
||||
warning_mock.reset_mock()
|
||||
|
||||
rows_existing_invalid = [
|
||||
{
|
||||
"timestamp": "2025-12-13T21:00:00+00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": "bad-count",
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-12-13T21:00:00Z",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 7,
|
||||
},
|
||||
]
|
||||
merged_rows = importer._merge_platform_stats_rows(rows_existing_invalid)
|
||||
duplicate_count = len(rows_existing_invalid) - len(merged_rows)
|
||||
assert duplicate_count == 1
|
||||
assert len(merged_rows) == 1
|
||||
assert merged_rows[0]["count"] == 7
|
||||
assert warning_mock.call_count == 1
|
||||
|
||||
warning_mock.reset_mock()
|
||||
|
||||
many_invalid_rows = [
|
||||
{
|
||||
"timestamp": "2025-12-13T22:00:00+00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 1,
|
||||
},
|
||||
*[
|
||||
{
|
||||
"timestamp": "2025-12-13T22:00:00Z",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": "bad-count",
|
||||
}
|
||||
for _ in range(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 5)
|
||||
],
|
||||
]
|
||||
importer._merge_platform_stats_rows(many_invalid_rows)
|
||||
assert (
|
||||
warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 1
|
||||
)
|
||||
assert any(
|
||||
"告警已达到上限" in str(call.args[0])
|
||||
for call in warning_mock.call_args_list
|
||||
)
|
||||
|
||||
warning_mock.reset_mock()
|
||||
|
||||
single_invalid_row = [
|
||||
{
|
||||
"timestamp": "2025-12-13T23:00:00+00:00",
|
||||
"platform_id": "telegram",
|
||||
"platform_type": "unknown",
|
||||
"count": "still-bad",
|
||||
},
|
||||
]
|
||||
merged_rows = importer._merge_platform_stats_rows(single_invalid_row)
|
||||
duplicate_count = len(single_invalid_row) - len(merged_rows)
|
||||
assert duplicate_count == 0
|
||||
assert len(merged_rows) == 1
|
||||
assert merged_rows[0]["count"] == 0
|
||||
assert warning_mock.call_count == 1
|
||||
|
||||
def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self):
|
||||
"""测试空/非法 timestamp 不参与聚合,避免误合并"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
rows = [
|
||||
{
|
||||
"timestamp": "",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 2,
|
||||
},
|
||||
{
|
||||
"timestamp": "not-a-datetime",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 3,
|
||||
},
|
||||
{
|
||||
"timestamp": "not-a-datetime",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 4,
|
||||
},
|
||||
]
|
||||
|
||||
merged_rows = importer._merge_platform_stats_rows(rows)
|
||||
duplicate_count = len(rows) - len(merged_rows)
|
||||
|
||||
assert duplicate_count == 0
|
||||
assert len(merged_rows) == 3
|
||||
assert [row["count"] for row in merged_rows] == [2, 3, 4]
|
||||
|
||||
def test_merge_platform_stats_rows_keeps_non_string_platform_keys_distinct(self):
|
||||
"""测试非字符串 platform_id/platform_type 不参与聚合"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
rows = [
|
||||
{
|
||||
"timestamp": "2025-12-13T20:00:00+00:00",
|
||||
"platform_id": None,
|
||||
"platform_type": "unknown",
|
||||
"count": 2,
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-12-13T20:00:00Z",
|
||||
"platform_id": None,
|
||||
"platform_type": "unknown",
|
||||
"count": 3,
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-12-13T20:00:00+00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": 1,
|
||||
"count": 4,
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-12-13T20:00:00Z",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": 1,
|
||||
"count": 5,
|
||||
},
|
||||
]
|
||||
|
||||
merged_rows = importer._merge_platform_stats_rows(rows)
|
||||
duplicate_count = len(rows) - len(merged_rows)
|
||||
|
||||
assert duplicate_count == 0
|
||||
assert len(merged_rows) == 4
|
||||
|
||||
def test_merge_platform_stats_rows_preserves_input_order(self):
|
||||
"""测试 platform_stats 聚合后仍保持输入顺序(按首次出现位置)"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
rows = [
|
||||
{
|
||||
"id": 1,
|
||||
"timestamp": "2025-12-13T20:00:00Z",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 2,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"timestamp": "",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 3,
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"timestamp": "2025-12-13T20:00:00+00:00",
|
||||
"platform_id": "webchat",
|
||||
"platform_type": "unknown",
|
||||
"count": 5,
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"timestamp": "2025-12-13T21:00:00+00:00",
|
||||
"platform_id": "telegram",
|
||||
"platform_type": "unknown",
|
||||
"count": 7,
|
||||
},
|
||||
]
|
||||
|
||||
merged_rows = importer._merge_platform_stats_rows(rows)
|
||||
|
||||
assert len(merged_rows) == 3
|
||||
assert [row["id"] for row in merged_rows] == [1, 2, 4]
|
||||
assert merged_rows[0]["count"] == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
|
||||
"""测试导入不存在的文件"""
|
||||
@@ -365,6 +659,35 @@ class TestAstrBotImporter:
|
||||
assert result.success is False
|
||||
assert any("主版本不兼容" in err for err in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_replace_fails_when_clear_main_db_fails(
|
||||
self, mock_main_db, tmp_path
|
||||
):
|
||||
"""测试 replace 模式下主库清空失败会直接终止导入"""
|
||||
zip_path = tmp_path / "valid_backup.zip"
|
||||
manifest = {
|
||||
"version": "1.1",
|
||||
"astrbot_version": VERSION,
|
||||
"tables": {"platform_stats": 0},
|
||||
}
|
||||
main_data = {"platform_stats": []}
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
zf.writestr("databases/main_db.json", json.dumps(main_data))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
importer._clear_main_db = AsyncMock(
|
||||
side_effect=DatabaseClearError("清空表 platform_stats 失败: db locked")
|
||||
)
|
||||
importer._import_main_database = AsyncMock(return_value={})
|
||||
|
||||
result = await importer.import_all(str(zip_path), mode="replace")
|
||||
|
||||
assert result.success is False
|
||||
assert any("清空主数据库失败" in err for err in result.errors)
|
||||
assert any("清空表 platform_stats 失败" in err for err in result.errors)
|
||||
importer._import_main_database.assert_not_awaited()
|
||||
|
||||
|
||||
class TestSecureFilename:
|
||||
"""安全文件名函数测试"""
|
||||
|
||||
Reference in New Issue
Block a user