Compare commits

...

3 Commits

Author SHA1 Message Date
advent259141 d87d586c0a feat: add dashboard routes for session rule and group management, including available resource listings. 2026-03-11 18:18:06 +08:00
advent259141 410789311a feat: Add a new session management page with custom rules, batch operations, and group management, along with corresponding API routes. 2026-03-11 17:58:09 +08:00
エイカク 6da59cfb07 fix: 插件依赖自动安装逻辑与 Dashboard 安装体验优化 (#5954)
* fix: install plugin requirements before first load

* fix: handle pip option arguments correctly

* fix: harden pip install input parsing

* refactor: simplify pip install input parsing

* fix: align plugin dependency install handling

* fix: respect configured pip index overrides

* test: parameterize plugin dependency install flows

* refactor: simplify multiline pip input parsing

* fix: install plugin dependencies before loading

* fix: protect core dependencies from downgrades and simplify package input splitting

* fix: enhance dependency conflict reporting and improve user-facing warnings

* refactor: preserve pip log indentation and fix CodeQL URL sanitization alert

* fix: explicit re-export for DependencyConflictError to satisfy ruff F401

* test: enhance index override verification in pip installer tests

* fix: correctly map pip ERROR and WARNING outputs to proper log levels

* refactor: show specific version conflicts in DependencyConflictError and revert log level mapping

* refactor: simplify install() by decoupling pip logging, failure classification and constraint file management

* refactor: further simplify pip installer and requirement parsing logic

* refactor: simplify dependency installation logic and improve circular requirement reporting

* style: organize imports in astrbot/core/__init__.py

* refactor: optimize requirement parsing efficiency and flatten pip installer API

* style: fix import sorting in astrbot/core/__init__.py

* refactor: consolidate requirement parsing, optimize core protection, and improve exception propagation

* fix: preserve valid pip requirement parsing

* fix: skip empty pip installs and preserve blank output

* chore: normalize gitignore entry style

* fix: tighten pip trust and requirement parsing

* refactor: centralize pip install parsing and failure handling

* fix: redact pip argv credentials in logs

* fix: surface plugin dependency install errors

* fix: cache core constraints and clarify requirement installs

* fix: harden pip requirement parsing for plugin installs

* fix: simplify pip installer parsing internals

* fix: tighten pip installer parsing and redaction

* refactor: simplify plugin dependency install flow

* fix: preserve core constraint conflict errors

* fix: harden pip installer fallback resolution

* refactor: split pip requirement and constraint helpers

* refactor: simplify pip installer helper flow

* refactor: streamline requirement precheck helpers

* refactor: clarify core constraint resolution

* fix: surface pip install failures explicitly

* refactor: separate pip conflict context parsing

* fix: harden core constraint resolution

* test: cover pip installer failure call sites

* refactor: remove dead requirements fallback helper

* refactor: narrow core constraint error handling

* refactor: unify requirement iteration

* refactor: share requirement name parsing

* test: align pip helper coverage

* fix: bind pip output limit at runtime

* refactor: reuse core requirement parser for tokens
2026-03-11 14:21:55 +09:00
12 changed files with 3523 additions and 348 deletions
+2
View File
@@ -61,3 +61,5 @@ GenieData/
.codex/
.opencode/
.kilocode/
.worktrees/
docs/plans/
+15 -1
View File
@@ -4,7 +4,21 @@ from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.file_token_service import FileTokenService
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.utils.pip_installer import (
DependencyConflictError as DependencyConflictError,
)
from astrbot.core.utils.pip_installer import (
PipInstaller,
)
from astrbot.core.utils.requirements_utils import (
RequirementsPrecheckFailed as RequirementsPrecheckFailed,
)
from astrbot.core.utils.requirements_utils import (
find_missing_requirements as find_missing_requirements,
)
from astrbot.core.utils.requirements_utils import (
find_missing_requirements_or_raise as find_missing_requirements_or_raise,
)
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.t2i.renderer import HtmlRenderer
+97 -9
View File
@@ -14,7 +14,12 @@ import yaml
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import InvalidVersion, Version
from astrbot.core import logger, pip_installer, sp
from astrbot.core import (
DependencyConflictError,
logger,
pip_installer,
sp,
)
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import VERSION
@@ -27,6 +32,10 @@ from astrbot.core.utils.astrbot_path import (
)
from astrbot.core.utils.io import remove_dir
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.requirements_utils import (
RequirementsPrecheckFailed,
find_missing_requirements_or_raise,
)
from . import StarMetadata
from .command_management import sync_command_configs
@@ -48,6 +57,49 @@ class PluginVersionIncompatibleError(Exception):
"""Raised when plugin astrbot_version is incompatible with current AstrBot."""
class PluginDependencyInstallError(Exception):
"""Raised when plugin dependency installation fails."""
def __init__(
self,
*,
plugin_label: str,
requirements_path: str,
error: Exception,
) -> None:
message = f"插件 {plugin_label} 依赖安装失败: {error!s}"
super().__init__(message)
self.plugin_label = plugin_label
self.requirements_path = requirements_path
self.error = error
async def _install_requirements_with_precheck(
*,
plugin_label: str,
requirements_path: str,
) -> None:
try:
missing = find_missing_requirements_or_raise(requirements_path)
except RequirementsPrecheckFailed:
logger.info(
f"正在安装插件 {plugin_label} 的依赖库(预检查失败,回退到完整安装): "
f"{requirements_path}"
)
await pip_installer.install(requirements_path=requirements_path)
return
if not missing:
logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。")
return
logger.info(
f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: "
f"{requirements_path} -> {sorted(missing)}"
)
await pip_installer.install(requirements_path=requirements_path)
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools
@@ -198,15 +250,37 @@ class PluginManager:
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
try:
await pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
await self._ensure_plugin_requirements(plugin_path, p)
return True
async def _ensure_plugin_requirements(
self,
plugin_dir_path: str,
plugin_label: str,
) -> None:
requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
if not os.path.exists(requirements_path):
return
try:
await _install_requirements_with_precheck(
plugin_label=plugin_label,
requirements_path=requirements_path,
)
except asyncio.CancelledError:
raise
except DependencyConflictError as e:
logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}")
raise
except Exception as e:
dependency_error = PluginDependencyInstallError(
plugin_label=plugin_label,
requirements_path=requirements_path,
error=e,
)
logger.exception(str(dependency_error))
raise dependency_error from e
async def _import_plugin_with_dependency_recovery(
self,
path: str,
@@ -422,7 +496,7 @@ class PluginManager:
root_dir_name: str,
plugin_dir_path: str,
reserved: bool,
error: Exception | str,
error: BaseException | str,
error_trace: str,
) -> dict:
record: dict = {
@@ -495,6 +569,9 @@ class PluginManager:
self._cleanup_plugin_state(dir_name)
plugin_path = os.path.join(self.plugin_store_path, dir_name)
await self._ensure_plugin_requirements(plugin_path, dir_name)
success, error = await self.load(specified_dir_name=dir_name)
if success:
self.failed_plugin_dict.pop(dir_name, None)
@@ -1078,6 +1155,10 @@ class PluginManager:
# reload the plugin
dir_name = os.path.basename(plugin_path)
await self._ensure_plugin_requirements(
plugin_path,
dir_name,
)
success, error_message = await self.load(
specified_dir_name=dir_name,
ignore_version_check=ignore_version_check,
@@ -1317,6 +1398,12 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin, proxy=proxy)
if plugin.root_dir_name:
plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
await self._ensure_plugin_requirements(
plugin_dir_path,
plugin_name,
)
await self.reload(plugin_name)
async def turn_off_plugin(self, plugin_name: str) -> None:
@@ -1488,6 +1575,7 @@ class PluginManager:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {e!s}")
await self._ensure_plugin_requirements(desti_dir, dir_name)
# await self.reload()
success, error_message = await self.load(
specified_dir_name=dir_name,
+121
View File
@@ -0,0 +1,121 @@
import contextlib
import functools
import importlib.metadata as importlib_metadata
import logging
import os
from collections.abc import Iterator
from packaging.requirements import Requirement
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name,
collect_installed_distribution_versions,
get_requirement_check_paths,
)
logger = logging.getLogger("astrbot")
def _resolve_core_dist_name(core_dist_name: str | None) -> str | None:
if core_dist_name:
try:
importlib_metadata.distribution(core_dist_name)
return core_dist_name
except importlib_metadata.PackageNotFoundError:
return None
try:
importlib_metadata.distribution("AstrBot")
return "AstrBot"
except importlib_metadata.PackageNotFoundError:
pass
if not __package__:
return None
top_pkg = __package__.split(".")[0]
for dist in importlib_metadata.distributions():
try:
top_level = dist.read_text("top_level.txt") or ""
except Exception:
continue
if top_pkg in top_level.splitlines():
if "Name" in dist.metadata:
return dist.metadata["Name"]
return None
@functools.cache
def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]:
try:
resolved_core_dist_name = _resolve_core_dist_name(core_dist_name)
except Exception as exc:
logger.warning("解析核心分发名称失败: %s", exc)
return ()
if not resolved_core_dist_name:
return ()
try:
dist = importlib_metadata.distribution(resolved_core_dist_name)
except importlib_metadata.PackageNotFoundError:
return ()
except Exception as exc:
logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc)
return ()
if not dist or not dist.requires:
return ()
installed = collect_installed_distribution_versions(get_requirement_check_paths())
if not installed:
return ()
constraints: list[str] = []
for req_str in dist.requires:
try:
req = Requirement(req_str)
if req.marker and not req.marker.evaluate():
continue
name = canonicalize_distribution_name(req.name)
if name in installed:
constraints.append(f"{name}=={installed[name]}")
except Exception:
continue
return tuple(constraints)
class CoreConstraintsProvider:
def __init__(self, core_dist_name: str | None) -> None:
self._core_dist_name = core_dist_name
@contextlib.contextmanager
def constraints_file(self) -> Iterator[str | None]:
constraints = _get_core_constraints(self._core_dist_name)
if not constraints:
yield None
return
path: str | None = None
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8"
) as f:
f.write("\n".join(constraints))
path = f.name
logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints))
except Exception as exc:
logger.warning("创建临时约束文件失败: %s", exc)
yield None
return
try:
yield path
finally:
if path and os.path.exists(path):
with contextlib.suppress(Exception):
os.remove(path)
+428 -96
View File
@@ -7,21 +7,71 @@ import io
import logging
import os
import re
import shlex
import sys
import threading
from collections import deque
from dataclasses import dataclass
from urllib.parse import urlparse
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name as _canonicalize_distribution_name,
)
from astrbot.core.utils.requirements_utils import (
extract_requirement_name,
extract_requirement_names,
parse_package_install_input,
)
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
logger = logging.getLogger("astrbot")
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
_PIP_FAILURE_PATTERNS = {
"error_prefix": re.compile(r"^\s*error:", re.IGNORECASE),
"user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE),
"resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE),
"cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE),
"conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE),
"constraint": re.compile(r"\(constraint\)", re.IGNORECASE),
"dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE),
}
_SENSITIVE_PIP_VALUE_KEYS = frozenset(
{"password", "passwd", "pass", "api_token", "token", "auth_token"}
)
_MAX_PIP_OUTPUT_LINES = 200
def _canonicalize_distribution_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
class DependencyConflictError(Exception):
"""Raised when pip encounters a dependency conflict."""
def __init__(
self, message: str, errors: list[str], *, is_core_conflict: bool
) -> None:
super().__init__(message)
self.errors = errors
self.is_core_conflict = is_core_conflict
class PipInstallError(Exception):
"""Raised when pip install fails without a classified dependency conflict."""
def __init__(self, message: str, *, code: int) -> None:
super().__init__(message)
self.code = code
@dataclass
class PipConflictContext:
relevant_lines: list[str]
requested_lines: list[str]
dependency_detail_lines: list[str]
constraint_lines: list[str]
has_strong_conflict_signal: bool
has_contextual_conflict_signal: bool
def _get_pip_main():
@@ -41,11 +91,12 @@ def _get_pip_main():
return pip_main
def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]:
stream = io.StringIO()
with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
result_code = pip_main(args)
return result_code, stream.getvalue()
def _prepend_sys_path(path: str) -> None:
normalized_target = os.path.realpath(path)
sys.path[:] = [
item for item in sys.path if os.path.realpath(item) != normalized_target
]
sys.path.insert(0, normalized_target)
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
@@ -59,76 +110,258 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No
handler.close()
def _prepend_sys_path(path: str) -> None:
normalized_target = os.path.realpath(path)
sys.path[:] = [
item for item in sys.path if os.path.realpath(item) != normalized_target
]
sys.path.insert(0, normalized_target)
def _get_trusted_host_for_index_url(index_url: str) -> str | None:
parsed = urlparse(index_url if "://" in index_url else f"//{index_url}")
host = parsed.hostname
if host == "mirrors.aliyun.com":
return host
return None
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
return os.path.isfile(package_init) or os.path.isfile(module_file)
def _normalize_sensitive_pip_key(raw_key: str) -> str:
return raw_key.lstrip("-").replace("-", "_").lower()
def _is_module_loaded_from_site_packages(
module_name: str,
site_packages_path: str,
) -> bool:
module = sys.modules.get(module_name)
if module is None:
try:
module = importlib.import_module(module_name)
except Exception:
return False
def _is_sensitive_pip_value_key(raw_key: str) -> bool:
return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS
module_file = getattr(module, "__file__", None)
if not module_file:
return False
module_path = os.path.realpath(module_file)
site_packages_real = os.path.realpath(site_packages_path)
try:
return (
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
def _redact_url_credentials(raw_value: str) -> str:
"""Redact URL credentials and known inline secret values for safe logging."""
parsed = urlparse(raw_value)
if parsed.netloc and "@" in parsed.netloc:
hostname = parsed.hostname or ""
port = f":{parsed.port}" if parsed.port else ""
return parsed._replace(netloc=f"<redacted>@{hostname}{port}").geturl()
if raw_value.startswith("--"):
option, separator, _ = raw_value.partition("=")
if separator and _is_sensitive_pip_value_key(option):
return f"{option}=****"
return raw_value
key, separator, _ = raw_value.partition("=")
if separator and _is_sensitive_pip_value_key(key):
return f"{key}=****"
return raw_value
def _redact_pip_args_for_logging(args: list[str]) -> list[str]:
redacted_args: list[str] = []
redact_next_value = False
for arg in args:
if redact_next_value:
redacted_args.append("****")
redact_next_value = False
continue
if arg.startswith("--") and "=" in arg:
option, value = arg.split("=", 1)
if _is_sensitive_pip_value_key(option):
redacted_args.append(f"{option}=****")
else:
redacted_args.append(f"{option}={_redact_url_credentials(value)}")
continue
if arg.startswith("-i") and arg != "-i":
redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}")
continue
if _is_sensitive_pip_value_key(arg):
redacted_args.append(arg)
redact_next_value = True
continue
redacted_args.append(_redact_url_credentials(arg))
return redacted_args
def _package_specs_override_index(package_specs: list[str]) -> bool:
for index, spec in enumerate(package_specs):
if spec == "--no-index":
return True
if spec in {"-i", "--index-url"}:
if index + 1 < len(package_specs):
return True
continue
if spec.startswith("--index-url="):
return True
if spec.startswith("-i") and spec != "-i":
return True
return False
class _StreamingLogWriter(io.TextIOBase):
def __init__(self, log_func, *, max_lines: int | None = None) -> None:
self._log_func = log_func
self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES)
self._buffer = ""
def write(self, text: str) -> int:
if not text:
return 0
self._buffer += text.replace("\r\n", "\n").replace("\r", "\n")
while "\n" in self._buffer:
raw_line, self._buffer = self._buffer.split("\n", 1)
line = raw_line.rstrip("\r\n")
self._log_func(line)
self._lines.append(line)
return len(text)
def flush(self) -> None:
line = self._buffer.rstrip("\r\n")
if line:
self._log_func(line)
self._lines.append(line)
self._buffer = ""
@property
def lines(self) -> list[str]:
return list(self._lines)
def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]:
stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES)
with (
contextlib.redirect_stdout(stream),
contextlib.redirect_stderr(stream),
):
result_code = pip_main(args)
stream.flush()
return result_code, stream.lines
def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool:
names = pattern_names or tuple(_PIP_FAILURE_PATTERNS)
return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names)
def _normalize_conflict_detail_line(line: str) -> str:
stripped = line.strip()
if _matches_pip_failure_pattern(stripped, "user_requested"):
return re.sub(
r"^\s*The user requested\s+",
"",
stripped,
flags=re.IGNORECASE,
)
except ValueError:
return False
return stripped
def _extract_requirement_name(raw_requirement: str) -> str | None:
line = raw_requirement.split("#", 1)[0].strip()
if not line:
return None
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
return None
if line.startswith("-"):
def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None:
matched_indices = [
index
for index, line in enumerate(output_lines)
if _matches_pip_failure_pattern(line)
]
if matched_indices:
relevant_index_set: set[int] = set()
for index in matched_indices:
start = max(0, index - 1)
end = min(len(output_lines), index + 2)
relevant_index_set.update(range(start, end))
relevant_output_lines = [
line
for index, line in enumerate(output_lines)
if index in relevant_index_set
]
else:
relevant_output_lines = output_lines[-5:]
if not relevant_output_lines:
return None
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
if egg_match:
return _canonicalize_distribution_name(egg_match.group(1))
dependency_detail_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "dependency_detail")
]
requested_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "user_requested")
and not _matches_pip_failure_pattern(line, "constraint")
]
if not requested_lines:
requested_lines = [
line
for line in dependency_detail_lines
if not _matches_pip_failure_pattern(line, "constraint")
]
constraint_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "constraint")
]
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
if not candidate:
has_strong_conflict_signal = any(
_matches_pip_failure_pattern(
line,
"resolution_impossible",
"cannot_install",
)
for line in relevant_output_lines
)
has_contextual_conflict_signal = any(
_matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines
) and bool(dependency_detail_lines or requested_lines or constraint_lines)
return PipConflictContext(
relevant_lines=relevant_output_lines,
requested_lines=requested_lines,
dependency_detail_lines=dependency_detail_lines,
constraint_lines=constraint_lines,
has_strong_conflict_signal=has_strong_conflict_signal,
has_contextual_conflict_signal=has_contextual_conflict_signal,
)
def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None:
context = _build_pip_conflict_context(output_lines)
if context is None:
return None
return _canonicalize_distribution_name(candidate)
if (
not context.has_strong_conflict_signal
and not context.has_contextual_conflict_signal
and not (context.requested_lines and context.constraint_lines)
):
return None
def _extract_requirement_names(requirements_path: str) -> set[str]:
names: set[str] = set()
try:
with open(requirements_path, encoding="utf-8") as requirements_file:
for line in requirements_file:
requirement_name = _extract_requirement_name(line)
if requirement_name:
names.add(requirement_name)
except Exception as exc:
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
return names
is_core_conflict = bool(context.constraint_lines)
detail = ""
if context.constraint_lines and context.requested_lines:
detail = (
" 冲突详情: "
f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs "
f"{_normalize_conflict_detail_line(context.constraint_lines[0])}"
)
elif len(context.dependency_detail_lines) >= 2:
detail = (
" 冲突详情: "
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs "
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}"
)
if is_core_conflict:
message = (
f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容,"
"为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。"
)
else:
message = f"检测到依赖冲突。{detail}"
return DependencyConflictError(
message,
context.relevant_lines,
is_core_conflict=is_core_conflict,
)
def _extract_top_level_modules(
@@ -155,7 +388,11 @@ def _collect_candidate_modules(
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
try:
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
distribution_name = distribution.metadata.get("Name")
distribution_name = (
distribution.metadata["Name"]
if "Name" in distribution.metadata
else None
)
if not distribution_name:
continue
canonical_name = _canonicalize_distribution_name(distribution_name)
@@ -173,7 +410,7 @@ def _collect_candidate_modules(
for distribution in by_name.get(requirement_name, []):
for dependency_line in distribution.requires or []:
dependency_name = _extract_requirement_name(dependency_line)
dependency_name = extract_requirement_name(dependency_line)
if not dependency_name:
continue
if dependency_name in expanded_requirement_names:
@@ -230,6 +467,38 @@ def _ensure_preferred_modules(
raise RuntimeError(conflict_message)
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
return os.path.isfile(package_init) or os.path.isfile(module_file)
def _is_module_loaded_from_site_packages(
module_name: str,
site_packages_path: str,
) -> bool:
module = sys.modules.get(module_name)
if module is None:
try:
module = importlib.import_module(module_name)
except Exception:
return False
module_file = getattr(module, "__file__", None)
if not module_file:
return False
module_path = os.path.realpath(module_file)
site_packages_real = os.path.realpath(site_packages_path)
try:
return (
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
)
except ValueError:
return False
def _prefer_module_from_site_packages(
module_name: str, site_packages_path: str
) -> bool:
@@ -531,9 +800,63 @@ def _patch_distlib_finder_for_frozen_runtime() -> None:
class PipInstaller:
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None:
def __init__(
self,
pip_install_arg: str,
pypi_index_url: str | None = None,
core_dist_name: str | None = "AstrBot",
) -> None:
self.pip_install_arg = pip_install_arg
self.pypi_index_url = pypi_index_url
self.core_dist_name = core_dist_name
self._core_constraints = CoreConstraintsProvider(core_dist_name)
def _build_pip_args(
self,
package_name: str | None,
requirements_path: str | None,
mirror: str | None,
) -> tuple[list[str], set[str]]:
args: list[str] = []
requested_requirements: set[str] = set()
normalized_requirements_path = (
requirements_path.strip() if requirements_path else ""
)
if package_name and normalized_requirements_path:
raise ValueError(
"package_name and requirements_path cannot be used together"
)
if package_name:
parsed_package = parse_package_install_input(package_name)
if parsed_package.specs:
args = ["install", *parsed_package.specs]
requested_requirements = set(parsed_package.requirement_names)
elif normalized_requirements_path:
args = ["install", "-r", normalized_requirements_path]
requested_requirements = extract_requirement_names(
normalized_requirements_path
)
if not args:
return [], requested_requirements
pip_install_args = (
shlex.split(self.pip_install_arg) if self.pip_install_arg else []
)
if not _package_specs_override_index([*args[1:], *pip_install_args]):
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
trusted_host = _get_trusted_host_for_index_url(index_url)
if trusted_host:
args.extend(["--trusted-host", trusted_host])
args.extend(["-i", index_url])
if pip_install_args:
args.extend(pip_install_args)
return args, requested_requirements
async def install(
self,
@@ -541,36 +864,37 @@ class PipInstaller:
requirements_path: str | None = None,
mirror: str | None = None,
) -> None:
args = ["install"]
requested_requirements: set[str] = set()
if package_name:
args.append(package_name)
requirement_name = _extract_requirement_name(package_name)
if requirement_name:
requested_requirements.add(requirement_name)
elif requirements_path:
args.extend(["-r", requirements_path])
requested_requirements = _extract_requirement_names(requirements_path)
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
args, requested_requirements = self._build_pip_args(
package_name, requirements_path, mirror
)
if not args:
logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。")
return
target_site_packages = None
if is_packaged_desktop_runtime():
target_site_packages = get_astrbot_site_packages_path()
os.makedirs(target_site_packages, exist_ok=True)
_prepend_sys_path(target_site_packages)
args.extend(["--target", target_site_packages])
args.extend(["--upgrade", "--force-reinstall"])
args.extend(
[
"--target",
target_site_packages,
"--upgrade",
"--upgrade-strategy",
"only-if-needed",
]
)
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
with self._core_constraints.constraints_file() as constraints_file_path:
if constraints_file_path:
args.extend(["-c", constraints_file_path])
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
result_code = await self._run_pip_in_process(args)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
logger.info(
"Pip 包管理器 argv: %s",
["pip", *_redact_pip_args_for_logging(args)],
)
await self._run_pip_with_classification(args)
if target_site_packages:
_prepend_sys_path(target_site_packages)
@@ -589,7 +913,7 @@ class PipInstaller:
if not os.path.isdir(target_site_packages):
return
requested_requirements = _extract_requirement_names(requirements_path)
requested_requirements = extract_requirement_names(requirements_path)
if not requested_requirements:
return
@@ -605,13 +929,21 @@ class PipInstaller:
_patch_distlib_finder_for_frozen_runtime()
original_handlers = list(logging.getLogger().handlers)
result_code, output = await asyncio.to_thread(
_run_pip_main_with_output, pip_main, args
)
for line in output.splitlines():
line = line.strip()
if line:
logger.info(line)
try:
result_code, output_lines = await asyncio.to_thread(
_run_pip_main_streaming, pip_main, args
)
finally:
_cleanup_added_root_handlers(original_handlers)
if result_code != 0:
conflict = _classify_pip_failure(output_lines)
if conflict:
raise conflict
_cleanup_added_root_handlers(original_handlers)
return result_code
async def _run_pip_with_classification(self, args: list[str]) -> None:
result_code = await self._run_pip_in_process(args)
if result_code != 0:
raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code)
+408
View File
@@ -0,0 +1,408 @@
import importlib.metadata as importlib_metadata
import logging
import os
import re
import shlex
import sys
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from packaging.requirements import InvalidRequirement, Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
logger = logging.getLogger("astrbot")
class RequirementsPrecheckFailed(Exception):
"""Raised when the pre-check of requirements fails."""
pass
@dataclass(frozen=True)
class ParsedPackageInput:
specs: tuple[str, ...]
requirement_names: frozenset[str]
def canonicalize_distribution_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
def strip_inline_requirement_comment(raw_input: str) -> str:
if raw_input.lstrip().startswith("#"):
return ""
return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip()
def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool:
try:
parsed_version = Version(version)
except InvalidVersion:
return False
return specifier.contains(parsed_version, prereleases=True)
def _looks_like_local_path_reference(token: str) -> bool:
candidate = token.strip()
if not candidate:
return False
return candidate in {".", ".."} or candidate.startswith(
("./", "../", "/", "~/", ".\\", "..\\", "\\")
)
def looks_like_direct_reference(token: str) -> bool:
candidate = token.strip()
if not candidate:
return False
return (
_looks_like_local_path_reference(candidate)
or candidate.startswith("git+")
or "://" in candidate
)
def extract_requirement_name(raw_requirement: str) -> str | None:
line = raw_requirement.split("#", 1)[0].strip()
if not line:
return None
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
return None
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
if egg_match:
return canonicalize_distribution_name(egg_match.group(1))
if line.startswith("-"):
return None
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
if not candidate:
return None
return canonicalize_distribution_name(candidate)
def _parse_editable_or_direct_name(target: str) -> str | None:
name = extract_requirement_name(target)
if not name:
return None
if "#egg=" in target or not looks_like_direct_reference(target):
return name
return None
def _parse_requirement_name_and_spec(
line: str,
) -> tuple[str | None, SpecifierSet | None]:
if line.startswith(("-c", "--constraint")):
return None, None
try:
req = Requirement(line)
except InvalidRequirement:
tokens = shlex.split(line)
if not tokens:
return None, None
editable_target: str | None = None
if tokens[0] in {"-e", "--editable"} and len(tokens) > 1:
editable_target = tokens[1]
elif tokens[0].startswith("--editable="):
editable_target = tokens[0].split("=", 1)[1]
if editable_target:
name = _parse_editable_or_direct_name(editable_target)
return (name, None) if name else (None, None)
name = _parse_editable_or_direct_name(line)
return (name, None) if name else (None, None)
if req.marker and not req.marker.evaluate():
return None, None
return canonicalize_distribution_name(req.name), (req.specifier or None)
def _parse_requirement_line(
line: str,
) -> tuple[str, SpecifierSet | None] | None:
name, specifier = _parse_requirement_name_and_spec(line)
return (name, specifier) if name else None
def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]:
requirement_names: set[str] = set()
skip_next_for: str | None = None
for token in tokens:
if skip_next_for:
if skip_next_for == "editable":
name = _parse_editable_or_direct_name(token)
if name:
requirement_names.add(name)
skip_next_for = None
continue
if token in {"-e", "--editable"}:
skip_next_for = "editable"
continue
if token in {
"-i",
"--index-url",
"--extra-index-url",
"-f",
"--find-links",
"--trusted-host",
"-r",
"--requirement",
"-c",
"--constraint",
}:
skip_next_for = "option-value"
continue
if token.startswith(("--editable=",)):
editable_target = token.split("=", 1)[1]
name = _parse_editable_or_direct_name(editable_target)
if name:
requirement_names.add(name)
continue
if token.startswith(
(
"--index-url=",
"--extra-index-url=",
"--find-links=",
"--trusted-host=",
"--requirement=",
"--constraint=",
)
):
continue
if (
(token.startswith("-i") and token != "-i")
or (token.startswith("-f") and token != "-f")
or token == "--no-index"
):
continue
if token.startswith("-"):
continue
name, _ = _parse_requirement_name_and_spec(token)
if name:
requirement_names.add(name)
return frozenset(requirement_names)
def parse_package_install_input(raw_input: str) -> ParsedPackageInput:
specs: list[str] = []
requirement_names: set[str] = set()
normalized = raw_input.strip()
if not normalized:
return ParsedPackageInput(specs=(), requirement_names=frozenset())
for raw_line in normalized.splitlines():
line = strip_inline_requirement_comment(raw_line)
if not line:
continue
try:
Requirement(line)
except InvalidRequirement:
tokens = shlex.split(line)
if not tokens:
continue
specs.extend(tokens)
requirement_names.update(
_extract_requirement_names_from_package_tokens(tokens)
)
continue
specs.append(line)
name, _ = _parse_requirement_name_and_spec(line)
if name:
requirement_names.add(name)
return ParsedPackageInput(
specs=tuple(specs),
requirement_names=frozenset(requirement_names),
)
def _iter_requirement_lines(
requirements_path: str,
_visited: set[str] | None = None,
) -> Iterator[str]:
visited = _visited or set()
resolved_path = os.path.realpath(requirements_path)
if resolved_path in visited:
logger.warning(
"检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path
)
return
visited.add(resolved_path)
with open(resolved_path, encoding="utf-8") as f:
for raw_line in f:
line = strip_inline_requirement_comment(raw_line)
if not line:
continue
tokens = shlex.split(line)
if not tokens:
continue
nested: str | None = None
if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1:
nested = tokens[1]
elif tokens[0].startswith("--requirement="):
nested = tokens[0].split("=", 1)[1]
if nested:
if not os.path.isabs(nested):
nested = os.path.join(os.path.dirname(resolved_path), nested)
yield from _iter_requirement_lines(nested, _visited=visited)
continue
yield line
def iter_requirements(
requirements_path: str | None = None,
lines: Iterable[str] | None = None,
) -> Iterator[tuple[str, SpecifierSet | None]]:
if lines is None:
if requirements_path is None:
raise ValueError("Either requirements_path or lines must be provided")
lines = _iter_requirement_lines(requirements_path)
for line in lines:
parsed = _parse_requirement_line(line)
if parsed is not None:
yield parsed
def extract_requirement_names(requirements_path: str) -> set[str]:
try:
return {
name for name, _ in iter_requirements(requirements_path=requirements_path)
}
except Exception as exc:
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
return set()
def get_requirement_check_paths() -> list[str]:
paths = list(sys.path)
if is_packaged_desktop_runtime():
target_site_packages = get_astrbot_site_packages_path()
if os.path.isdir(target_site_packages):
paths.insert(0, target_site_packages)
return paths
def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
distribution_name = (
distribution.metadata["Name"] if "Name" in distribution.metadata else None
)
if not distribution_name:
return None, None
return canonicalize_distribution_name(distribution_name), distribution.version
def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None:
installed: dict[str, str] = {}
try:
for distribution in importlib_metadata.distributions(path=paths):
distribution_name, version = _canonical_distribution_identity(distribution)
if not distribution_name or not version:
continue
installed.setdefault(distribution_name, version)
except Exception as exc:
logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc)
return None
return installed
def _load_requirement_lines_for_precheck(
requirements_path: str,
) -> tuple[bool, list[str] | None]:
try:
requirement_lines = list(_iter_requirement_lines(requirements_path))
except Exception as exc:
logger.warning(
"预检查缺失依赖失败,将回退到完整安装: %s (%s)",
requirements_path,
exc,
)
return False, None
fallback_line = next(
(
line
for line in requirement_lines
if (
(
line.startswith(("-e ", "--editable ", "--editable="))
and "#egg=" not in line
)
or (
_parse_requirement_line(line) is None
and looks_like_direct_reference(line)
)
)
),
None,
)
if fallback_line is not None:
logger.warning(
"预检查缺失依赖失败,将回退到完整安装: unresolved direct reference in %s: %s",
requirements_path,
fallback_line,
)
return False, None
return True, requirement_lines
def find_missing_requirements(requirements_path: str) -> set[str] | None:
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
requirements_path
)
if not can_precheck or requirement_lines is None:
return None
required = list(iter_requirements(lines=requirement_lines))
if not required:
return set()
installed = collect_installed_distribution_versions(get_requirement_check_paths())
if installed is None:
return None
missing: set[str] = set()
for name, specifier in required:
installed_version = installed.get(name)
if not installed_version:
missing.add(name)
continue
if specifier and not _specifier_contains_version(specifier, installed_version):
missing.add(name)
return missing
def find_missing_requirements_or_raise(requirements_path: str) -> set[str]:
missing = find_missing_requirements(requirements_path)
if missing is None:
raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}")
return missing
+153 -1
View File
@@ -43,6 +43,7 @@ class SessionManagementRoute(Route):
"/session/group/create": ("POST", self.create_group),
"/session/group/update": ("POST", self.update_group),
"/session/group/delete": ("POST", self.delete_group),
"/session/group/update-config": ("POST", self.update_group_config),
}
self.conv_mgr = core_lifecycle.conversation_manager
self.core_lifecycle = core_lifecycle
@@ -145,9 +146,20 @@ class SessionManagementRoute(Route):
page=page, page_size=page_size, search=search
)
# 构建规则列表
# 收集属于有配置分组的 UMO,避免重复显示
grouped_umos = set()
groups = self._get_groups()
for group_data in groups.values():
if group_data.get("config"):
grouped_umos.update(group_data.get("umos", []))
# 构建规则列表(排除已被分组管理的 UMO)
rules_list = []
filtered_count = 0
for umo, rules in umo_rules.items():
if umo in grouped_umos:
filtered_count += 1
continue
rule_info = {
"umo": umo,
"rules": rules,
@@ -159,6 +171,7 @@ class SessionManagementRoute(Route):
rule_info["message_type"] = parts[1]
rule_info["session_id"] = parts[2]
rules_list.append(rule_info)
total -= filtered_count
# 获取可用的 providers 和 personas
provider_manager = self.core_lifecycle.provider_manager
@@ -240,6 +253,7 @@ class SessionManagementRoute(Route):
"available_plugins": available_plugins,
"available_kbs": available_kbs,
"available_rule_keys": AVAILABLE_SESSION_RULE_KEYS,
"group_rules": self._get_group_rules(),
}
)
.__dict__
@@ -793,6 +807,51 @@ class SessionManagementRoute(Route):
"""保存分组"""
sp.put("session_groups", groups)
def _get_group_rules(self) -> list:
"""获取有配置的分组列表,用于在规则列表中显示"""
groups = self._get_groups()
group_rules = []
for group_id, group_data in groups.items():
config = group_data.get("config", {})
if config: # 只返回有配置的分组
group_rules.append(
{
"group_id": group_id,
"name": group_data.get("name", ""),
"umo_count": len(group_data.get("umos", [])),
"config": config,
}
)
return group_rules
async def _sync_group_config_to_umos(
self, config: dict, umos: list[str]
) -> tuple[int, list[str]]:
"""将分组配置同步到指定的 UMO 列表
Returns:
(success_count, failed_umos)
"""
success_count = 0
failed_umos = []
for umo in umos:
try:
for rule_key, rule_value in config.items():
if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
continue
if rule_value is None:
continue
if rule_key == "session_plugin_config":
# session_plugin_config 需要包裹 umo key
await sp.session_put(umo, rule_key, {umo: rule_value})
else:
await sp.session_put(umo, rule_key, rule_value)
success_count += 1
except Exception as e:
logger.error(f"同步配置到 {umo} 失败: {e!s}")
failed_umos.append(umo)
return success_count, failed_umos
async def list_groups(self):
"""获取所有分组列表"""
try:
@@ -806,6 +865,7 @@ class SessionManagementRoute(Route):
"name": group_data.get("name", ""),
"umos": group_data.get("umos", []),
"umo_count": len(group_data.get("umos", [])),
"config": group_data.get("config", {}),
}
)
return Response().ok({"groups": groups_list}).__dict__
@@ -875,6 +935,7 @@ class SessionManagementRoute(Route):
return Response().error(f"分组 '{group_id}' 不存在").__dict__
group = groups[group_id]
old_umos = set(group.get("umos", []))
# 更新名称
if name is not None:
@@ -883,6 +944,7 @@ class SessionManagementRoute(Route):
# 直接设置 umos 列表
if umos is not None:
group["umos"] = umos
new_umos = set(umos)
else:
# 增量更新
current_umos = set(group.get("umos", []))
@@ -891,9 +953,21 @@ class SessionManagementRoute(Route):
if remove_umos:
current_umos.difference_update(remove_umos)
group["umos"] = list(current_umos)
new_umos = current_umos
self._save_groups(groups)
# 自动同步分组配置给新加入的成员
group_config = group.get("config", {})
newly_added = new_umos - old_umos
if group_config and newly_added:
sync_count, _ = await self._sync_group_config_to_umos(
group_config, list(newly_added)
)
logger.info(
f"自动同步分组 '{group['name']}' 配置到 {sync_count} 个新成员"
)
return (
Response()
.ok(
@@ -936,3 +1010,81 @@ class SessionManagementRoute(Route):
except Exception as e:
logger.error(f"删除分组失败: {e!s}")
return Response().error(f"删除分组失败: {e!s}").__dict__
async def update_group_config(self):
"""更新分组的配置,并同步到所有成员 UMO
请求体:
{
"group_id": "分组ID",
"config": {
"session_service_config": {...},
"session_plugin_config": {...},
"kb_config": {...},
"provider_perf_chat_completion": ...,
"provider_perf_speech_to_text": ...,
"provider_perf_text_to_speech": ...
}
}
"""
try:
data = await request.get_json()
group_id = data.get("group_id")
config = data.get("config", {})
if not group_id:
return Response().error("缺少必要参数: group_id").__dict__
groups = self._get_groups()
if group_id not in groups:
return Response().error(f"分组 '{group_id}' 不存在").__dict__
group = groups[group_id]
# 保存配置到分组
group["config"] = config
self._save_groups(groups)
# 同步到所有成员 UMO
umos = group.get("umos", [])
if not config:
# 空配置 → 清除成员上的所有分组下发规则
success_count = 0
failed_umos = []
for umo in umos:
try:
for rule_key in AVAILABLE_SESSION_RULE_KEYS:
try:
await sp.session_remove(umo, rule_key)
except Exception:
pass
success_count += 1
except Exception as e:
logger.error(f"清除 {umo} 规则失败: {e!s}")
failed_umos.append(umo)
else:
success_count, failed_umos = await self._sync_group_config_to_umos(
config, umos
)
msg = f"分组 '{group['name']}' 配置已保存并同步到 {success_count}/{len(umos)} 个会话"
if failed_umos:
msg += f"{len(failed_umos)} 个失败"
return (
Response()
.ok(
{
"message": msg,
"success_count": success_count,
"failed_count": len(failed_umos),
"failed_umos": failed_umos,
}
)
.__dict__
)
except Exception as e:
logger.error(f"更新分组配置失败: {e!s}")
return Response().error(f"更新分组配置失败: {e!s}").__dict__
+273 -51
View File
@@ -1,4 +1,4 @@
<template>
<template>
<div class="session-management-page">
<v-container fluid class="pa-0">
<v-card flat>
@@ -35,7 +35,16 @@
<!-- UMO 信息 -->
<template v-slot:item.umo_info="{ item }">
<div>
<div class="d-flex align-center">
<div class="d-flex align-center" v-if="item.isGroup">
<v-chip size="x-small" color="deep-purple" variant="flat" class="mr-2">
分组
</v-chip>
<span class="font-weight-medium">{{ item.groupName }}</span>
<v-chip size="x-small" variant="outlined" class="ml-2">
{{ item.umo_count }} 个会话
</v-chip>
</div>
<div class="d-flex align-center" v-else>
<v-chip size="x-small" :color="getPlatformColor(item.platform)" class="mr-2">
{{ item.platform || 'unknown' }}
</v-chip>
@@ -282,14 +291,24 @@
{{ tm('addRule.description') }}
</v-alert>
<v-autocomplete v-model="selectedNewUmo" :items="availableUmos" :loading="loadingUmos"
<v-radio-group v-model="addRuleTargetType" inline hide-details class="mb-4">
<v-radio label="单个会话" value="session"></v-radio>
<v-radio label="分组" value="group" :disabled="groups.length === 0"></v-radio>
</v-radio-group>
<v-autocomplete v-if="addRuleTargetType === 'session'" v-model="selectedNewUmo" :items="availableUmos" :loading="loadingUmos"
:label="tm('addRule.selectUmo')" variant="outlined" clearable :no-data-text="tm('addRule.noUmos')" />
<v-select v-if="addRuleTargetType === 'group'" v-model="selectedGroup" :items="groupSelectOptions"
item-title="label" item-value="value" return-object
label="选择分组" variant="outlined" clearable
:no-data-text="'暂无分组,请先创建分组'" />
</v-card-text>
<v-card-actions class="px-4 pb-4">
<v-spacer></v-spacer>
<v-btn variant="text" @click="addRuleDialog = false">{{ tm('buttons.cancel') }}</v-btn>
<v-btn color="primary" variant="tonal" @click="createNewRule" :disabled="!selectedNewUmo">
<v-btn color="primary" variant="tonal" @click="createNewRule" :disabled="addRuleTargetType === 'session' ? !selectedNewUmo : !selectedGroup">
{{ tm('buttons.next') }}
</v-btn>
</v-card-actions>
@@ -334,12 +353,7 @@
</v-col>
</v-row>
<div class="d-flex justify-end mt-4">
<v-btn color="primary" variant="tonal" size="small" @click="saveServiceConfig" :loading="saving"
prepend-icon="mdi-content-save">
{{ tm('buttons.save') }}
</v-btn>
</div>
<!-- Provider Config Section -->
<div class="d-flex align-center mb-4 mt-4">
@@ -364,12 +378,7 @@
</v-col>
</v-row>
<div class="d-flex justify-end mt-4">
<v-btn color="primary" variant="tonal" size="small" @click="saveProviderConfig" :loading="saving"
prepend-icon="mdi-content-save">
{{ tm('buttons.save') }}
</v-btn>
</div>
<!-- Persona Config Section -->
<div class="d-flex align-center mb-4 mt-4">
@@ -389,12 +398,7 @@
</v-col>
</v-row>
<div class="d-flex justify-end mt-4">
<v-btn color="primary" variant="tonal" size="small" @click="saveServiceConfig" :loading="saving"
prepend-icon="mdi-content-save">
{{ tm('buttons.save') }}
</v-btn>
</div>
<!-- Plugin Config Section -->
<div class="d-flex align-center mb-4 mt-4">
@@ -414,12 +418,7 @@
</v-col>
</v-row>
<div class="d-flex justify-end mt-4">
<v-btn color="primary" variant="tonal" size="small" @click="savePluginConfig" :loading="saving"
prepend-icon="mdi-content-save">
{{ tm('buttons.save') }}
</v-btn>
</div>
<!-- KB Config Section -->
<div class="d-flex align-center mb-4 mt-4">
@@ -442,14 +441,17 @@
</v-col>
</v-row>
<div class="d-flex justify-end mt-4">
<v-btn color="primary" variant="tonal" size="small" @click="saveKbConfig" :loading="saving"
prepend-icon="mdi-content-save">
{{ tm('buttons.save') }}
</v-btn>
</div>
</div>
</v-card-text>
<v-card-actions class="px-6 pb-4">
<v-spacer></v-spacer>
<v-btn variant="text" @click="closeRuleEditor">{{ tm('buttons.cancel') }}</v-btn>
<v-btn color="primary" variant="tonal" @click="saveAllConfigs" :loading="saving"
prepend-icon="mdi-content-save">
{{ tm('buttons.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -567,6 +569,8 @@ export default {
addRuleDialog: false,
availableUmos: [],
selectedNewUmo: null,
addRuleTargetType: 'session',
selectedGroup: null,
// 规则编辑
ruleDialog: false,
@@ -729,6 +733,13 @@ export default {
return options
},
groupSelectOptions() {
return this.groups.map(g => ({
label: `${g.name} (${g.umo_count} 个会话)`,
value: g,
}))
},
groupOptions() {
return this.groups.map(g => ({
label: `${g.name} (${g.umo_count} 个会话)`,
@@ -811,7 +822,7 @@ export default {
})
if (response.data.status === 'ok') {
const data = response.data.data
this.rulesList = data.rules
this.rulesList = data.rules || []
this.totalItems = data.total
this.availablePersonas = data.available_personas
this.availableChatProviders = data.available_chat_providers
@@ -819,6 +830,20 @@ export default {
this.availableTtsProviders = data.available_tts_providers
this.availablePlugins = data.available_plugins || []
this.availableKbs = data.available_kbs || []
// 合并分组规则到列表中
const groupRules = data.group_rules || []
for (const gr of groupRules) {
this.rulesList.unshift({
umo: `[\u5206\u7ec4] ${gr.name}`,
isGroup: true,
groupId: gr.group_id,
groupName: gr.name,
umo_count: gr.umo_count,
rules: gr.config || {},
})
}
this.totalItems += groupRules.length
} else {
this.showError(response.data.message || this.tm('messages.loadError'))
}
@@ -872,10 +897,89 @@ export default {
async openAddRuleDialog() {
this.addRuleDialog = true
this.selectedNewUmo = null
this.addRuleTargetType = 'session'
this.selectedGroup = null
await this.loadUmos()
},
async saveAllConfigs() {
if (!this.selectedUmo) return
// 分组模式:调用分组配置 API
if (this.selectedUmo.isGroup) {
this.saving = true
try {
const config = {
session_service_config: { ...this.serviceConfig },
provider_perf_chat_completion: this.providerConfig.chat_completion || null,
provider_perf_speech_to_text: this.providerConfig.speech_to_text || null,
provider_perf_text_to_speech: this.providerConfig.text_to_speech || null,
session_plugin_config: { ...this.pluginConfig },
kb_config: { ...this.kbConfig },
}
// 清理空值
if (!config.session_service_config.custom_name) delete config.session_service_config.custom_name
if (config.session_service_config.persona_id === null) delete config.session_service_config.persona_id
const response = await axios.post('/api/session/group/update-config', {
group_id: this.selectedUmo.groupId,
config: config
})
if (response.data.status === 'ok') {
this.showSuccess(response.data.data?.message || '分组配置已保存并同步')
await this.loadData()
} else {
this.showError(response.data.message || this.tm('messages.saveError'))
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
} finally {
this.saving = false
}
return
}
// 单个会话模式
this.saving = true
this._batchSaving = true
try {
await this.saveServiceConfig()
await this.saveProviderConfig()
await this.savePluginConfig()
await this.saveKbConfig()
this.showSuccess(this.tm('messages.saveSuccess'))
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
} finally {
this._batchSaving = false
this.saving = false
}
},
createNewRule() {
if (this.addRuleTargetType === 'group') {
// 分组模式
if (!this.selectedGroup) return
const group = this.selectedGroup.value || this.selectedGroup
if (!group.umos || group.umos.length === 0) {
this.showError('该分组没有成员会话')
return
}
// 创建一个特殊的规则项,标记为分组
const newItem = {
umo: `[分组] ${group.name}`,
isGroup: true,
groupId: group.id,
groupName: group.name,
groupUmos: group.umos,
rules: {},
}
this.addRuleDialog = false
this.openRuleEditor(newItem)
return
}
// 单个会话模式(原逻辑)
if (!this.selectedNewUmo) return
// 创建一个新的规则项并打开编辑器
@@ -943,13 +1047,37 @@ export default {
async saveServiceConfig() {
if (!this.selectedUmo) return
this.saving = true
if (!this._batchSaving) this.saving = true
try {
const config = { ...this.serviceConfig }
// 清理空值
if (!config.custom_name) delete config.custom_name
if (config.persona_id === null) delete config.persona_id
// 分组模式:批量下发给所有成员
if (this.selectedUmo.isGroup) {
const umos = this.selectedUmo.groupUmos
let successCount = 0
for (const umo of umos) {
try {
await axios.post('/api/session/update-rule', {
umo: umo,
rule_key: 'session_service_config',
rule_value: config
})
successCount++
} catch (e) {
console.error(`更新 ${umo} 失败:`, e)
}
}
if (!this._batchSaving) {
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的服务配置`)
await this.loadData()
this.saving = false
}
return
}
const response = await axios.post('/api/session/update-rule', {
umo: this.selectedUmo.umo,
rule_key: 'session_service_config',
@@ -957,7 +1085,7 @@ export default {
})
if (response.data.status === 'ok') {
this.showSuccess(this.tm('messages.saveSuccess'))
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
this.editingRules.session_service_config = config
// 更新或添加到列表
@@ -980,17 +1108,45 @@ export default {
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
}
this.saving = false
if (!this._batchSaving) this.saving = false
},
async saveProviderConfig() {
if (!this.selectedUmo) return
this.saving = true
if (!this._batchSaving) this.saving = true
try {
const providerTypes = ['chat_completion', 'speech_to_text', 'text_to_speech']
// 分组模式:批量下发给所有成员
if (this.selectedUmo.isGroup) {
const umos = this.selectedUmo.groupUmos
let successCount = 0
for (const umo of umos) {
try {
const tasks = []
for (const type of providerTypes) {
const value = this.providerConfig[type]
if (value) {
tasks.push(axios.post('/api/session/update-rule', { umo, rule_key: `provider_perf_${type}`, rule_value: value }))
}
}
if (tasks.length > 0) await Promise.all(tasks)
successCount++
} catch (e) {
console.error(`更新 ${umo} Provider 失败:`, e)
}
}
if (!this._batchSaving) {
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的 Provider 配置`)
await this.loadData()
this.saving = false
}
return
}
const updateTasks = []
const deleteTasks = []
const providerTypes = ['chat_completion', 'speech_to_text', 'text_to_speech']
for (const type of providerTypes) {
const value = this.providerConfig[type]
@@ -1017,7 +1173,7 @@ export default {
const allTasks = [...updateTasks, ...deleteTasks]
if (allTasks.length > 0) {
await Promise.all(allTasks)
this.showSuccess(this.tm('messages.saveSuccess'))
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
// 更新或添加到列表
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
@@ -1042,24 +1198,48 @@ export default {
}
}
} else {
this.showSuccess(this.tm('messages.noChanges'))
if (!this._batchSaving) this.showSuccess(this.tm('messages.noChanges'))
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
}
this.saving = false
if (!this._batchSaving) this.saving = false
},
async savePluginConfig() {
if (!this.selectedUmo) return
this.saving = true
if (!this._batchSaving) this.saving = true
try {
const config = {
enabled_plugins: this.pluginConfig.enabled_plugins,
disabled_plugins: this.pluginConfig.disabled_plugins,
}
// 分组模式:批量下发给所有成员
if (this.selectedUmo.isGroup) {
const umos = this.selectedUmo.groupUmos
let successCount = 0
for (const umo of umos) {
try {
if (config.enabled_plugins.length === 0 && config.disabled_plugins.length === 0) {
await axios.post('/api/session/delete-rule', { umo, rule_key: 'session_plugin_config' })
} else {
await axios.post('/api/session/update-rule', { umo, rule_key: 'session_plugin_config', rule_value: config })
}
successCount++
} catch (e) {
console.error(`更新 ${umo} 插件配置失败:`, e)
}
}
if (!this._batchSaving) {
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的插件配置`)
await this.loadData()
this.saving = false
}
return
}
// 如果两个列表都为空,删除配置
if (config.enabled_plugins.length === 0 && config.disabled_plugins.length === 0) {
if (this.editingRules.session_plugin_config) {
@@ -1071,7 +1251,7 @@ export default {
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
if (item) delete item.rules.session_plugin_config
}
this.showSuccess(this.tm('messages.saveSuccess'))
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
} else {
const response = await axios.post('/api/session/update-rule', {
umo: this.selectedUmo.umo,
@@ -1080,7 +1260,7 @@ export default {
})
if (response.data.status === 'ok') {
this.showSuccess(this.tm('messages.saveSuccess'))
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
this.editingRules.session_plugin_config = config
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
@@ -1102,13 +1282,13 @@ export default {
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
}
this.saving = false
if (!this._batchSaving) this.saving = false
},
async saveKbConfig() {
if (!this.selectedUmo) return
this.saving = true
if (!this._batchSaving) this.saving = true
try {
const config = {
kb_ids: this.kbConfig.kb_ids,
@@ -1116,6 +1296,30 @@ export default {
enable_rerank: this.kbConfig.enable_rerank,
}
// 分组模式:批量下发给所有成员
if (this.selectedUmo.isGroup) {
const umos = this.selectedUmo.groupUmos
let successCount = 0
for (const umo of umos) {
try {
if (config.kb_ids.length === 0) {
await axios.post('/api/session/delete-rule', { umo, rule_key: 'kb_config' })
} else {
await axios.post('/api/session/update-rule', { umo, rule_key: 'kb_config', rule_value: config })
}
successCount++
} catch (e) {
console.error(`更新 ${umo} 知识库配置失败:`, e)
}
}
if (!this._batchSaving) {
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的知识库配置`)
await this.loadData()
this.saving = false
}
return
}
// 如果 kb_ids 为空,删除配置
if (config.kb_ids.length === 0) {
if (this.editingRules.kb_config) {
@@ -1127,7 +1331,7 @@ export default {
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
if (item) delete item.rules.kb_config
}
this.showSuccess(this.tm('messages.saveSuccess'))
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
} else {
const response = await axios.post('/api/session/update-rule', {
umo: this.selectedUmo.umo,
@@ -1136,7 +1340,7 @@ export default {
})
if (response.data.status === 'ok') {
this.showSuccess(this.tm('messages.saveSuccess'))
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
this.editingRules.kb_config = config
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
@@ -1158,7 +1362,7 @@ export default {
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
}
this.saving = false
if (!this._batchSaving) this.saving = false
},
confirmDeleteRules(item) {
@@ -1171,6 +1375,24 @@ export default {
this.deleting = true
try {
// 分组规则:清空分组配置
if (this.deleteTarget.isGroup) {
const response = await axios.post('/api/session/group/update-config', {
group_id: this.deleteTarget.groupId,
config: {}
})
if (response.data.status === 'ok') {
this.showSuccess('分组配置已清除')
this.deleteDialog = false
this.deleteTarget = null
await this.loadData()
} else {
this.showError(response.data.message || this.tm('messages.deleteError'))
}
this.deleting = false
return
}
const response = await axios.post('/api/session/delete-rule', {
umo: this.deleteTarget.umo
})
+30
View File
@@ -16,6 +16,7 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.star import star_registry
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.utils.pip_installer import PipInstallError
from astrbot.dashboard.routes.plugin import PluginRoute
from astrbot.dashboard.server import AstrBotDashboard
from tests.fixtures.helpers import (
@@ -359,6 +360,35 @@ async def test_do_update(
assert os.path.exists(release_path)
@pytest.mark.asyncio
async def test_install_pip_package_returns_pip_install_error_message(
app: Quart,
authenticated_header: dict,
monkeypatch,
):
test_client = app.test_client()
async def mock_pip_install(*args, **kwargs):
del args, kwargs
raise PipInstallError("install failed", code=2)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.pip_installer.install",
mock_pip_install,
)
response = await test_client.post(
"/api/update/pip-install",
headers=authenticated_header,
json={"package": "demo-package"},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "error"
assert data["message"] == "install failed"
class _FakeNeoSkills:
async def list_candidates(self, **kwargs):
_ = kwargs
+266
View File
@@ -0,0 +1,266 @@
from pathlib import Path
import pytest
from astrbot.core.utils import core_constraints as core_constraints_module
from astrbot.core.utils import requirements_utils
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
def test_requirements_utils_parse_package_install_input_collects_specs_and_names():
parsed = requirements_utils.parse_package_install_input(
"--index-url https://example.com/simple demo-package\nanother-package>=1.0\n"
)
assert parsed.specs == (
"--index-url",
"https://example.com/simple",
"demo-package",
"another-package>=1.0",
)
assert parsed.requirement_names == {"demo-package", "another-package"}
def test_core_constraints_provider_writes_constraints_file_from_fallback_distribution(
monkeypatch,
):
class FakeFallbackDistribution:
metadata = {"Name": "AstrBot-App"}
requires = ["shared-lib>=1.0"]
def read_text(self, name):
if name == "top_level.txt":
return "astrbot\n"
return ""
fake_distribution = FakeFallbackDistribution()
def mock_distribution(name):
if name == "AstrBot":
raise core_constraints_module.importlib_metadata.PackageNotFoundError
if name == "AstrBot-App":
return fake_distribution
raise core_constraints_module.importlib_metadata.PackageNotFoundError
def mock_distributions(path=None):
del path
return [fake_distribution]
monkeypatch.setattr(
core_constraints_module.importlib_metadata,
"distribution",
mock_distribution,
)
monkeypatch.setattr(
core_constraints_module.importlib_metadata,
"distributions",
mock_distributions,
)
monkeypatch.setattr(
core_constraints_module,
"collect_installed_distribution_versions",
lambda paths: {"shared-lib": "2.0"},
)
core_constraints_module._get_core_constraints.cache_clear()
try:
provider = CoreConstraintsProvider(None)
with provider.constraints_file() as constraints_path:
assert constraints_path is not None
assert (
Path(constraints_path).read_text(encoding="utf-8") == "shared-lib==2.0"
)
finally:
core_constraints_module._get_core_constraints.cache_clear()
def test_resolve_core_dist_name_skips_distribution_without_name(monkeypatch):
class NamelessDistribution:
metadata = {}
def read_text(self, name):
if name == "top_level.txt":
return "astrbot\n"
return ""
class NamedDistribution:
metadata = {"Name": "AstrBot-App"}
def read_text(self, name):
if name == "top_level.txt":
return "astrbot\n"
return ""
monkeypatch.setattr(
core_constraints_module.importlib_metadata,
"distribution",
lambda name: (_ for _ in ()).throw(
core_constraints_module.importlib_metadata.PackageNotFoundError
),
)
monkeypatch.setattr(
core_constraints_module.importlib_metadata,
"distributions",
lambda: [NamelessDistribution(), NamedDistribution()],
)
assert core_constraints_module._resolve_core_dist_name(None) == "AstrBot-App"
def test_find_missing_requirements_returns_none_when_precheck_gate_fails(
monkeypatch,
tmp_path,
):
requirements_path = tmp_path / "requirements.txt"
requirements_path.write_text("demo-package\n", encoding="utf-8")
monkeypatch.setattr(
requirements_utils,
"_load_requirement_lines_for_precheck",
lambda path: (False, None),
)
missing = requirements_utils.find_missing_requirements(str(requirements_path))
assert missing is None
def test_parse_package_install_input_tracks_only_named_direct_references():
named = requirements_utils.parse_package_install_input(
"git+https://example.com/demo.git#egg=demo-package"
)
unnamed = requirements_utils.parse_package_install_input(
"git+https://example.com/demo.git"
)
assert named.requirement_names == {"demo-package"}
assert unnamed.requirement_names == set()
def test_find_missing_requirements_or_raise_uses_requirements_exception(tmp_path):
requirements_path = tmp_path / "requirements.txt"
requirements_path.write_text("-e ../sharedlib\n", encoding="utf-8")
with pytest.raises(requirements_utils.RequirementsPrecheckFailed):
requirements_utils.find_missing_requirements_or_raise(str(requirements_path))
def test_find_missing_requirements_logs_path_and_reason_on_precheck_fallback(
monkeypatch,
tmp_path,
):
requirements_path = tmp_path / "requirements.txt"
requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8")
warning_logs = []
monkeypatch.setattr(
"astrbot.core.utils.requirements_utils.logger.warning",
lambda line, *args: warning_logs.append(line % args if args else line),
)
missing = requirements_utils.find_missing_requirements(str(requirements_path))
assert missing is None
assert any(str(requirements_path) in log for log in warning_logs)
assert any("direct reference" in log for log in warning_logs)
def test_load_requirement_lines_for_precheck_uses_parse_requirement_line_result(
monkeypatch,
tmp_path,
):
requirements_path = tmp_path / "requirements.txt"
requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8")
monkeypatch.setattr(
requirements_utils,
"_parse_requirement_line",
lambda line: ("demo-package", None) if line.startswith("git+") else None,
)
can_precheck, requirement_lines = (
requirements_utils._load_requirement_lines_for_precheck(str(requirements_path))
)
assert can_precheck is True
assert requirement_lines == ["git+https://example.com/demo.git"]
def test_collect_installed_distribution_versions_skips_nameless_distribution(
monkeypatch,
):
class NamelessDistribution:
metadata = {}
version = "1.0"
class NamedDistribution:
metadata = {"Name": "demo-package"}
version = "2.0"
monkeypatch.setattr(
requirements_utils.importlib_metadata,
"distributions",
lambda path: [NamelessDistribution(), NamedDistribution()],
)
installed = requirements_utils.collect_installed_distribution_versions(
["/tmp/test"]
)
assert installed == {"demo-package": "2.0"}
def test_get_core_constraints_logs_resolution_step_context(monkeypatch):
warning_logs = []
monkeypatch.setattr(
core_constraints_module,
"_resolve_core_dist_name",
lambda core_dist_name: (_ for _ in ()).throw(RuntimeError("boom")),
)
monkeypatch.setattr(
"astrbot.core.utils.core_constraints.logger.warning",
lambda line, *args: warning_logs.append(line % args if args else line),
)
core_constraints_module._get_core_constraints.cache_clear()
try:
constraints = core_constraints_module._get_core_constraints(None)
finally:
core_constraints_module._get_core_constraints.cache_clear()
assert constraints == ()
assert any("解析核心分发名称失败" in log for log in warning_logs)
def test_iter_requirements_supports_direct_line_input():
parsed = list(
requirements_utils.iter_requirements(
lines=["demo-package>=1.0", 'other-package; sys_platform == "win32"']
)
)
assert parsed == [
("demo-package", requirements_utils.Requirement("demo-package>=1.0").specifier)
]
def test_parse_requirement_name_and_spec_preserves_direct_reference_rules():
named = requirements_utils._parse_requirement_name_and_spec(
"git+https://example.com/demo.git#egg=demo-package"
)
unnamed = requirements_utils._parse_requirement_name_and_spec(
"git+https://example.com/demo.git"
)
assert named == ("demo-package", None)
assert unnamed == (None, None)
def test_parse_requirement_name_and_spec_handles_plain_requirement_token():
parsed = requirements_utils._parse_requirement_name_and_spec("demo-package>=1.0")
assert parsed == (
"demo-package",
requirements_utils.Requirement("demo-package>=1.0").specifier,
)
File diff suppressed because it is too large Load Diff
+426 -189
View File
@@ -1,235 +1,472 @@
import sys
from asyncio import Queue
import asyncio
from pathlib import Path
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
import yaml
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from astrbot.core.star.star import star_map, star_registry
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_manager import PluginDependencyInstallError, PluginManager
from astrbot.core.utils.pip_installer import PipInstallError
# --- Test Data & Helpers ---
def _clear_module_cache() -> None:
"""Clear module cache for data module tree to ensure test isolation."""
modules_to_remove = [
key for key in sys.modules if key == "data" or key.startswith("data.")
]
for key in modules_to_remove:
del sys.modules[key]
def _clear_registry(plugin_name: str) -> None:
"""Clear plugin from global registries."""
# Clear star_registry (list)
star_registry[:] = [md for md in star_registry if md.name != plugin_name]
# Clear star_map (dict)
keys_to_remove = [
key for key, md in star_map.items() if md.name == plugin_name
]
for key in keys_to_remove:
del star_map[key]
# Clear star_handlers_registry (StarHandlerRegistry)
for handler in list(star_handlers_registry):
if plugin_name in (handler.handler_module_path or ""):
star_handlers_registry.remove(handler)
TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld"
TEST_PLUGIN_DIR = "helloworld"
TEST_PLUGIN_NAME = "helloworld"
TEST_PLUGIN_REPO = "https://github.com/AstrBotDevs/astrbot_plugin_helloworld"
TEST_PLUGIN_DIR = "helloworld"
def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None:
plugin_dir.mkdir(parents=True, exist_ok=True)
(plugin_dir / "metadata.yaml").write_text(
"\n".join(
[
f"name: {TEST_PLUGIN_NAME}",
"author: AstrBot Team",
"desc: Local test plugin",
"version: 1.0.0",
f"repo: {repo_url}",
],
)
+ "\n",
encoding="utf-8",
)
(plugin_dir / "main.py").write_text(
"\n".join(
[
"from astrbot.api import star",
"",
"class Main(star.Star):",
" pass",
"",
],
),
encoding="utf-8",
class MockStar:
def __init__(self):
self.root_dir_name = TEST_PLUGIN_DIR
self.name = TEST_PLUGIN_NAME
self.repo = TEST_PLUGIN_REPO
self.reserved = False
self.info = {"repo": TEST_PLUGIN_REPO, "readme": ""}
def _write_local_test_plugin(plugin_path: Path, repo_url: str):
"""Creates a minimal valid plugin structure."""
plugin_path.mkdir(parents=True, exist_ok=True)
metadata = {
"name": TEST_PLUGIN_NAME,
"repo": repo_url,
"version": "1.0.0",
"author": "AstrBot Team",
"desc": "Local test plugin",
}
with open(plugin_path / "info.yaml", "w", encoding="utf-8") as f:
yaml.dump(metadata, f)
with open(plugin_path / "main.py", "w", encoding="utf-8") as f:
f.write("from astrbot.api.star import Star, Context, StarManager\n")
f.write("@StarManager.register\n")
f.write("class HelloWorld(Star):\n")
f.write(" def __init__(self, context: Context): ...\n")
def _write_requirements(plugin_path: Path):
"""Creates a requirements.txt file."""
with open(plugin_path / "requirements.txt", "w", encoding="utf-8") as f:
f.write("networkx\n")
def _clear_module_cache():
"""Clear test-specific modules from sys.modules to allow reloading."""
import sys
to_del = [m for m in sys.modules if m.startswith("data.plugins.helloworld")]
for m in to_del:
del sys.modules[m]
def _build_load_mock(events):
async def mock_load(specified_dir_name=None, ignore_version_check=False):
del ignore_version_check
events.append(("load", specified_dir_name or TEST_PLUGIN_DIR))
return True, ""
return mock_load
def _build_reload_mock(events):
async def mock_reload(specified_dir_name=None):
events.append(("reload", specified_dir_name or TEST_PLUGIN_DIR))
return True, ""
return mock_reload
def _build_dependency_install_mock(events, fail: bool):
async def mock_install_requirements(
*, requirements_path: str = None, package_name: str = None, **kwargs
):
del kwargs
if requirements_path:
events.append(("deps", str(requirements_path)))
if package_name:
events.append(("deps_pkg", package_name))
if fail:
raise Exception("pip failed")
return mock_install_requirements
def _mock_missing_requirements(monkeypatch, missing: set[str]):
monkeypatch.setattr(
"astrbot.core.star.star_manager.find_missing_requirements_or_raise",
lambda requirements_path: missing,
)
@pytest_asyncio.fixture
async def plugin_manager_pm(tmp_path, monkeypatch):
def _mock_precheck_fails(monkeypatch):
from astrbot.core import RequirementsPrecheckFailed
def mock_fail(requirements_path):
raise RequirementsPrecheckFailed("mock precheck failure")
monkeypatch.setattr(
"astrbot.core.star.star_manager.find_missing_requirements_or_raise",
mock_fail,
)
# --- Fixtures ---
@pytest.fixture
def plugin_manager_pm(tmp_path, monkeypatch):
"""Provides a fully isolated PluginManager instance for testing."""
# Clear module cache before setup to ensure isolation
_clear_module_cache()
test_root = tmp_path / "astrbot_root"
data_dir = test_root / "data"
plugin_dir = data_dir / "plugins"
config_dir = data_dir / "config"
temp_dir = data_dir / "temp"
for path in (plugin_dir, config_dir, temp_dir):
path.mkdir(parents=True, exist_ok=True)
plugin_dir = tmp_path / "astrbot_root" / "data" / "plugins"
plugin_dir.mkdir(parents=True, exist_ok=True)
# Ensure `import data.plugins.<plugin>.main` resolves to this temp root.
(data_dir / "__init__.py").write_text("", encoding="utf-8")
(plugin_dir / "__init__.py").write_text("", encoding="utf-8")
class MockContext:
def __init__(self):
self.stars = []
# Use monkeypatch for both env var and sys.path to ensure proper cleanup
monkeypatch.setenv("ASTRBOT_ROOT", str(test_root))
monkeypatch.syspath_prepend(str(test_root))
def get_all_stars(self):
return self.stars
# Create fresh, isolated instances for the context
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase(str(data_dir / "test_db.db"))
config.plugin_store_path = str(plugin_dir)
def get_registered_star(self, name):
for s in self.stars:
if s.root_dir_name == name or s.name == name:
return s
return None
provider_manager = MagicMock()
platform_manager = MagicMock()
conversation_manager = MagicMock()
message_history_manager = MagicMock()
persona_manager = MagicMock()
persona_manager.personas_v3 = []
astrbot_config_mgr = MagicMock()
knowledge_base_manager = MagicMock()
cron_manager = MagicMock()
mock_context = MockContext()
mock_config = {}
pm = PluginManager(mock_context, mock_config)
star_context = Context(
event_queue=event_queue,
config=config,
db=db,
provider_manager=provider_manager,
platform_manager=platform_manager,
conversation_manager=conversation_manager,
message_history_manager=message_history_manager,
persona_manager=persona_manager,
astrbot_config_mgr=astrbot_config_mgr,
knowledge_base_manager=knowledge_base_manager,
cron_manager=cron_manager,
subagent_orchestrator=None,
# Patch paths to use tmp_path
monkeypatch.setattr(pm, "plugin_store_path", str(plugin_dir))
monkeypatch.setattr(
"astrbot.core.star.star_manager.get_astrbot_plugin_path",
lambda: str(plugin_dir),
)
manager = PluginManager(star_context, config)
try:
yield manager
finally:
# Cleanup global registries and module cache
_clear_registry(TEST_PLUGIN_NAME)
_clear_module_cache()
await db.engine.dispose()
return pm
@pytest.fixture
def local_updator(plugin_manager_pm: PluginManager, monkeypatch):
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
def local_updator(plugin_manager_pm):
"""Helper to setup a local plugin directory simulating a download."""
path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
_write_local_test_plugin(path, TEST_PLUGIN_REPO)
return path
async def mock_install(repo_url: str, proxy=""): # noqa: ARG001
if repo_url != TEST_PLUGIN_REPO:
raise Exception("Repo not found")
# --- Tests ---
@pytest.mark.asyncio
@pytest.mark.parametrize("dependency_install_fails", [False, True])
async def test_install_plugin_dependency_install_flow(
plugin_manager_pm: PluginManager, monkeypatch, dependency_install_fails: bool
):
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
events = []
_mock_missing_requirements(monkeypatch, {"networkx"})
async def mock_install(repo_url: str, proxy=""):
assert repo_url == TEST_PLUGIN_REPO
_write_local_test_plugin(plugin_path, repo_url)
_write_requirements(plugin_path)
return str(plugin_path)
async def mock_update(plugin, proxy=""): # noqa: ARG001
if plugin.name != TEST_PLUGIN_NAME:
raise Exception("Plugin not found")
if not plugin_path.exists():
raise Exception("Plugin path missing")
(plugin_path / ".updated").write_text("ok", encoding="utf-8")
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
return plugin_path
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
_build_dependency_install_mock(events, dependency_install_fails),
)
def mock_load_and_register(*args, **kwargs):
plugin_manager_pm.context.stars.append(MockStar())
return _build_load_mock(events)(*args, **kwargs)
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
if dependency_install_fails:
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
assert events == [("deps", str(plugin_path / "requirements.txt"))]
else:
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
assert events == [
("deps", str(plugin_path / "requirements.txt")),
("load", TEST_PLUGIN_DIR),
]
@pytest.mark.asyncio
async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
assert plugin_manager_pm is not None
assert plugin_manager_pm.context is not None
assert plugin_manager_pm.config is not None
@pytest.mark.asyncio
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = await plugin_manager_pm.reload()
assert success is True
assert err_message is None
@pytest.mark.asyncio
async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
"""Tests successful plugin installation without external network."""
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
assert plugin_info is not None
assert plugin_info["name"] == TEST_PLUGIN_NAME
assert local_updator.exists()
assert any(md.name == TEST_PLUGIN_NAME for md in star_registry)
@pytest.mark.asyncio
async def test_install_nonexistent_plugin(
plugin_manager_pm: PluginManager, local_updator
@pytest.mark.parametrize("dependency_install_fails", [False, True])
async def test_install_plugin_from_file_dependency_install_flow(
plugin_manager_pm: PluginManager,
monkeypatch,
tmp_path,
dependency_install_fails: bool,
):
"""Tests that installing a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.install_plugin(
"https://github.com/Soulter/non_existent_repo"
zip_file_path = tmp_path / f"{TEST_PLUGIN_DIR}.zip"
zip_file_path.write_text("placeholder", encoding="utf-8")
events = []
_mock_missing_requirements(monkeypatch, {"networkx"})
def mock_unzip_file(zip_path: str, target_dir: str) -> None:
assert zip_path == str(zip_file_path)
plugin_path = Path(target_dir)
_write_local_test_plugin(plugin_path, TEST_PLUGIN_REPO)
_write_requirements(plugin_path)
monkeypatch.setattr(plugin_manager_pm.updator, "unzip_file", mock_unzip_file)
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
_build_dependency_install_mock(events, dependency_install_fails),
)
def mock_load_and_register(*args, **kwargs):
plugin_manager_pm.context.stars.append(MockStar())
return _build_load_mock(events)(*args, **kwargs)
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
if dependency_install_fails:
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
await plugin_manager_pm.install_plugin_from_file(str(zip_file_path))
assert any(e[0] == "deps" for e in events)
else:
await plugin_manager_pm.install_plugin_from_file(str(zip_file_path))
assert any(e[0] == "deps" for e in events)
assert ("load", TEST_PLUGIN_DIR) in events
@pytest.mark.asyncio
@pytest.mark.parametrize("dependency_install_fails", [False, True])
async def test_reload_failed_plugin_dependency_install_flow(
plugin_manager_pm: PluginManager,
local_updator: Path,
monkeypatch,
dependency_install_fails: bool,
):
_write_requirements(local_updator)
plugin_manager_pm.failed_plugin_dict[TEST_PLUGIN_DIR] = {"error": "init fail"}
events = []
_mock_missing_requirements(monkeypatch, {"networkx"})
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
_build_dependency_install_mock(events, dependency_install_fails),
)
def mock_load_and_register(*args, **kwargs):
plugin_manager_pm.context.stars.append(MockStar())
return _build_load_mock(events)(*args, **kwargs)
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
if dependency_install_fails:
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR)
assert events == [("deps", str(local_updator / "requirements.txt"))]
else:
await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR)
assert events == [
("deps", str(local_updator / "requirements.txt")),
("load", TEST_PLUGIN_DIR),
]
@pytest.mark.asyncio
async def test_ensure_plugin_requirements_reraises_cancelled_error(
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
):
_write_requirements(local_updator)
_mock_missing_requirements(monkeypatch, {"networkx"})
async def mock_install_requirements(*args, **kwargs):
raise asyncio.CancelledError()
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
mock_install_requirements,
)
with pytest.raises(asyncio.CancelledError):
await plugin_manager_pm._ensure_plugin_requirements(
str(local_updator),
TEST_PLUGIN_DIR,
)
@pytest.mark.asyncio
async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
"""Tests updating an existing plugin without external network."""
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
assert plugin_info is not None
plugin_name = plugin_info["name"]
await plugin_manager_pm.update_plugin(plugin_name)
assert (local_updator / ".updated").exists()
@pytest.mark.asyncio
async def test_update_nonexistent_plugin(
plugin_manager_pm: PluginManager, local_updator
async def test_ensure_plugin_requirements_wraps_generic_dependency_install_failure(
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
):
"""Tests that updating a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("non_existent_plugin")
_write_requirements(local_updator)
_mock_missing_requirements(monkeypatch, {"networkx"})
async def mock_install_requirements(*args, **kwargs):
raise RuntimeError("pip failed")
@pytest.mark.asyncio
async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
"""Tests successful plugin uninstallation."""
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
assert plugin_info is not None
plugin_name = plugin_info["name"]
assert local_updator.exists()
await plugin_manager_pm.uninstall_plugin(plugin_name)
assert not local_updator.exists()
assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry)
assert not any(
TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
mock_install_requirements,
)
with pytest.raises(PluginDependencyInstallError, match="pip failed") as exc_info:
await plugin_manager_pm._ensure_plugin_requirements(
str(local_updator),
TEST_PLUGIN_DIR,
)
assert exc_info.value.plugin_label == TEST_PLUGIN_DIR
assert exc_info.value.requirements_path == str(local_updator / "requirements.txt")
assert isinstance(exc_info.value.__cause__, RuntimeError)
@pytest.mark.asyncio
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that uninstalling a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")
async def test_ensure_plugin_requirements_wraps_pip_install_error(
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
):
_write_requirements(local_updator)
_mock_missing_requirements(monkeypatch, {"networkx"})
async def mock_install_requirements(*args, **kwargs):
raise PipInstallError("install failed", code=2)
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
mock_install_requirements,
)
with pytest.raises(PluginDependencyInstallError, match="install failed") as exc_info:
await plugin_manager_pm._ensure_plugin_requirements(
str(local_updator),
TEST_PLUGIN_DIR,
)
assert isinstance(exc_info.value.__cause__, PipInstallError)
@pytest.mark.asyncio
async def test_ensure_plugin_requirements_logs_requirements_file_install_for_missing_dependencies(
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
):
_write_requirements(local_updator)
_mock_missing_requirements(monkeypatch, {"networkx"})
logged_lines = []
async def mock_install_requirements(*args, **kwargs):
return None
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
mock_install_requirements,
)
monkeypatch.setattr(
"astrbot.core.star.star_manager.logger.info",
lambda line, *args: logged_lines.append(line % args if args else line),
)
await plugin_manager_pm._ensure_plugin_requirements(
str(local_updator),
TEST_PLUGIN_DIR,
)
assert any("按 requirements.txt 安装" in line for line in logged_lines)
@pytest.mark.asyncio
@pytest.mark.parametrize("dependency_install_fails", [False, True])
async def test_update_plugin_dependency_install_flow(
plugin_manager_pm: PluginManager,
local_updator: Path,
monkeypatch,
dependency_install_fails: bool,
):
mock_star = MockStar()
plugin_manager_pm.context.stars.append(mock_star)
_write_requirements(local_updator)
events = []
_mock_missing_requirements(monkeypatch, {"networkx"})
async def mock_update(plugin, proxy=""):
del proxy
events.append(("update", plugin.name))
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
_build_dependency_install_mock(events, dependency_install_fails),
)
monkeypatch.setattr(plugin_manager_pm, "reload", _build_reload_mock(events))
if dependency_install_fails:
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME)
assert ("deps", str(local_updator / "requirements.txt")) in events
else:
await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME)
assert ("deps", str(local_updator / "requirements.txt")) in events
assert ("reload", TEST_PLUGIN_DIR) in events
@pytest.mark.asyncio
async def test_install_plugin_skips_dependency_install_when_no_requirements_missing(
plugin_manager_pm: PluginManager, monkeypatch
):
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
events = []
_mock_missing_requirements(monkeypatch, set())
async def mock_install(repo_url: str, proxy=""):
_write_local_test_plugin(plugin_path, repo_url)
_write_requirements(plugin_path)
return str(plugin_path)
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
_build_dependency_install_mock(events, False),
)
def mock_load_and_register(*args, **kwargs):
plugin_manager_pm.context.stars.append(MockStar())
return _build_load_mock(events)(*args, **kwargs)
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
assert "deps" not in [e[0] for e in events]
assert ("load", TEST_PLUGIN_DIR) in events
@pytest.mark.asyncio
async def test_install_plugin_runs_dependency_install_when_precheck_fails(
plugin_manager_pm: PluginManager, monkeypatch
):
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
events = []
async def mock_install(repo_url: str, proxy=""):
_write_local_test_plugin(plugin_path, repo_url)
_write_requirements(plugin_path)
return str(plugin_path)
_mock_precheck_fails(monkeypatch)
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
monkeypatch.setattr(
"astrbot.core.star.star_manager.pip_installer.install",
_build_dependency_install_mock(events, False),
)
def mock_load_and_register(*args, **kwargs):
plugin_manager_pm.context.stars.append(MockStar())
return _build_load_mock(events)(*args, **kwargs)
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
assert ("deps", str(plugin_path / "requirements.txt")) in events
assert ("load", TEST_PLUGIN_DIR) in events