Compare commits

...

4 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] 2a7c8b44bf feat: switch Monaco editor from CDN to local deployment
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2026-03-01 15:15:03 +00:00
copilot-swe-agent[bot] b8e83b772d Initial plan 2026-03-01 15:09:13 +00:00
sanyekana 4abea2bd30 fix: harden backup import for duplicate platform stats (#5594)
* fix: harden backup import for duplicate platform stats

- 修复 replace 模式下主库清空失败仍继续导入的问题。
- 导入前对 platform_stats 重复键做聚合(count 累加),并统一时间戳判重格式。
- 非法 count 按 0 处理并告警(限流),补充对应测试。

* refactor: improve robustness and readability of platform stats import

- 告警上限魔法数字提取为模块常量 PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
- 抽取 parse_count 内联函数,消除重复的 try/except 分支
- 存储行的 timestamp 同步写入规范化值,避免落库格式混用
- 补充测试:已有行 count 非法、告警限流、replace 模式中断断言

* fix: normalize invalid platform_stats count for non-duplicate rows

* fix: avoid merging invalid platform_stats timestamps

* refactor: simplify platform stats merge and normalize naive UTC

* refactor: inline platform stats merge helpers

* refactor: flatten platform stats merge flow

* refactor: harden platform stats merge key handling

* refactor: streamline platform stats preprocessing

* refactor: simplify platform stats merge helpers

* refactor: inline platform stats merge normalization

* refactor: extract platform stats merge helpers

* refactor: simplify platform stats preprocessing flow

* refactor: flatten platform stats preprocess helpers

* refactor: streamline platform stats merge helpers

* refactor: isolate platform stats warning limiter

---------

Co-authored-by: 邹永赫 <1259085392@qq.com>
2026-03-01 20:46:35 +09:00
pandyzhou 267abfd552 fix: resolve /model command misleading behavior when switching to model from different provider (#5578)
* fix: /model command now auto-switches provider when model exists elsewhere

Made-with: Cursor

* fix: address Sourcery review - log get_models() failures in cross-provider lookup

Made-with: Cursor

* fix: integer branch exception handling and API key masking in model command

Made-with: Cursor

* fix: harden cross-provider model resolution

* fix: improve model lookup resilience and cache hygiene

* refactor: simplify model switch lookup flow

* refactor: streamline provider model cache updates

* fix: align provider annotations and key error flow

* fix: narrow provider command exception handling

* refactor: harden provider command error redaction and flow

* fix: improve provider model lookup and secret redaction

* refactor: cache normalized model names in provider lookup

* refactor: simplify provider model lookup helpers

* refactor: extract provider model lookup helpers

* fix: harden provider lookup cancellation and redaction

* refactor: streamline provider cache and lookup settings

* refactor: simplify provider command setting and update helpers

* refactor: streamline provider model lookup config usage

* refactor: flatten provider lookup settings and filter model lookup providers

* refactor: simplify provider cache and callback flow

* refactor: simplify provider command model cache flow

* refactor: scope provider model cache by session

* fix: preserve redaction context and restore provider hooks

* refactor: unify provider model lookup config flow

* refactor: inline provider model cache access flow

* fix: align provider lookup cache and callback semantics

* refactor: centralize provider model fetch error handling

* refactor: simplify provider model cache and lookup flow

---------

Co-authored-by: 邹永赫 <1259085392@qq.com>
2026-03-01 19:11:31 +09:00
9 changed files with 1116 additions and 51 deletions
@@ -1,15 +1,262 @@
from __future__ import annotations
import asyncio
import re
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING
from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType
from astrbot.core.utils.error_redaction import safe_error
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT = 30.0
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT = 4
MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16
MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds"
MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency"
MODEL_CACHE_MAX_ENTRIES = 512
@dataclass(frozen=True)
class _ModelLookupConfig:
umo: str | None
cache_ttl_seconds: float
max_concurrency: int
class _ModelCache:
def __init__(self) -> None:
self._store: dict[tuple[str, str | None], tuple[float, list[str]]] = {}
def get(self, provider_id: str, umo: str | None, ttl: float) -> list[str] | None:
if ttl <= 0:
return None
entry = self._store.get((provider_id, umo))
if not entry:
return None
timestamp, models = entry
if time.monotonic() - timestamp > ttl:
self._store.pop((provider_id, umo), None)
return None
return models
def set(
self, provider_id: str, umo: str | None, models: list[str], ttl: float
) -> None:
if ttl <= 0:
return
self._store[(provider_id, umo)] = (time.monotonic(), list(models))
self._evict_if_needed()
def _evict_if_needed(self) -> None:
if len(self._store) <= MODEL_CACHE_MAX_ENTRIES:
return
# Drop oldest entries first when cache grows too large.
overflow = len(self._store) - MODEL_CACHE_MAX_ENTRIES
for key, _ in sorted(
self._store.items(),
key=lambda item: item[1][0],
)[:overflow]:
self._store.pop(key, None)
def invalidate(
self, provider_id: str | None = None, *, umo: str | None = None
) -> None:
if provider_id is None:
self._store.clear()
return
if umo is not None:
self._store.pop((provider_id, umo), None)
return
stale_keys = [
cache_key for cache_key in self._store if cache_key[0] == provider_id
]
for cache_key in stale_keys:
self._store.pop(cache_key, None)
class ProviderCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
self._model_cache = _ModelCache()
self._register_provider_change_hook()
def _register_provider_change_hook(self) -> None:
set_change_callback = getattr(
self.context.provider_manager,
"set_provider_change_callback",
None,
)
if callable(set_change_callback):
set_change_callback(self._on_provider_manager_changed)
return
register_change_hook = getattr(
self.context.provider_manager,
"register_provider_change_hook",
None,
)
if callable(register_change_hook):
register_change_hook(self._on_provider_manager_changed)
def invalidate_provider_models_cache(
self, provider_id: str | None = None, *, umo: str | None = None
) -> None:
"""Public hook for cache invalidation on external provider config changes."""
self._model_cache.invalidate(provider_id, umo=umo)
def _on_provider_manager_changed(
self,
provider_id: str,
provider_type: ProviderType,
umo: str | None,
) -> None:
if provider_type == ProviderType.CHAT_COMPLETION:
self.invalidate_provider_models_cache(provider_id, umo=umo)
def _get_provider_settings(self, umo: str | None) -> dict:
if not umo:
return {}
try:
return self.context.get_config(umo).get("provider_settings", {}) or {}
except Exception as e:
logger.debug(
"读取 provider_settings 失败,使用默认值: %s",
safe_error("", e),
)
return {}
def _get_model_cache_ttl(self, umo: str | None) -> float:
settings = self._get_provider_settings(umo)
raw = settings.get(
MODEL_LIST_CACHE_TTL_KEY,
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
)
try:
return max(float(raw), 0.0)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
MODEL_LIST_CACHE_TTL_KEY,
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
safe_error("", e),
)
return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
def _get_model_lookup_concurrency(self, umo: str | None) -> int:
settings = self._get_provider_settings(umo)
raw = settings.get(
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
)
try:
value = int(raw)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
safe_error("", e),
)
value = MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
return min(max(value, 1), MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND)
def _get_model_lookup_config(self, umo: str | None) -> _ModelLookupConfig:
return _ModelLookupConfig(
umo=umo,
cache_ttl_seconds=self._get_model_cache_ttl(umo),
max_concurrency=self._get_model_lookup_concurrency(umo),
)
def _resolve_model_name(
self,
model_name: str,
models: Sequence[str],
) -> str | None:
"""Resolve model name with precedence:
exact > case-insensitive > provider-qualified suffix.
"""
requested = model_name.strip()
if not requested:
return None
requested_norm = requested.casefold()
# exact / case-insensitive match
for candidate in models:
if candidate == requested or candidate.casefold() == requested_norm:
return candidate
# provider-qualified suffix match:
# e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`.
for candidate in models:
cand_norm = candidate.casefold()
if cand_norm.endswith(f"/{requested_norm}") or cand_norm.endswith(
f":{requested_norm}"
):
return candidate
return None
def _apply_model(
self, prov: Provider, model_name: str, *, umo: str | None = None
) -> str:
prov.set_model(model_name)
self.invalidate_provider_models_cache(prov.meta().id, umo=umo)
return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
async def _get_provider_models(
self,
provider: Provider,
*,
config: _ModelLookupConfig,
use_cache: bool = True,
) -> list[str]:
provider_id = provider.meta().id
ttl_seconds = config.cache_ttl_seconds
umo = config.umo
if use_cache:
cached = self._model_cache.get(provider_id, umo, ttl_seconds)
if cached is not None:
return cached
models = list(await provider.get_models())
if use_cache:
self._model_cache.set(provider_id, umo, models, ttl_seconds)
return models
async def _get_models_or_reply_error(
self,
message: AstrMessageEvent,
prov: Provider,
config: _ModelLookupConfig,
*,
error_prefix: str,
disable_t2i: bool = False,
warning_log: str | None = None,
) -> list[str] | None:
try:
return await self._get_provider_models(prov, config=config)
except asyncio.CancelledError:
raise
except Exception as e:
if warning_log is not None:
logger.warning(
warning_log,
prov.meta().id,
safe_error("", e),
)
result = MessageEventResult().message(safe_error(error_prefix, e))
if disable_t2i:
result = result.use_t2i(False)
message.set_result(result)
return None
def _log_reachability_failure(
self,
@@ -38,12 +285,96 @@ class ProviderCommands:
return True, None, None
except Exception as e:
err_code = "TEST_FAILED"
err_reason = str(e)
err_reason = safe_error("", e)
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
async def _find_provider_for_model(
self,
model_name: str,
*,
exclude_provider_id: str | None = None,
config: _ModelLookupConfig,
use_cache: bool = True,
) -> tuple[Provider | None, str | None]:
all_providers = []
for provider in self.context.get_all_providers():
provider_meta = provider.meta()
if provider_meta.provider_type != ProviderType.CHAT_COMPLETION:
continue
if (
exclude_provider_id is not None
and provider_meta.id == exclude_provider_id
):
continue
all_providers.append(provider)
if not all_providers:
return None, None
semaphore = asyncio.Semaphore(config.max_concurrency)
async def fetch_models(
provider: Provider,
) -> tuple[Provider, list[str] | None, str | None]:
async with semaphore:
try:
models = await self._get_provider_models(
provider,
config=config,
use_cache=use_cache,
)
return provider, models, None
except asyncio.CancelledError:
raise
except Exception as e:
err = safe_error("", e)
logger.debug(
"跨提供商查找模型 %s 获取 %s 模型列表失败: %s",
model_name,
provider.meta().id,
err,
)
return provider, None, err
results = await asyncio.gather(
*(fetch_models(provider) for provider in all_providers)
)
failed_provider_errors: list[tuple[str, str]] = []
for provider, models, err in results:
if err is not None:
failed_provider_errors.append((provider.meta().id, err))
continue
if models is None:
continue
matched_model_name = self._resolve_model_name(model_name, models)
if matched_model_name is not None:
return provider, matched_model_name
if failed_provider_errors and len(failed_provider_errors) == len(all_providers):
failed_ids = ",".join(
provider_id for provider_id, _ in failed_provider_errors
)
logger.error(
"跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络",
model_name,
len(all_providers),
failed_ids,
)
elif failed_provider_errors:
logger.debug(
"跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s",
model_name,
len(failed_provider_errors),
",".join(
f"{provider_id}({error})"
for provider_id, error in failed_provider_errors
),
)
return None, None
async def provider(
self,
event: AstrMessageEvent,
@@ -92,13 +423,15 @@ class ProviderCommands:
id_ = meta.id
error_code = None
if isinstance(reachable, asyncio.CancelledError):
raise reachable
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
reachable.__class__.__name__,
str(reachable),
safe_error("", reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
@@ -224,6 +557,73 @@ class ProviderCommands:
else:
event.set_result(MessageEventResult().message("无效的参数。"))
async def _switch_model_by_name(
self, message: AstrMessageEvent, model_name: str, prov: Provider
) -> None:
model_name = model_name.strip()
if not model_name:
message.set_result(MessageEventResult().message("模型名不能为空。"))
return
umo = message.unified_msg_origin
config = self._get_model_lookup_config(umo)
curr_provider_id = prov.meta().id
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取当前提供商模型列表失败: ",
warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
)
if models is None:
return
matched_model_name = self._resolve_model_name(model_name, models)
if matched_model_name is not None:
message.set_result(
MessageEventResult().message(
self._apply_model(prov, matched_model_name, umo=umo)
),
)
return
target_prov, matched_target_model_name = await self._find_provider_for_model(
model_name,
exclude_provider_id=curr_provider_id,
config=config,
)
if target_prov is None or matched_target_model_name is None:
message.set_result(
MessageEventResult().message(
f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。",
),
)
return
target_id = target_prov.meta().id
try:
await self.context.provider_manager.set_provider(
provider_id=target_id,
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
self._apply_model(target_prov, matched_target_model_name, umo=umo)
message.set_result(
MessageEventResult().message(
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
),
)
except asyncio.CancelledError:
raise
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("跨提供商切换并设置模型失败: ", e)
),
)
async def model_ls(
self,
message: AstrMessageEvent,
@@ -236,20 +636,17 @@ class ProviderCommands:
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
# 定义正则表达式匹配 API 密钥
api_key_pattern = re.compile(r"key=[^&'\" ]+")
config = self._get_model_lookup_config(message.unified_msg_origin)
if idx_or_name is None:
models = []
try:
models = await prov.get_models()
except BaseException as e:
err_msg = api_key_pattern.sub("key=***", str(e))
message.set_result(
MessageEventResult()
.message("获取模型列表失败: " + err_msg)
.use_t2i(False),
)
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取模型列表失败: ",
disable_t2i=True,
)
if models is None:
return
parts = ["下面列出了此模型提供商可用模型:"]
for i, model in enumerate(models, 1):
@@ -258,40 +655,43 @@ class ProviderCommands:
curr_model = prov.get_model() or ""
parts.append(f"\n当前模型: [{curr_model}]")
parts.append(
"\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名"
"\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换"
)
ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
elif isinstance(idx_or_name, int):
models = []
try:
models = await prov.get_models()
except BaseException as e:
message.set_result(
MessageEventResult().message("获取模型列表失败: " + str(e)),
)
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取模型列表失败: ",
)
if models is None:
return
if idx_or_name > len(models) or idx_or_name < 1:
message.set_result(MessageEventResult().message("模型序号错误。"))
else:
try:
new_model = models[idx_or_name - 1]
prov.set_model(new_model)
except BaseException as e:
message.set_result(
MessageEventResult().message("切换模型未知错误: " + str(e)),
MessageEventResult().message(
self._apply_model(
prov,
new_model,
umo=message.unified_msg_origin,
)
),
)
message.set_result(
MessageEventResult().message(
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
),
)
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("切换模型未知错误: ", e)
),
)
return
else:
prov.set_model(idx_or_name)
message.set_result(
MessageEventResult().message(f"切换模型到 {prov.get_model()}"),
)
await self._switch_model_by_name(message, idx_or_name, prov)
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
prov = self.context.get_using_provider(message.unified_msg_origin)
@@ -322,8 +722,15 @@ class ProviderCommands:
try:
new_key = keys_data[index - 1]
prov.set_key(new_key)
except BaseException as e:
message.set_result(
MessageEventResult().message(f"切换 Key 未知错误: {e!s}"),
self.invalidate_provider_models_cache(
prov.meta().id,
umo=message.unified_msg_origin,
)
message.set_result(MessageEventResult().message("切换 Key 成功。"))
message.set_result(MessageEventResult().message("切换 Key 成功。"))
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("切换 Key 未知错误: ", e)
),
)
return
+188 -3
View File
@@ -12,7 +12,7 @@ import os
import shutil
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -61,6 +61,69 @@ def _get_major_version(version_str: str) -> str:
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
KB_PATH = get_astrbot_knowledge_base_path()
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = (
"ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT"
)
def _load_platform_stats_invalid_count_warn_limit() -> int:
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
if raw_value is None:
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
try:
value = int(raw_value)
if value < 0:
raise ValueError("negative")
return value
except (TypeError, ValueError):
logger.warning(
"Invalid env %s=%r, fallback to default %d",
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
raw_value,
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
)
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = (
_load_platform_stats_invalid_count_warn_limit()
)
class _InvalidCountWarnLimiter:
"""Rate-limit warnings for invalid platform_stats count values."""
def __init__(self, limit: int) -> None:
self.limit = limit
self._count = 0
self._suppression_logged = False
def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None:
if self.limit > 0:
if self._count < self.limit:
logger.warning(
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
value,
key_for_log,
)
self._count += 1
if self._count == self.limit and not self._suppression_logged:
logger.warning(
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
self.limit,
)
self._suppression_logged = True
return
if not self._suppression_logged:
# limit <= 0: emit only one suppression warning.
logger.warning(
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
self.limit,
)
self._suppression_logged = True
@dataclass
@@ -138,6 +201,10 @@ class ImportResult:
}
class DatabaseClearError(RuntimeError):
"""Raised when clearing the main database in replace mode fails."""
class AstrBotImporter:
"""AstrBot 数据导入器
@@ -342,6 +409,9 @@ class AstrBotImporter:
imported = await self._import_main_database(main_data)
result.imported_tables.update(imported)
except DatabaseClearError as e:
result.add_error(f"清空主数据库失败: {e}")
return result
except Exception as e:
result.add_error(f"导入主数据库失败: {e}")
return result
@@ -452,7 +522,9 @@ class AstrBotImporter:
await session.execute(delete(model_class))
logger.debug(f"已清空表 {table_name}")
except Exception as e:
logger.warning(f"清空表 {table_name} 失败: {e}")
raise DatabaseClearError(
f"清空表 {table_name} 失败: {e}"
) from e
async def _clear_kb_data(self) -> None:
"""清空知识库数据"""
@@ -494,9 +566,10 @@ class AstrBotImporter:
if not model_class:
logger.warning(f"未知的表: {table_name}")
continue
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
count = 0
for row in rows:
for row in normalized_rows:
try:
# 转换 datetime 字符串为 datetime 对象
row = self._convert_datetime_fields(row, model_class)
@@ -511,6 +584,118 @@ class AstrBotImporter:
return imported
def _preprocess_main_table_rows(
self, table_name: str, rows: list[dict[str, Any]]
) -> list[dict[str, Any]]:
if table_name == "platform_stats":
normalized_rows = self._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(normalized_rows)
if duplicate_count > 0:
logger.warning(
"检测到 %s 重复键 %d 条,已在导入前聚合",
table_name,
duplicate_count,
)
return normalized_rows
return rows
def _merge_platform_stats_rows(
self, rows: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Merge duplicate platform_stats rows by normalized timestamp/platform key.
Note:
- Invalid/empty timestamps are kept as distinct rows to avoid accidental merging.
- Non-string platform_id/platform_type are kept as distinct rows.
- Invalid count warnings are rate-limited per function invocation.
"""
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
result: list[dict[str, Any]] = []
warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT)
for row in rows:
normalized_row, normalized_timestamp, count = (
self._normalize_platform_stats_entry(row, warn_limiter)
)
platform_id = normalized_row.get("platform_id")
platform_type = normalized_row.get("platform_type")
if (
normalized_timestamp is None
or not isinstance(platform_id, str)
or not isinstance(platform_type, str)
):
result.append(normalized_row)
continue
merge_key = (normalized_timestamp, platform_id, platform_type)
existing = merged.get(merge_key)
if existing is None:
merged[merge_key] = normalized_row
result.append(normalized_row)
else:
existing["count"] += count
return result
def _normalize_platform_stats_entry(
self,
row: dict[str, Any],
warn_limiter: _InvalidCountWarnLimiter,
) -> tuple[dict[str, Any], str | None, int]:
normalized_row = dict(row)
raw_timestamp = normalized_row.get("timestamp")
normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp)
if normalized_timestamp is not None:
normalized_row["timestamp"] = normalized_timestamp
elif isinstance(raw_timestamp, str):
normalized_row["timestamp"] = raw_timestamp.strip()
elif raw_timestamp is None:
normalized_row["timestamp"] = ""
else:
normalized_row["timestamp"] = str(raw_timestamp)
raw_count = normalized_row.get("count", 0)
try:
count = int(raw_count)
except (TypeError, ValueError):
key_for_log = (
normalized_row.get("timestamp"),
repr(normalized_row.get("platform_id")),
repr(normalized_row.get("platform_type")),
)
warn_limiter.warn_invalid_count(raw_count, key_for_log)
count = 0
normalized_row["count"] = count
return normalized_row, normalized_timestamp, count
def _normalize_platform_stats_timestamp(self, value: Any) -> str | None:
if isinstance(value, datetime):
dt = value
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(timezone.utc)
return dt.isoformat()
if isinstance(value, str):
timestamp = value.strip()
if not timestamp:
return None
if timestamp.endswith("Z"):
timestamp = f"{timestamp[:-1]}+00:00"
try:
dt = datetime.fromisoformat(timestamp)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(timezone.utc)
return dt.isoformat()
except ValueError:
return None
return None
async def _import_knowledge_bases(
self,
zf: zipfile.ZipFile,
+56
View File
@@ -2,11 +2,13 @@ import asyncio
import copy
import os
import traceback
from collections.abc import Callable
from typing import Protocol, runtime_checkable
from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.error_redaction import safe_error
from ..persona_mgr import PersonaManager
from .entities import ProviderType
@@ -71,6 +73,56 @@ class ProviderManager:
self.curr_tts_provider_inst: TTSProvider | None = None
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
self.db_helper = db_helper
self._provider_change_callback: (
Callable[[str, ProviderType, str | None], None] | None
) = None
self._provider_change_hooks: list[
Callable[[str, ProviderType, str | None], None]
] = []
def set_provider_change_callback(
self,
cb: Callable[[str, ProviderType, str | None], None] | None,
) -> None:
# Backward-compatible single-callback setter.
# This callback coexists with register_provider_change_hook subscriptions.
self._provider_change_callback = cb
def register_provider_change_hook(
self,
hook: Callable[[str, ProviderType, str | None], None],
) -> None:
if hook not in self._provider_change_hooks:
self._provider_change_hooks.append(hook)
def _notify_provider_changed(
self,
provider_id: str,
provider_type: ProviderType,
umo: str | None,
) -> None:
if self._provider_change_callback is not None:
try:
self._provider_change_callback(provider_id, provider_type, umo)
except Exception as e:
logger.warning(
"调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s",
provider_id,
provider_type,
safe_error("", e),
)
for hook in list(self._provider_change_hooks):
if hook is self._provider_change_callback:
continue
try:
hook(provider_id, provider_type, umo)
except Exception as e:
logger.warning(
"调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s",
provider_id,
provider_type,
safe_error("", e),
)
@property
def persona_configs(self) -> list:
@@ -111,6 +163,7 @@ class ProviderManager:
f"provider_perf_{provider_type.value}",
provider_id,
)
self._notify_provider_changed(provider_id, provider_type, umo)
return
# 不启用提供商会话隔离模式的情况
@@ -126,6 +179,7 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self._notify_provider_changed(provider_id, provider_type, umo)
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov,
STTProvider,
@@ -137,6 +191,7 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self._notify_provider_changed(provider_id, provider_type, umo)
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov,
Provider,
@@ -148,6 +203,7 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self._notify_provider_changed(provider_id, provider_type, umo)
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
+82
View File
@@ -0,0 +1,82 @@
import re
_SECRET_KEYS = (
r"(?:api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)"
)
_JSON_FIELD_PATTERN = re.compile(
rf"(?i)(?P<prefix>(?P<kq>['\"]){_SECRET_KEYS}(?P=kq)\s*:\s*)(?P<vq>['\"])(?P<value>[^'\"]+)(?P=vq)"
)
_AUTH_JSON_FIELD_PATTERN = re.compile(
r"(?i)(?P<prefix>(?P<kq>['\"])authorization(?P=kq)\s*:\s*)(?P<vq>['\"])bearer\s+[^'\"]+(?P=vq)"
)
_QUERY_FIELD_PATTERN = re.compile(
rf"(?i)(?P<prefix>{_SECRET_KEYS}\s*=\s*)(?P<value>[^&'\" ]+)"
)
_QUERY_PARAM_PATTERN = re.compile(
r"(?i)(?P<prefix>[?&](?:api_?key|key|access_?token|auth_?token)=)(?P<value>[^&'\" ]+)"
)
_AUTH_HEADER_PATTERN = re.compile(
r"(?i)(?P<prefix>\bauthorization\s*:\s*bearer\s+)(?P<token>[A-Za-z0-9._\-]+)"
)
_BEARER_PATTERN = re.compile(r"(?i)(?P<prefix>\bbearer\s+)(?P<token>[A-Za-z0-9._\-]+)")
_SK_PATTERN = re.compile(r"\bsk-[A-Za-z0-9]{16,}\b")
def _redact_json_field(match: re.Match[str]) -> str:
quote = match.group("vq")
return f"{match.group('prefix')}{quote}[REDACTED]{quote}"
def _redact_auth_json_field(match: re.Match[str]) -> str:
quote = match.group("vq")
return f"{match.group('prefix')}{quote}Bearer [REDACTED]{quote}"
def _redact_prefixed_value(match: re.Match[str]) -> str:
return f"{match.group('prefix')}[REDACTED]"
def _redact_bearer_token(match: re.Match[str]) -> str:
return f"{match.group('prefix')}[REDACTED]"
def _redact_json_like(text: str) -> str:
text = _JSON_FIELD_PATTERN.sub(_redact_json_field, text)
return _AUTH_JSON_FIELD_PATTERN.sub(_redact_auth_json_field, text)
def _redact_query_like(text: str) -> str:
text = _QUERY_FIELD_PATTERN.sub(_redact_prefixed_value, text)
return _QUERY_PARAM_PATTERN.sub(_redact_prefixed_value, text)
def _redact_tokens(text: str) -> str:
text = _AUTH_HEADER_PATTERN.sub(_redact_bearer_token, text)
text = _BEARER_PATTERN.sub(_redact_bearer_token, text)
return _SK_PATTERN.sub("[REDACTED]", text)
def redact_sensitive_text(text: str) -> str:
text = _redact_json_like(text)
text = _redact_query_like(text)
text = _redact_tokens(text)
return text
def safe_error(
prefix: str,
error: Exception | BaseException | str,
*,
redact: bool = True,
) -> str:
try:
text = str(error)
except Exception:
try:
text = repr(error)
except Exception:
text = "<unprintable error>"
if redact:
text = redact_sensitive_text(text)
return prefix + text
+1
View File
@@ -65,6 +65,7 @@
"sass-loader": "13.3.2",
"typescript": "5.1.6",
"vite": "4.4.9",
"vite-plugin-monaco-editor": "1.1.0",
"vue-cli-plugin-vuetify": "2.5.8",
"vue-tsc": "1.8.8",
"vuetify-loader": "^2.0.0-alpha.9"
+12
View File
@@ -159,6 +159,9 @@ importers:
vite:
specifier: 4.4.9
version: 4.4.9(@types/node@20.19.32)(sass@1.66.1)(terser@5.46.0)
vite-plugin-monaco-editor:
specifier: 1.1.0
version: 1.1.0(monaco-editor@0.52.2)
vue-cli-plugin-vuetify:
specifier: 2.5.8
version: 2.5.8(sass-loader@13.3.2(sass@1.66.1)(webpack@5.105.0))(vue@3.3.4)(vuetify-loader@2.0.0-alpha.9(@vue/compiler-sfc@3.3.4)(vue@3.3.4)(vuetify@3.7.11)(webpack@5.105.0))(webpack@5.105.0)
@@ -2568,6 +2571,11 @@ packages:
vfile@6.0.3:
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
vite-plugin-monaco-editor@1.1.0:
resolution: {integrity: sha512-IvtUqZotrRoVqwT0PBBDIZPNraya3BxN/bfcNfnxZ5rkJiGcNtO5eAOWWSgT7zullIAEqQwxMU83yL9J5k7gww==}
peerDependencies:
monaco-editor: '>=0.33.0'
vite-plugin-vuetify@1.0.2:
resolution: {integrity: sha512-MubIcKD33O8wtgQXlbEXE7ccTEpHZ8nPpe77y9Wy3my2MWw/PgehP9VqTp92BLqr0R1dSL970Lynvisx3UxBFw==}
engines: {node: '>=12'}
@@ -5297,6 +5305,10 @@ snapshots:
'@types/unist': 3.0.3
vfile-message: 4.0.3
vite-plugin-monaco-editor@1.1.0(monaco-editor@0.52.2):
dependencies:
monaco-editor: 0.52.2
vite-plugin-vuetify@1.0.2(vite@4.4.9(@types/node@20.19.32)(sass@1.66.1)(terser@5.46.0))(vue@3.3.4)(vuetify@3.7.11):
dependencies:
'@vuetify/loader-shared': 1.7.1(vue@3.3.4)(vuetify@3.7.11)
+4 -7
View File
@@ -9,9 +9,12 @@ import '@/scss/style.scss';
import VueApexCharts from 'vue3-apexcharts';
import print from 'vue3-print-nb';
import { loader } from '@guolao/vue-monaco-editor'
import { loader } from '@guolao/vue-monaco-editor';
import * as monaco from 'monaco-editor';
import axios from 'axios';
loader.config({ monaco });
// 初始化新的i18n系统,等待完成后再挂载应用
setupI18n().then(() => {
console.log('🌍 新i18n系统初始化完成');
@@ -108,9 +111,3 @@ window.fetch = (input: RequestInfo | URL, init?: RequestInit) => {
}
return _origFetch(input, { ...init, headers });
};
loader.config({
paths: {
vs: 'https://cdn.jsdelivr.net/npm/monaco-editor@0.54.0/min/vs',
},
})
+3 -1
View File
@@ -2,6 +2,7 @@ import { fileURLToPath, URL } from 'url';
import { defineConfig } from 'vite';
import vue from '@vitejs/plugin-vue';
import vuetify from 'vite-plugin-vuetify';
import monacoEditorPlugin from 'vite-plugin-monaco-editor';
// https://vitejs.dev/config/
export default defineConfig({
@@ -15,7 +16,8 @@ export default defineConfig({
}),
vuetify({
autoImport: true
})
}),
monacoEditorPlugin({})
],
resolve: {
alias: {
+324 -1
View File
@@ -5,7 +5,7 @@ import os
import re
import zipfile
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -17,6 +17,8 @@ from astrbot.core.backup import (
)
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.backup.importer import (
DatabaseClearError,
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
AstrBotImporter,
ImportResult,
_get_major_version,
@@ -308,6 +310,298 @@ class TestAstrBotImporter:
assert isinstance(result["created_at"], datetime)
assert isinstance(result["updated_at"], datetime)
def test_merge_platform_stats_rows(self):
"""测试 platform_stats 重复键会在导入前聚合"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"id": 1,
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 14,
},
{
"id": 80,
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 3,
},
{
"id": 81,
"timestamp": "2025-12-13T20:00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 2,
},
{
"id": 2,
"timestamp": "2025-12-13T21:00:00",
"platform_id": "aiocqhttp",
"platform_type": "unknown",
"count": 1,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 2
assert len(merged_rows) == 2
webchat_row = next(
(
r
for r in merged_rows
if r.get("timestamp") == "2025-12-13T20:00:00+00:00"
and r.get("platform_id") == "webchat"
and r.get("platform_type") == "unknown"
),
None,
)
assert webchat_row is not None
assert webchat_row["timestamp"] == "2025-12-13T20:00:00+00:00"
assert webchat_row["platform_id"] == "webchat"
assert webchat_row["platform_type"] == "unknown"
assert webchat_row["count"] == 19
aiocq_row = next(
(
r
for r in merged_rows
if r.get("platform_id") == "aiocqhttp"
and r.get("platform_type") == "unknown"
),
None,
)
assert aiocq_row is not None
assert aiocq_row["timestamp"] == "2025-12-13T21:00:00+00:00"
def test_merge_platform_stats_rows_normalizes_naive_timestamp_to_utc(self):
"""测试 platform_stats 合并前会将 naive timestamp 标准化为 UTC 偏移"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"timestamp": "2025-12-13T21:00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 1,
},
{
"timestamp": datetime(2025, 12, 13, 22, 0, 0),
"platform_id": "telegram",
"platform_type": "unknown",
"count": 1,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
assert len(merged_rows) == 2
by_platform = {row["platform_id"]: row for row in merged_rows}
assert by_platform["webchat"]["timestamp"] == "2025-12-13T21:00:00+00:00"
assert by_platform["telegram"]["timestamp"] == "2025-12-13T22:00:00+00:00"
def test_merge_platform_stats_rows_warns_on_invalid_count(self):
"""测试 platform_stats count 非法时会告警并按 0 处理(含上限)"""
importer = AstrBotImporter(main_db=MagicMock())
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
rows = [
{
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 5,
},
{
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": "bad-count",
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 1
assert len(merged_rows) == 1
assert merged_rows[0]["count"] == 5
assert warning_mock.call_count == 1
warning_mock.reset_mock()
rows_existing_invalid = [
{
"timestamp": "2025-12-13T21:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": "bad-count",
},
{
"timestamp": "2025-12-13T21:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 7,
},
]
merged_rows = importer._merge_platform_stats_rows(rows_existing_invalid)
duplicate_count = len(rows_existing_invalid) - len(merged_rows)
assert duplicate_count == 1
assert len(merged_rows) == 1
assert merged_rows[0]["count"] == 7
assert warning_mock.call_count == 1
warning_mock.reset_mock()
many_invalid_rows = [
{
"timestamp": "2025-12-13T22:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 1,
},
*[
{
"timestamp": "2025-12-13T22:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": "bad-count",
}
for _ in range(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 5)
],
]
importer._merge_platform_stats_rows(many_invalid_rows)
assert (
warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 1
)
assert any(
"告警已达到上限" in str(call.args[0])
for call in warning_mock.call_args_list
)
warning_mock.reset_mock()
single_invalid_row = [
{
"timestamp": "2025-12-13T23:00:00+00:00",
"platform_id": "telegram",
"platform_type": "unknown",
"count": "still-bad",
},
]
merged_rows = importer._merge_platform_stats_rows(single_invalid_row)
duplicate_count = len(single_invalid_row) - len(merged_rows)
assert duplicate_count == 0
assert len(merged_rows) == 1
assert merged_rows[0]["count"] == 0
assert warning_mock.call_count == 1
def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self):
"""测试空/非法 timestamp 不参与聚合,避免误合并"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"timestamp": "",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 2,
},
{
"timestamp": "not-a-datetime",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 3,
},
{
"timestamp": "not-a-datetime",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 4,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 0
assert len(merged_rows) == 3
assert [row["count"] for row in merged_rows] == [2, 3, 4]
def test_merge_platform_stats_rows_keeps_non_string_platform_keys_distinct(self):
"""测试非字符串 platform_id/platform_type 不参与聚合"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": None,
"platform_type": "unknown",
"count": 2,
},
{
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": None,
"platform_type": "unknown",
"count": 3,
},
{
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": 1,
"count": 4,
},
{
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": 1,
"count": 5,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 0
assert len(merged_rows) == 4
def test_merge_platform_stats_rows_preserves_input_order(self):
"""测试 platform_stats 聚合后仍保持输入顺序(按首次出现位置)"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"id": 1,
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 2,
},
{
"id": 2,
"timestamp": "",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 3,
},
{
"id": 3,
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 5,
},
{
"id": 4,
"timestamp": "2025-12-13T21:00:00+00:00",
"platform_id": "telegram",
"platform_type": "unknown",
"count": 7,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
assert len(merged_rows) == 3
assert [row["id"] for row in merged_rows] == [1, 2, 4]
assert merged_rows[0]["count"] == 7
@pytest.mark.asyncio
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
"""测试导入不存在的文件"""
@@ -365,6 +659,35 @@ class TestAstrBotImporter:
assert result.success is False
assert any("主版本不兼容" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_replace_fails_when_clear_main_db_fails(
self, mock_main_db, tmp_path
):
"""测试 replace 模式下主库清空失败会直接终止导入"""
zip_path = tmp_path / "valid_backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": VERSION,
"tables": {"platform_stats": 0},
}
main_data = {"platform_stats": []}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
zf.writestr("databases/main_db.json", json.dumps(main_data))
importer = AstrBotImporter(main_db=mock_main_db)
importer._clear_main_db = AsyncMock(
side_effect=DatabaseClearError("清空表 platform_stats 失败: db locked")
)
importer._import_main_database = AsyncMock(return_value={})
result = await importer.import_all(str(zip_path), mode="replace")
assert result.success is False
assert any("清空主数据库失败" in err for err in result.errors)
assert any("清空表 platform_stats 失败" in err for err in result.errors)
importer._import_main_database.assert_not_awaited()
class TestSecureFilename:
"""安全文件名函数测试"""