diff --git a/.gitignore b/.gitignore index 486b9628d..004481c61 100644 --- a/.gitignore +++ b/.gitignore @@ -61,3 +61,5 @@ GenieData/ .codex/ .opencode/ .kilocode/ +.worktrees/ +docs/plans/ diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 5c015e96e..51690ede2 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -4,7 +4,21 @@ from astrbot.core.config import AstrBotConfig from astrbot.core.config.default import DB_PATH from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.file_token_service import FileTokenService -from astrbot.core.utils.pip_installer import PipInstaller +from astrbot.core.utils.pip_installer import ( + DependencyConflictError as DependencyConflictError, +) +from astrbot.core.utils.pip_installer import ( + PipInstaller, +) +from astrbot.core.utils.requirements_utils import ( + RequirementsPrecheckFailed as RequirementsPrecheckFailed, +) +from astrbot.core.utils.requirements_utils import ( + find_missing_requirements as find_missing_requirements, +) +from astrbot.core.utils.requirements_utils import ( + find_missing_requirements_or_raise as find_missing_requirements_or_raise, +) from astrbot.core.utils.shared_preferences import SharedPreferences from astrbot.core.utils.t2i.renderer import HtmlRenderer diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index b812698f2..cf000c5a4 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -14,7 +14,12 @@ import yaml from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import InvalidVersion, Version -from astrbot.core import logger, pip_installer, sp +from astrbot.core import ( + DependencyConflictError, + logger, + pip_installer, + sp, +) from astrbot.core.agent.handoff import FunctionTool, HandoffTool from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.default import VERSION @@ -27,6 +32,10 @@ from astrbot.core.utils.astrbot_path import ( ) from astrbot.core.utils.io import remove_dir from astrbot.core.utils.metrics import Metric +from astrbot.core.utils.requirements_utils import ( + RequirementsPrecheckFailed, + find_missing_requirements_or_raise, +) from . import StarMetadata from .command_management import sync_command_configs @@ -48,6 +57,49 @@ class PluginVersionIncompatibleError(Exception): """Raised when plugin astrbot_version is incompatible with current AstrBot.""" +class PluginDependencyInstallError(Exception): + """Raised when plugin dependency installation fails.""" + + def __init__( + self, + *, + plugin_label: str, + requirements_path: str, + error: Exception, + ) -> None: + message = f"插件 {plugin_label} 依赖安装失败: {error!s}" + super().__init__(message) + self.plugin_label = plugin_label + self.requirements_path = requirements_path + self.error = error + + +async def _install_requirements_with_precheck( + *, + plugin_label: str, + requirements_path: str, +) -> None: + try: + missing = find_missing_requirements_or_raise(requirements_path) + except RequirementsPrecheckFailed: + logger.info( + f"正在安装插件 {plugin_label} 的依赖库(预检查失败,回退到完整安装): " + f"{requirements_path}" + ) + await pip_installer.install(requirements_path=requirements_path) + return + + if not missing: + logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。") + return + + logger.info( + f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: " + f"{requirements_path} -> {sorted(missing)}" + ) + await pip_installer.install(requirements_path=requirements_path) + + class PluginManager: def __init__(self, context: Context, config: AstrBotConfig) -> None: from .star_tools import StarTools @@ -198,15 +250,37 @@ class PluginManager: to_update.append(p.root_dir_name) for p in to_update: plugin_path = os.path.join(plugin_dir, p) - if os.path.exists(os.path.join(plugin_path, "requirements.txt")): - pth = os.path.join(plugin_path, "requirements.txt") - logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}") - try: - await pip_installer.install(requirements_path=pth) - except Exception as e: - logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}") + await self._ensure_plugin_requirements(plugin_path, p) return True + async def _ensure_plugin_requirements( + self, + plugin_dir_path: str, + plugin_label: str, + ) -> None: + requirements_path = os.path.join(plugin_dir_path, "requirements.txt") + if not os.path.exists(requirements_path): + return + + try: + await _install_requirements_with_precheck( + plugin_label=plugin_label, + requirements_path=requirements_path, + ) + except asyncio.CancelledError: + raise + except DependencyConflictError as e: + logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}") + raise + except Exception as e: + dependency_error = PluginDependencyInstallError( + plugin_label=plugin_label, + requirements_path=requirements_path, + error=e, + ) + logger.exception(str(dependency_error)) + raise dependency_error from e + async def _import_plugin_with_dependency_recovery( self, path: str, @@ -422,7 +496,7 @@ class PluginManager: root_dir_name: str, plugin_dir_path: str, reserved: bool, - error: Exception | str, + error: BaseException | str, error_trace: str, ) -> dict: record: dict = { @@ -495,6 +569,9 @@ class PluginManager: self._cleanup_plugin_state(dir_name) + plugin_path = os.path.join(self.plugin_store_path, dir_name) + await self._ensure_plugin_requirements(plugin_path, dir_name) + success, error = await self.load(specified_dir_name=dir_name) if success: self.failed_plugin_dict.pop(dir_name, None) @@ -1078,6 +1155,10 @@ class PluginManager: # reload the plugin dir_name = os.path.basename(plugin_path) + await self._ensure_plugin_requirements( + plugin_path, + dir_name, + ) success, error_message = await self.load( specified_dir_name=dir_name, ignore_version_check=ignore_version_check, @@ -1317,6 +1398,12 @@ class PluginManager: raise Exception("该插件是 AstrBot 保留插件,无法更新。") await self.updator.update(plugin, proxy=proxy) + if plugin.root_dir_name: + plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) + await self._ensure_plugin_requirements( + plugin_dir_path, + plugin_name, + ) await self.reload(plugin_name) async def turn_off_plugin(self, plugin_name: str) -> None: @@ -1488,6 +1575,7 @@ class PluginManager: os.remove(zip_file_path) except BaseException as e: logger.warning(f"删除插件压缩包失败: {e!s}") + await self._ensure_plugin_requirements(desti_dir, dir_name) # await self.reload() success, error_message = await self.load( specified_dir_name=dir_name, diff --git a/astrbot/core/utils/core_constraints.py b/astrbot/core/utils/core_constraints.py new file mode 100644 index 000000000..b43f00122 --- /dev/null +++ b/astrbot/core/utils/core_constraints.py @@ -0,0 +1,121 @@ +import contextlib +import functools +import importlib.metadata as importlib_metadata +import logging +import os +from collections.abc import Iterator + +from packaging.requirements import Requirement + +from astrbot.core.utils.requirements_utils import ( + canonicalize_distribution_name, + collect_installed_distribution_versions, + get_requirement_check_paths, +) + +logger = logging.getLogger("astrbot") + + +def _resolve_core_dist_name(core_dist_name: str | None) -> str | None: + if core_dist_name: + try: + importlib_metadata.distribution(core_dist_name) + return core_dist_name + except importlib_metadata.PackageNotFoundError: + return None + + try: + importlib_metadata.distribution("AstrBot") + return "AstrBot" + except importlib_metadata.PackageNotFoundError: + pass + + if not __package__: + return None + + top_pkg = __package__.split(".")[0] + for dist in importlib_metadata.distributions(): + try: + top_level = dist.read_text("top_level.txt") or "" + except Exception: + continue + if top_pkg in top_level.splitlines(): + if "Name" in dist.metadata: + return dist.metadata["Name"] + + return None + + +@functools.cache +def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]: + try: + resolved_core_dist_name = _resolve_core_dist_name(core_dist_name) + except Exception as exc: + logger.warning("解析核心分发名称失败: %s", exc) + return () + + if not resolved_core_dist_name: + return () + + try: + dist = importlib_metadata.distribution(resolved_core_dist_name) + except importlib_metadata.PackageNotFoundError: + return () + except Exception as exc: + logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc) + return () + + if not dist or not dist.requires: + return () + + installed = collect_installed_distribution_versions(get_requirement_check_paths()) + if not installed: + return () + + constraints: list[str] = [] + for req_str in dist.requires: + try: + req = Requirement(req_str) + if req.marker and not req.marker.evaluate(): + continue + name = canonicalize_distribution_name(req.name) + if name in installed: + constraints.append(f"{name}=={installed[name]}") + except Exception: + continue + + return tuple(constraints) + + +class CoreConstraintsProvider: + def __init__(self, core_dist_name: str | None) -> None: + self._core_dist_name = core_dist_name + + @contextlib.contextmanager + def constraints_file(self) -> Iterator[str | None]: + constraints = _get_core_constraints(self._core_dist_name) + if not constraints: + yield None + return + + path: str | None = None + try: + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8" + ) as f: + f.write("\n".join(constraints)) + path = f.name + logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints)) + except Exception as exc: + logger.warning("创建临时约束文件失败: %s", exc) + yield None + return + + try: + yield path + finally: + if path and os.path.exists(path): + with contextlib.suppress(Exception): + os.remove(path) diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 562a0ed30..97e9653d6 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -7,21 +7,71 @@ import io import logging import os import re +import shlex import sys import threading from collections import deque +from dataclasses import dataclass +from urllib.parse import urlparse from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path +from astrbot.core.utils.core_constraints import CoreConstraintsProvider +from astrbot.core.utils.requirements_utils import ( + canonicalize_distribution_name as _canonicalize_distribution_name, +) +from astrbot.core.utils.requirements_utils import ( + extract_requirement_name, + extract_requirement_names, + parse_package_install_input, +) from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime logger = logging.getLogger("astrbot") _DISTLIB_FINDER_PATCH_ATTEMPTED = False _SITE_PACKAGES_IMPORT_LOCK = threading.RLock() +_PIP_FAILURE_PATTERNS = { + "error_prefix": re.compile(r"^\s*error:", re.IGNORECASE), + "user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE), + "resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE), + "cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE), + "conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE), + "constraint": re.compile(r"\(constraint\)", re.IGNORECASE), + "dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE), +} +_SENSITIVE_PIP_VALUE_KEYS = frozenset( + {"password", "passwd", "pass", "api_token", "token", "auth_token"} +) +_MAX_PIP_OUTPUT_LINES = 200 -def _canonicalize_distribution_name(name: str) -> str: - return re.sub(r"[-_.]+", "-", name).strip("-").lower() +class DependencyConflictError(Exception): + """Raised when pip encounters a dependency conflict.""" + + def __init__( + self, message: str, errors: list[str], *, is_core_conflict: bool + ) -> None: + super().__init__(message) + self.errors = errors + self.is_core_conflict = is_core_conflict + + +class PipInstallError(Exception): + """Raised when pip install fails without a classified dependency conflict.""" + + def __init__(self, message: str, *, code: int) -> None: + super().__init__(message) + self.code = code + + +@dataclass +class PipConflictContext: + relevant_lines: list[str] + requested_lines: list[str] + dependency_detail_lines: list[str] + constraint_lines: list[str] + has_strong_conflict_signal: bool + has_contextual_conflict_signal: bool def _get_pip_main(): @@ -41,11 +91,12 @@ def _get_pip_main(): return pip_main -def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]: - stream = io.StringIO() - with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream): - result_code = pip_main(args) - return result_code, stream.getvalue() +def _prepend_sys_path(path: str) -> None: + normalized_target = os.path.realpath(path) + sys.path[:] = [ + item for item in sys.path if os.path.realpath(item) != normalized_target + ] + sys.path.insert(0, normalized_target) def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None: @@ -59,76 +110,258 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No handler.close() -def _prepend_sys_path(path: str) -> None: - normalized_target = os.path.realpath(path) - sys.path[:] = [ - item for item in sys.path if os.path.realpath(item) != normalized_target - ] - sys.path.insert(0, normalized_target) +def _get_trusted_host_for_index_url(index_url: str) -> str | None: + parsed = urlparse(index_url if "://" in index_url else f"//{index_url}") + host = parsed.hostname + if host == "mirrors.aliyun.com": + return host + return None -def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool: - base_path = os.path.join(site_packages_path, *module_name.split(".")) - package_init = os.path.join(base_path, "__init__.py") - module_file = f"{base_path}.py" - return os.path.isfile(package_init) or os.path.isfile(module_file) +def _normalize_sensitive_pip_key(raw_key: str) -> str: + return raw_key.lstrip("-").replace("-", "_").lower() -def _is_module_loaded_from_site_packages( - module_name: str, - site_packages_path: str, -) -> bool: - module = sys.modules.get(module_name) - if module is None: - try: - module = importlib.import_module(module_name) - except Exception: - return False +def _is_sensitive_pip_value_key(raw_key: str) -> bool: + return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS - module_file = getattr(module, "__file__", None) - if not module_file: - return False - module_path = os.path.realpath(module_file) - site_packages_real = os.path.realpath(site_packages_path) - try: - return ( - os.path.commonpath([module_path, site_packages_real]) == site_packages_real +def _redact_url_credentials(raw_value: str) -> str: + """Redact URL credentials and known inline secret values for safe logging.""" + parsed = urlparse(raw_value) + if parsed.netloc and "@" in parsed.netloc: + hostname = parsed.hostname or "" + port = f":{parsed.port}" if parsed.port else "" + return parsed._replace(netloc=f"@{hostname}{port}").geturl() + + if raw_value.startswith("--"): + option, separator, _ = raw_value.partition("=") + if separator and _is_sensitive_pip_value_key(option): + return f"{option}=****" + return raw_value + + key, separator, _ = raw_value.partition("=") + if separator and _is_sensitive_pip_value_key(key): + return f"{key}=****" + + return raw_value + + +def _redact_pip_args_for_logging(args: list[str]) -> list[str]: + redacted_args: list[str] = [] + redact_next_value = False + + for arg in args: + if redact_next_value: + redacted_args.append("****") + redact_next_value = False + continue + + if arg.startswith("--") and "=" in arg: + option, value = arg.split("=", 1) + if _is_sensitive_pip_value_key(option): + redacted_args.append(f"{option}=****") + else: + redacted_args.append(f"{option}={_redact_url_credentials(value)}") + continue + + if arg.startswith("-i") and arg != "-i": + redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}") + continue + + if _is_sensitive_pip_value_key(arg): + redacted_args.append(arg) + redact_next_value = True + continue + + redacted_args.append(_redact_url_credentials(arg)) + + return redacted_args + + +def _package_specs_override_index(package_specs: list[str]) -> bool: + for index, spec in enumerate(package_specs): + if spec == "--no-index": + return True + if spec in {"-i", "--index-url"}: + if index + 1 < len(package_specs): + return True + continue + if spec.startswith("--index-url="): + return True + if spec.startswith("-i") and spec != "-i": + return True + return False + + +class _StreamingLogWriter(io.TextIOBase): + def __init__(self, log_func, *, max_lines: int | None = None) -> None: + self._log_func = log_func + self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES) + self._buffer = "" + + def write(self, text: str) -> int: + if not text: + return 0 + + self._buffer += text.replace("\r\n", "\n").replace("\r", "\n") + while "\n" in self._buffer: + raw_line, self._buffer = self._buffer.split("\n", 1) + line = raw_line.rstrip("\r\n") + self._log_func(line) + self._lines.append(line) + return len(text) + + def flush(self) -> None: + line = self._buffer.rstrip("\r\n") + if line: + self._log_func(line) + self._lines.append(line) + self._buffer = "" + + @property + def lines(self) -> list[str]: + return list(self._lines) + + +def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]: + stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES) + with ( + contextlib.redirect_stdout(stream), + contextlib.redirect_stderr(stream), + ): + result_code = pip_main(args) + stream.flush() + return result_code, stream.lines + + +def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool: + names = pattern_names or tuple(_PIP_FAILURE_PATTERNS) + return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names) + + +def _normalize_conflict_detail_line(line: str) -> str: + stripped = line.strip() + if _matches_pip_failure_pattern(stripped, "user_requested"): + return re.sub( + r"^\s*The user requested\s+", + "", + stripped, + flags=re.IGNORECASE, ) - except ValueError: - return False + return stripped -def _extract_requirement_name(raw_requirement: str) -> str | None: - line = raw_requirement.split("#", 1)[0].strip() - if not line: - return None - if line.startswith(("-r", "--requirement", "-c", "--constraint")): - return None - if line.startswith("-"): +def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None: + matched_indices = [ + index + for index, line in enumerate(output_lines) + if _matches_pip_failure_pattern(line) + ] + if matched_indices: + relevant_index_set: set[int] = set() + for index in matched_indices: + start = max(0, index - 1) + end = min(len(output_lines), index + 2) + relevant_index_set.update(range(start, end)) + relevant_output_lines = [ + line + for index, line in enumerate(output_lines) + if index in relevant_index_set + ] + else: + relevant_output_lines = output_lines[-5:] + + if not relevant_output_lines: return None - egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement) - if egg_match: - return _canonicalize_distribution_name(egg_match.group(1)) + dependency_detail_lines = [ + line.strip() + for line in relevant_output_lines + if _matches_pip_failure_pattern(line, "dependency_detail") + ] + requested_lines = [ + line.strip() + for line in relevant_output_lines + if _matches_pip_failure_pattern(line, "user_requested") + and not _matches_pip_failure_pattern(line, "constraint") + ] + if not requested_lines: + requested_lines = [ + line + for line in dependency_detail_lines + if not _matches_pip_failure_pattern(line, "constraint") + ] + constraint_lines = [ + line.strip() + for line in relevant_output_lines + if _matches_pip_failure_pattern(line, "constraint") + ] - candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip() - if not candidate: + has_strong_conflict_signal = any( + _matches_pip_failure_pattern( + line, + "resolution_impossible", + "cannot_install", + ) + for line in relevant_output_lines + ) + + has_contextual_conflict_signal = any( + _matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines + ) and bool(dependency_detail_lines or requested_lines or constraint_lines) + + return PipConflictContext( + relevant_lines=relevant_output_lines, + requested_lines=requested_lines, + dependency_detail_lines=dependency_detail_lines, + constraint_lines=constraint_lines, + has_strong_conflict_signal=has_strong_conflict_signal, + has_contextual_conflict_signal=has_contextual_conflict_signal, + ) + + +def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None: + context = _build_pip_conflict_context(output_lines) + if context is None: return None - return _canonicalize_distribution_name(candidate) + if ( + not context.has_strong_conflict_signal + and not context.has_contextual_conflict_signal + and not (context.requested_lines and context.constraint_lines) + ): + return None -def _extract_requirement_names(requirements_path: str) -> set[str]: - names: set[str] = set() - try: - with open(requirements_path, encoding="utf-8") as requirements_file: - for line in requirements_file: - requirement_name = _extract_requirement_name(line) - if requirement_name: - names.add(requirement_name) - except Exception as exc: - logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc) - return names + is_core_conflict = bool(context.constraint_lines) + + detail = "" + if context.constraint_lines and context.requested_lines: + detail = ( + " 冲突详情: " + f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs " + f"{_normalize_conflict_detail_line(context.constraint_lines[0])}。" + ) + elif len(context.dependency_detail_lines) >= 2: + detail = ( + " 冲突详情: " + f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs " + f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}。" + ) + + if is_core_conflict: + message = ( + f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容," + "为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。" + ) + else: + message = f"检测到依赖冲突。{detail}" + + return DependencyConflictError( + message, + context.relevant_lines, + is_core_conflict=is_core_conflict, + ) def _extract_top_level_modules( @@ -155,7 +388,11 @@ def _collect_candidate_modules( by_name: dict[str, list[importlib_metadata.Distribution]] = {} try: for distribution in importlib_metadata.distributions(path=[site_packages_path]): - distribution_name = distribution.metadata.get("Name") + distribution_name = ( + distribution.metadata["Name"] + if "Name" in distribution.metadata + else None + ) if not distribution_name: continue canonical_name = _canonicalize_distribution_name(distribution_name) @@ -173,7 +410,7 @@ def _collect_candidate_modules( for distribution in by_name.get(requirement_name, []): for dependency_line in distribution.requires or []: - dependency_name = _extract_requirement_name(dependency_line) + dependency_name = extract_requirement_name(dependency_line) if not dependency_name: continue if dependency_name in expanded_requirement_names: @@ -230,6 +467,38 @@ def _ensure_preferred_modules( raise RuntimeError(conflict_message) +def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool: + base_path = os.path.join(site_packages_path, *module_name.split(".")) + package_init = os.path.join(base_path, "__init__.py") + module_file = f"{base_path}.py" + return os.path.isfile(package_init) or os.path.isfile(module_file) + + +def _is_module_loaded_from_site_packages( + module_name: str, + site_packages_path: str, +) -> bool: + module = sys.modules.get(module_name) + if module is None: + try: + module = importlib.import_module(module_name) + except Exception: + return False + + module_file = getattr(module, "__file__", None) + if not module_file: + return False + + module_path = os.path.realpath(module_file) + site_packages_real = os.path.realpath(site_packages_path) + try: + return ( + os.path.commonpath([module_path, site_packages_real]) == site_packages_real + ) + except ValueError: + return False + + def _prefer_module_from_site_packages( module_name: str, site_packages_path: str ) -> bool: @@ -531,9 +800,63 @@ def _patch_distlib_finder_for_frozen_runtime() -> None: class PipInstaller: - def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None: + def __init__( + self, + pip_install_arg: str, + pypi_index_url: str | None = None, + core_dist_name: str | None = "AstrBot", + ) -> None: self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url + self.core_dist_name = core_dist_name + self._core_constraints = CoreConstraintsProvider(core_dist_name) + + def _build_pip_args( + self, + package_name: str | None, + requirements_path: str | None, + mirror: str | None, + ) -> tuple[list[str], set[str]]: + args: list[str] = [] + requested_requirements: set[str] = set() + normalized_requirements_path = ( + requirements_path.strip() if requirements_path else "" + ) + + if package_name and normalized_requirements_path: + raise ValueError( + "package_name and requirements_path cannot be used together" + ) + + if package_name: + parsed_package = parse_package_install_input(package_name) + if parsed_package.specs: + args = ["install", *parsed_package.specs] + requested_requirements = set(parsed_package.requirement_names) + elif normalized_requirements_path: + args = ["install", "-r", normalized_requirements_path] + requested_requirements = extract_requirement_names( + normalized_requirements_path + ) + + if not args: + return [], requested_requirements + + pip_install_args = ( + shlex.split(self.pip_install_arg) if self.pip_install_arg else [] + ) + + if not _package_specs_override_index([*args[1:], *pip_install_args]): + index_url = mirror or self.pypi_index_url or "https://pypi.org/simple" + trusted_host = _get_trusted_host_for_index_url(index_url) + if trusted_host: + args.extend(["--trusted-host", trusted_host]) + args.extend(["-i", index_url]) + + if pip_install_args: + args.extend(pip_install_args) + + return args, requested_requirements async def install( self, @@ -541,36 +864,37 @@ class PipInstaller: requirements_path: str | None = None, mirror: str | None = None, ) -> None: - args = ["install"] - requested_requirements: set[str] = set() - if package_name: - args.append(package_name) - requirement_name = _extract_requirement_name(package_name) - if requirement_name: - requested_requirements.add(requirement_name) - elif requirements_path: - args.extend(["-r", requirements_path]) - requested_requirements = _extract_requirement_names(requirements_path) - - index_url = mirror or self.pypi_index_url or "https://pypi.org/simple" - args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url]) + args, requested_requirements = self._build_pip_args( + package_name, requirements_path, mirror + ) + if not args: + logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。") + return target_site_packages = None if is_packaged_desktop_runtime(): target_site_packages = get_astrbot_site_packages_path() os.makedirs(target_site_packages, exist_ok=True) _prepend_sys_path(target_site_packages) - args.extend(["--target", target_site_packages]) - args.extend(["--upgrade", "--force-reinstall"]) + args.extend( + [ + "--target", + target_site_packages, + "--upgrade", + "--upgrade-strategy", + "only-if-needed", + ] + ) - if self.pip_install_arg: - args.extend(self.pip_install_arg.split()) + with self._core_constraints.constraints_file() as constraints_file_path: + if constraints_file_path: + args.extend(["-c", constraints_file_path]) - logger.info(f"Pip 包管理器: pip {' '.join(args)}") - result_code = await self._run_pip_in_process(args) - - if result_code != 0: - raise Exception(f"安装失败,错误码:{result_code}") + logger.info( + "Pip 包管理器 argv: %s", + ["pip", *_redact_pip_args_for_logging(args)], + ) + await self._run_pip_with_classification(args) if target_site_packages: _prepend_sys_path(target_site_packages) @@ -589,7 +913,7 @@ class PipInstaller: if not os.path.isdir(target_site_packages): return - requested_requirements = _extract_requirement_names(requirements_path) + requested_requirements = extract_requirement_names(requirements_path) if not requested_requirements: return @@ -605,13 +929,21 @@ class PipInstaller: _patch_distlib_finder_for_frozen_runtime() original_handlers = list(logging.getLogger().handlers) - result_code, output = await asyncio.to_thread( - _run_pip_main_with_output, pip_main, args - ) - for line in output.splitlines(): - line = line.strip() - if line: - logger.info(line) + try: + result_code, output_lines = await asyncio.to_thread( + _run_pip_main_streaming, pip_main, args + ) + finally: + _cleanup_added_root_handlers(original_handlers) + + if result_code != 0: + conflict = _classify_pip_failure(output_lines) + if conflict: + raise conflict - _cleanup_added_root_handlers(original_handlers) return result_code + + async def _run_pip_with_classification(self, args: list[str]) -> None: + result_code = await self._run_pip_in_process(args) + if result_code != 0: + raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code) diff --git a/astrbot/core/utils/requirements_utils.py b/astrbot/core/utils/requirements_utils.py new file mode 100644 index 000000000..7f3827256 --- /dev/null +++ b/astrbot/core/utils/requirements_utils.py @@ -0,0 +1,408 @@ +import importlib.metadata as importlib_metadata +import logging +import os +import re +import shlex +import sys +from collections.abc import Iterable, Iterator +from dataclasses import dataclass + +from packaging.requirements import InvalidRequirement, Requirement +from packaging.specifiers import SpecifierSet +from packaging.version import InvalidVersion, Version + +from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path +from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime + +logger = logging.getLogger("astrbot") + + +class RequirementsPrecheckFailed(Exception): + """Raised when the pre-check of requirements fails.""" + + pass + + +@dataclass(frozen=True) +class ParsedPackageInput: + specs: tuple[str, ...] + requirement_names: frozenset[str] + + +def canonicalize_distribution_name(name: str) -> str: + return re.sub(r"[-_.]+", "-", name).strip("-").lower() + + +def strip_inline_requirement_comment(raw_input: str) -> str: + if raw_input.lstrip().startswith("#"): + return "" + return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip() + + +def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool: + try: + parsed_version = Version(version) + except InvalidVersion: + return False + return specifier.contains(parsed_version, prereleases=True) + + +def _looks_like_local_path_reference(token: str) -> bool: + candidate = token.strip() + if not candidate: + return False + return candidate in {".", ".."} or candidate.startswith( + ("./", "../", "/", "~/", ".\\", "..\\", "\\") + ) + + +def looks_like_direct_reference(token: str) -> bool: + candidate = token.strip() + if not candidate: + return False + return ( + _looks_like_local_path_reference(candidate) + or candidate.startswith("git+") + or "://" in candidate + ) + + +def extract_requirement_name(raw_requirement: str) -> str | None: + line = raw_requirement.split("#", 1)[0].strip() + if not line: + return None + if line.startswith(("-r", "--requirement", "-c", "--constraint")): + return None + + egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement) + if egg_match: + return canonicalize_distribution_name(egg_match.group(1)) + + if line.startswith("-"): + return None + + candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip() + if not candidate: + return None + return canonicalize_distribution_name(candidate) + + +def _parse_editable_or_direct_name(target: str) -> str | None: + name = extract_requirement_name(target) + if not name: + return None + if "#egg=" in target or not looks_like_direct_reference(target): + return name + return None + + +def _parse_requirement_name_and_spec( + line: str, +) -> tuple[str | None, SpecifierSet | None]: + if line.startswith(("-c", "--constraint")): + return None, None + + try: + req = Requirement(line) + except InvalidRequirement: + tokens = shlex.split(line) + if not tokens: + return None, None + + editable_target: str | None = None + if tokens[0] in {"-e", "--editable"} and len(tokens) > 1: + editable_target = tokens[1] + elif tokens[0].startswith("--editable="): + editable_target = tokens[0].split("=", 1)[1] + + if editable_target: + name = _parse_editable_or_direct_name(editable_target) + return (name, None) if name else (None, None) + + name = _parse_editable_or_direct_name(line) + return (name, None) if name else (None, None) + + if req.marker and not req.marker.evaluate(): + return None, None + + return canonicalize_distribution_name(req.name), (req.specifier or None) + + +def _parse_requirement_line( + line: str, +) -> tuple[str, SpecifierSet | None] | None: + name, specifier = _parse_requirement_name_and_spec(line) + return (name, specifier) if name else None + + +def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]: + requirement_names: set[str] = set() + skip_next_for: str | None = None + + for token in tokens: + if skip_next_for: + if skip_next_for == "editable": + name = _parse_editable_or_direct_name(token) + if name: + requirement_names.add(name) + skip_next_for = None + continue + + if token in {"-e", "--editable"}: + skip_next_for = "editable" + continue + + if token in { + "-i", + "--index-url", + "--extra-index-url", + "-f", + "--find-links", + "--trusted-host", + "-r", + "--requirement", + "-c", + "--constraint", + }: + skip_next_for = "option-value" + continue + + if token.startswith(("--editable=",)): + editable_target = token.split("=", 1)[1] + name = _parse_editable_or_direct_name(editable_target) + if name: + requirement_names.add(name) + continue + + if token.startswith( + ( + "--index-url=", + "--extra-index-url=", + "--find-links=", + "--trusted-host=", + "--requirement=", + "--constraint=", + ) + ): + continue + + if ( + (token.startswith("-i") and token != "-i") + or (token.startswith("-f") and token != "-f") + or token == "--no-index" + ): + continue + + if token.startswith("-"): + continue + + name, _ = _parse_requirement_name_and_spec(token) + if name: + requirement_names.add(name) + + return frozenset(requirement_names) + + +def parse_package_install_input(raw_input: str) -> ParsedPackageInput: + specs: list[str] = [] + requirement_names: set[str] = set() + normalized = raw_input.strip() + if not normalized: + return ParsedPackageInput(specs=(), requirement_names=frozenset()) + + for raw_line in normalized.splitlines(): + line = strip_inline_requirement_comment(raw_line) + if not line: + continue + + try: + Requirement(line) + except InvalidRequirement: + tokens = shlex.split(line) + if not tokens: + continue + specs.extend(tokens) + requirement_names.update( + _extract_requirement_names_from_package_tokens(tokens) + ) + continue + + specs.append(line) + name, _ = _parse_requirement_name_and_spec(line) + if name: + requirement_names.add(name) + + return ParsedPackageInput( + specs=tuple(specs), + requirement_names=frozenset(requirement_names), + ) + + +def _iter_requirement_lines( + requirements_path: str, + _visited: set[str] | None = None, +) -> Iterator[str]: + visited = _visited or set() + resolved_path = os.path.realpath(requirements_path) + if resolved_path in visited: + logger.warning( + "检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path + ) + return + visited.add(resolved_path) + + with open(resolved_path, encoding="utf-8") as f: + for raw_line in f: + line = strip_inline_requirement_comment(raw_line) + if not line: + continue + + tokens = shlex.split(line) + if not tokens: + continue + + nested: str | None = None + if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1: + nested = tokens[1] + elif tokens[0].startswith("--requirement="): + nested = tokens[0].split("=", 1)[1] + + if nested: + if not os.path.isabs(nested): + nested = os.path.join(os.path.dirname(resolved_path), nested) + yield from _iter_requirement_lines(nested, _visited=visited) + continue + + yield line + + +def iter_requirements( + requirements_path: str | None = None, + lines: Iterable[str] | None = None, +) -> Iterator[tuple[str, SpecifierSet | None]]: + if lines is None: + if requirements_path is None: + raise ValueError("Either requirements_path or lines must be provided") + lines = _iter_requirement_lines(requirements_path) + + for line in lines: + parsed = _parse_requirement_line(line) + if parsed is not None: + yield parsed + + +def extract_requirement_names(requirements_path: str) -> set[str]: + try: + return { + name for name, _ in iter_requirements(requirements_path=requirements_path) + } + except Exception as exc: + logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc) + return set() + + +def get_requirement_check_paths() -> list[str]: + paths = list(sys.path) + if is_packaged_desktop_runtime(): + target_site_packages = get_astrbot_site_packages_path() + if os.path.isdir(target_site_packages): + paths.insert(0, target_site_packages) + return paths + + +def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]: + distribution_name = ( + distribution.metadata["Name"] if "Name" in distribution.metadata else None + ) + if not distribution_name: + return None, None + return canonicalize_distribution_name(distribution_name), distribution.version + + +def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None: + installed: dict[str, str] = {} + try: + for distribution in importlib_metadata.distributions(path=paths): + distribution_name, version = _canonical_distribution_identity(distribution) + if not distribution_name or not version: + continue + installed.setdefault(distribution_name, version) + except Exception as exc: + logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc) + return None + return installed + + +def _load_requirement_lines_for_precheck( + requirements_path: str, +) -> tuple[bool, list[str] | None]: + try: + requirement_lines = list(_iter_requirement_lines(requirements_path)) + except Exception as exc: + logger.warning( + "预检查缺失依赖失败,将回退到完整安装: %s (%s)", + requirements_path, + exc, + ) + return False, None + + fallback_line = next( + ( + line + for line in requirement_lines + if ( + ( + line.startswith(("-e ", "--editable ", "--editable=")) + and "#egg=" not in line + ) + or ( + _parse_requirement_line(line) is None + and looks_like_direct_reference(line) + ) + ) + ), + None, + ) + if fallback_line is not None: + logger.warning( + "预检查缺失依赖失败,将回退到完整安装: unresolved direct reference in %s: %s", + requirements_path, + fallback_line, + ) + return False, None + + return True, requirement_lines + + +def find_missing_requirements(requirements_path: str) -> set[str] | None: + can_precheck, requirement_lines = _load_requirement_lines_for_precheck( + requirements_path + ) + if not can_precheck or requirement_lines is None: + return None + + required = list(iter_requirements(lines=requirement_lines)) + if not required: + return set() + + installed = collect_installed_distribution_versions(get_requirement_check_paths()) + if installed is None: + return None + + missing: set[str] = set() + for name, specifier in required: + installed_version = installed.get(name) + if not installed_version: + missing.add(name) + continue + if specifier and not _specifier_contains_version(specifier, installed_version): + missing.add(name) + + return missing + + +def find_missing_requirements_or_raise(requirements_path: str) -> set[str]: + missing = find_missing_requirements(requirements_path) + if missing is None: + raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}") + return missing diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index ce28316af..6c575910a 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -16,6 +16,7 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.utils.pip_installer import PipInstallError from astrbot.dashboard.routes.plugin import PluginRoute from astrbot.dashboard.server import AstrBotDashboard from tests.fixtures.helpers import ( @@ -359,6 +360,35 @@ async def test_do_update( assert os.path.exists(release_path) +@pytest.mark.asyncio +async def test_install_pip_package_returns_pip_install_error_message( + app: Quart, + authenticated_header: dict, + monkeypatch, +): + test_client = app.test_client() + + async def mock_pip_install(*args, **kwargs): + del args, kwargs + raise PipInstallError("install failed", code=2) + + monkeypatch.setattr( + "astrbot.dashboard.routes.update.pip_installer.install", + mock_pip_install, + ) + + response = await test_client.post( + "/api/update/pip-install", + headers=authenticated_header, + json={"package": "demo-package"}, + ) + + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "error" + assert data["message"] == "install failed" + + class _FakeNeoSkills: async def list_candidates(self, **kwargs): _ = kwargs diff --git a/tests/test_pip_helper_modules.py b/tests/test_pip_helper_modules.py new file mode 100644 index 000000000..dcb5cdb21 --- /dev/null +++ b/tests/test_pip_helper_modules.py @@ -0,0 +1,266 @@ +from pathlib import Path + +import pytest + +from astrbot.core.utils import core_constraints as core_constraints_module +from astrbot.core.utils import requirements_utils +from astrbot.core.utils.core_constraints import CoreConstraintsProvider + + +def test_requirements_utils_parse_package_install_input_collects_specs_and_names(): + parsed = requirements_utils.parse_package_install_input( + "--index-url https://example.com/simple demo-package\nanother-package>=1.0\n" + ) + + assert parsed.specs == ( + "--index-url", + "https://example.com/simple", + "demo-package", + "another-package>=1.0", + ) + assert parsed.requirement_names == {"demo-package", "another-package"} + + +def test_core_constraints_provider_writes_constraints_file_from_fallback_distribution( + monkeypatch, +): + class FakeFallbackDistribution: + metadata = {"Name": "AstrBot-App"} + requires = ["shared-lib>=1.0"] + + def read_text(self, name): + if name == "top_level.txt": + return "astrbot\n" + return "" + + fake_distribution = FakeFallbackDistribution() + + def mock_distribution(name): + if name == "AstrBot": + raise core_constraints_module.importlib_metadata.PackageNotFoundError + if name == "AstrBot-App": + return fake_distribution + raise core_constraints_module.importlib_metadata.PackageNotFoundError + + def mock_distributions(path=None): + del path + return [fake_distribution] + + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distribution", + mock_distribution, + ) + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distributions", + mock_distributions, + ) + monkeypatch.setattr( + core_constraints_module, + "collect_installed_distribution_versions", + lambda paths: {"shared-lib": "2.0"}, + ) + + core_constraints_module._get_core_constraints.cache_clear() + try: + provider = CoreConstraintsProvider(None) + with provider.constraints_file() as constraints_path: + assert constraints_path is not None + assert ( + Path(constraints_path).read_text(encoding="utf-8") == "shared-lib==2.0" + ) + finally: + core_constraints_module._get_core_constraints.cache_clear() + + +def test_resolve_core_dist_name_skips_distribution_without_name(monkeypatch): + class NamelessDistribution: + metadata = {} + + def read_text(self, name): + if name == "top_level.txt": + return "astrbot\n" + return "" + + class NamedDistribution: + metadata = {"Name": "AstrBot-App"} + + def read_text(self, name): + if name == "top_level.txt": + return "astrbot\n" + return "" + + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distribution", + lambda name: (_ for _ in ()).throw( + core_constraints_module.importlib_metadata.PackageNotFoundError + ), + ) + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distributions", + lambda: [NamelessDistribution(), NamedDistribution()], + ) + + assert core_constraints_module._resolve_core_dist_name(None) == "AstrBot-App" + + +def test_find_missing_requirements_returns_none_when_precheck_gate_fails( + monkeypatch, + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("demo-package\n", encoding="utf-8") + + monkeypatch.setattr( + requirements_utils, + "_load_requirement_lines_for_precheck", + lambda path: (False, None), + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing is None + + +def test_parse_package_install_input_tracks_only_named_direct_references(): + named = requirements_utils.parse_package_install_input( + "git+https://example.com/demo.git#egg=demo-package" + ) + unnamed = requirements_utils.parse_package_install_input( + "git+https://example.com/demo.git" + ) + + assert named.requirement_names == {"demo-package"} + assert unnamed.requirement_names == set() + + +def test_find_missing_requirements_or_raise_uses_requirements_exception(tmp_path): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("-e ../sharedlib\n", encoding="utf-8") + + with pytest.raises(requirements_utils.RequirementsPrecheckFailed): + requirements_utils.find_missing_requirements_or_raise(str(requirements_path)) + + +def test_find_missing_requirements_logs_path_and_reason_on_precheck_fallback( + monkeypatch, + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8") + warning_logs = [] + + monkeypatch.setattr( + "astrbot.core.utils.requirements_utils.logger.warning", + lambda line, *args: warning_logs.append(line % args if args else line), + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing is None + assert any(str(requirements_path) in log for log in warning_logs) + assert any("direct reference" in log for log in warning_logs) + + +def test_load_requirement_lines_for_precheck_uses_parse_requirement_line_result( + monkeypatch, + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8") + + monkeypatch.setattr( + requirements_utils, + "_parse_requirement_line", + lambda line: ("demo-package", None) if line.startswith("git+") else None, + ) + + can_precheck, requirement_lines = ( + requirements_utils._load_requirement_lines_for_precheck(str(requirements_path)) + ) + + assert can_precheck is True + assert requirement_lines == ["git+https://example.com/demo.git"] + + +def test_collect_installed_distribution_versions_skips_nameless_distribution( + monkeypatch, +): + class NamelessDistribution: + metadata = {} + version = "1.0" + + class NamedDistribution: + metadata = {"Name": "demo-package"} + version = "2.0" + + monkeypatch.setattr( + requirements_utils.importlib_metadata, + "distributions", + lambda path: [NamelessDistribution(), NamedDistribution()], + ) + + installed = requirements_utils.collect_installed_distribution_versions( + ["/tmp/test"] + ) + + assert installed == {"demo-package": "2.0"} + + +def test_get_core_constraints_logs_resolution_step_context(monkeypatch): + warning_logs = [] + + monkeypatch.setattr( + core_constraints_module, + "_resolve_core_dist_name", + lambda core_dist_name: (_ for _ in ()).throw(RuntimeError("boom")), + ) + monkeypatch.setattr( + "astrbot.core.utils.core_constraints.logger.warning", + lambda line, *args: warning_logs.append(line % args if args else line), + ) + + core_constraints_module._get_core_constraints.cache_clear() + try: + constraints = core_constraints_module._get_core_constraints(None) + finally: + core_constraints_module._get_core_constraints.cache_clear() + + assert constraints == () + assert any("解析核心分发名称失败" in log for log in warning_logs) + + +def test_iter_requirements_supports_direct_line_input(): + parsed = list( + requirements_utils.iter_requirements( + lines=["demo-package>=1.0", 'other-package; sys_platform == "win32"'] + ) + ) + + assert parsed == [ + ("demo-package", requirements_utils.Requirement("demo-package>=1.0").specifier) + ] + + +def test_parse_requirement_name_and_spec_preserves_direct_reference_rules(): + named = requirements_utils._parse_requirement_name_and_spec( + "git+https://example.com/demo.git#egg=demo-package" + ) + unnamed = requirements_utils._parse_requirement_name_and_spec( + "git+https://example.com/demo.git" + ) + + assert named == ("demo-package", None) + assert unnamed == (None, None) + + +def test_parse_requirement_name_and_spec_handles_plain_requirement_token(): + parsed = requirements_utils._parse_requirement_name_and_spec("demo-package>=1.0") + + assert parsed == ( + "demo-package", + requirements_utils.Requirement("demo-package>=1.0").specifier, + ) diff --git a/tests/test_pip_installer.py b/tests/test_pip_installer.py index deaf48088..d61507f4c 100644 --- a/tests/test_pip_installer.py +++ b/tests/test_pip_installer.py @@ -1,17 +1,36 @@ +import asyncio +import threading from unittest.mock import AsyncMock import pytest +from astrbot.core.utils import core_constraints as core_constraints_module +from astrbot.core.utils import pip_installer as pip_installer_module +from astrbot.core.utils import requirements_utils from astrbot.core.utils.pip_installer import PipInstaller +def _make_run_pip_mock( + code: int = 0, + output_lines: list[str] | None = None, + conflict=None, +): + del output_lines, conflict + + async def run_pip(*args, **kwargs): + del args, kwargs + return code + + return AsyncMock(side_effect=run_pip) + + @pytest.mark.asyncio async def test_install_targets_site_packages_for_desktop_client(monkeypatch, tmp_path): monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") monkeypatch.delattr("sys.frozen", raising=False) site_packages_path = tmp_path / "site-packages" - run_pip = AsyncMock(return_value=0) + run_pip = _make_run_pip_mock() prepend_sys_path_calls = [] ensure_preferred_calls = [] @@ -39,3 +58,1287 @@ async def test_install_targets_site_packages_for_desktop_client(monkeypatch, tmp assert str(site_packages_path) in recorded_args assert prepend_sys_path_calls == [str(site_packages_path), str(site_packages_path)] assert ensure_preferred_calls == [(str(site_packages_path), {"demo-package"})] + + +@pytest.mark.asyncio +async def test_run_pip_in_process_streams_output_lines(monkeypatch): + logged_lines = [] + first_line_seen = asyncio.Event() + unblock_pip = threading.Event() + + def fake_pip_main(args): + del args + print("Collecting demo-package") + unblock_pip.wait(timeout=1) + print("Downloading demo-package.whl") + return 0 + + loop = asyncio.get_running_loop() + + def record_log(line, *args): + message = line % args if args else line + logged_lines.append(message) + if message == "Collecting demo-package": + loop.call_soon_threadsafe(first_line_seen.set) + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._get_pip_main", + lambda: fake_pip_main, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.logger.info", + record_log, + ) + + installer = PipInstaller("") + task = asyncio.create_task( + installer._run_pip_in_process(["install", "demo-package"]) + ) + + await asyncio.wait_for(first_line_seen.wait(), timeout=1) + unblock_pip.set() + result = await task + + assert result == 0 + assert logged_lines[-2:] == [ + "Collecting demo-package", + "Downloading demo-package.whl", + ] + + +@pytest.mark.asyncio +async def test_run_pip_in_process_preserves_shared_stream_order(monkeypatch): + logged_lines = [] + + def fake_pip_main(args): + del args + import sys + + sys.stdout.write("out") + sys.stderr.write("err\n") + sys.stdout.write(" line\n") + return 0 + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._get_pip_main", + lambda: fake_pip_main, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.logger.info", + lambda line, *args: logged_lines.append(line % args if args else line), + ) + + installer = PipInstaller("") + result = await installer._run_pip_in_process(["install", "demo-package"]) + + assert result == 0 + assert logged_lines[-2:] == ["outerr", " line"] + + +@pytest.mark.asyncio +async def test_run_pip_in_process_preserves_blank_lines(monkeypatch): + logged_lines = [] + + def fake_pip_main(args): + del args + print("Collecting demo-package") + print() + print("Installing collected packages") + return 0 + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._get_pip_main", + lambda: fake_pip_main, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.logger.info", + lambda line, *args: logged_lines.append(line % args if args else line), + ) + + installer = PipInstaller("") + result = await installer._run_pip_in_process(["install", "demo-package"]) + + assert result == 0 + assert logged_lines[-3:] == [ + "Collecting demo-package", + "", + "Installing collected packages", + ] + + +@pytest.mark.asyncio +async def test_run_pip_in_process_preserves_trailing_blank_line_on_flush(monkeypatch): + logged_lines = [] + + def fake_pip_main(args): + del args + import sys + + sys.stdout.write("Collecting demo-package\n\n") + return 0 + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._get_pip_main", + lambda: fake_pip_main, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.logger.info", + lambda line, *args: logged_lines.append(line % args if args else line), + ) + + installer = PipInstaller("") + result = await installer._run_pip_in_process(["install", "demo-package"]) + + assert result == 0 + assert logged_lines[-2:] == ["Collecting demo-package", ""] + + +@pytest.mark.asyncio +async def test_run_pip_in_process_normalizes_crlf_without_extra_blank_lines( + monkeypatch, +): + logged_lines = [] + + def fake_pip_main(args): + del args + import sys + + sys.stdout.write("Collecting demo-package\r\n") + sys.stdout.write("Installing collected packages\r\n") + return 0 + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._get_pip_main", + lambda: fake_pip_main, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.logger.info", + lambda line, *args: logged_lines.append(line % args if args else line), + ) + + installer = PipInstaller("") + result = await installer._run_pip_in_process(["install", "demo-package"]) + + assert result == 0 + assert logged_lines[-2:] == [ + "Collecting demo-package", + "Installing collected packages", + ] + + +@pytest.mark.asyncio +async def test_run_pip_in_process_classifies_nonstandard_conflict_output(monkeypatch): + def fake_pip_main(args): + del args + print( + "Cannot install demo-package and astrbot-core because these package " + "versions have conflicting dependencies." + ) + print("The conflict is caused by:") + print(" demo-package depends on shared-lib>=3.0") + print(" AstrBot (constraint) depends on shared-lib==2.0") + return 1 + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._get_pip_main", + lambda: fake_pip_main, + ) + + installer = PipInstaller("") + with pytest.raises(pip_installer_module.DependencyConflictError) as exc_info: + await installer._run_pip_in_process(["install", "demo-package"]) + + assert exc_info.value.is_core_conflict is True + assert "demo-package" in str(exc_info.value) + assert "demo-package depends on shared-lib>=3.0" in str(exc_info.value) + assert "AstrBot (constraint) depends on shared-lib==2.0" in str(exc_info.value) + assert "The conflict is caused by:" in exc_info.value.errors + + +@pytest.mark.asyncio +async def test_install_raises_dedicated_pip_install_error_on_non_conflict_failure( + monkeypatch, +): + async def failing_run_pip(self, args): + del self, args + return 2 + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", failing_run_pip) + + installer = PipInstaller("") + + with pytest.raises(pip_installer_module.PipInstallError, match="错误码:2"): + await installer.install(package_name="demo-package") + + +@pytest.mark.asyncio +async def test_run_pip_with_classification_raises_install_error_on_non_conflict_failure( + monkeypatch, +): + async def failing_run_pip(self, args): + del self, args + return 3 + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", failing_run_pip) + + installer = PipInstaller("") + + with pytest.raises(pip_installer_module.PipInstallError, match="错误码:3"): + await installer._run_pip_with_classification(["install", "demo-package"]) + + +@pytest.mark.asyncio +async def test_run_pip_in_process_bounds_retained_conflict_lines(monkeypatch): + def fake_pip_main(args): + del args + for index in range(10): + print(f"noise-{index}") + print( + "Cannot install demo-package and astrbot-core because these package " + "versions have conflicting dependencies." + ) + print("The conflict is caused by:") + print(" demo-package depends on shared-lib>=3.0") + print(" AstrBot (constraint) depends on shared-lib==2.0") + return 1 + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._get_pip_main", + lambda: fake_pip_main, + ) + monkeypatch.setattr("astrbot.core.utils.pip_installer._MAX_PIP_OUTPUT_LINES", 4) + + installer = PipInstaller("") + with pytest.raises(pip_installer_module.DependencyConflictError) as exc_info: + await installer._run_pip_in_process(["install", "demo-package"]) + + assert len(exc_info.value.errors) == 4 + assert exc_info.value.errors[0].startswith("Cannot install demo-package") + assert ( + exc_info.value.errors[-1] + == " AstrBot (constraint) depends on shared-lib==2.0" + ) + + +def test_build_pip_args_rejects_package_name_and_requirements_path_together(tmp_path): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("demo-package\n", encoding="utf-8") + + installer = PipInstaller("") + + with pytest.raises(ValueError, match="package_name and requirements_path"): + installer._build_pip_args("requests", str(requirements_path), None) + + +def _make_fake_distribution(name: str, version: str): + class FakeDistribution: + metadata = {"Name": name} + + def __init__(self, version: str): + self.version = version + + return FakeDistribution(version) + + +def test_find_missing_requirements_honors_version_specifiers(monkeypatch, tmp_path): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("demo-package>=2.0\n", encoding="utf-8") + + monkeypatch.setattr( + pip_installer_module.importlib_metadata, + "distributions", + lambda path: [_make_fake_distribution("demo-package", "1.0")], + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing == {"demo-package"} + + +def test_find_missing_requirements_skips_unmatched_markers(monkeypatch, tmp_path): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text( + 'demo-package; sys_platform == "win32"\n', + encoding="utf-8", + ) + + monkeypatch.setattr( + pip_installer_module.importlib_metadata, + "distributions", + lambda path: [], + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing == set() + + +def test_find_missing_requirements_follows_nested_requirement_files( + monkeypatch, tmp_path +): + base_requirements = tmp_path / "base.txt" + base_requirements.write_text("demo-package==1.0\n", encoding="utf-8") + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("-r base.txt\n", encoding="utf-8") + + monkeypatch.setattr( + pip_installer_module.importlib_metadata, + "distributions", + lambda path: [], + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing == {"demo-package"} + + +def test_find_missing_requirements_follows_equals_form_nested_requirements( + monkeypatch, tmp_path +): + base_requirements = tmp_path / "base.txt" + base_requirements.write_text("demo-package==1.0\n", encoding="utf-8") + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("--requirement=base.txt\n", encoding="utf-8") + + monkeypatch.setattr( + pip_installer_module.importlib_metadata, + "distributions", + lambda path: [], + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing == {"demo-package"} + + +def test_find_missing_requirements_returns_none_when_nested_file_missing(tmp_path): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("-r base.txt\n", encoding="utf-8") + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing is None + + +def test_find_missing_requirements_extracts_editable_vcs_requirement( + monkeypatch, tmp_path +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text( + "-e git+https://example.com/demo.git#egg=demo-package\n", + encoding="utf-8", + ) + + monkeypatch.setattr( + pip_installer_module.importlib_metadata, + "distributions", + lambda path: [], + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing == {"demo-package"} + + +def test_find_missing_requirements_prefers_first_search_path_version( + monkeypatch, tmp_path +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("demo-package>=2.0\n", encoding="utf-8") + + monkeypatch.setattr( + pip_installer_module.importlib_metadata, + "distributions", + lambda path: [ + _make_fake_distribution("demo-package", "1.0"), + _make_fake_distribution("demo-package", "3.0"), + ], + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing == {"demo-package"} + + +def test_find_missing_requirements_returns_none_when_distribution_scan_fails( + monkeypatch, tmp_path +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("demo-package>=2.0\n", encoding="utf-8") + + def failing_distributions(path): + del path + yield _make_fake_distribution("demo-package", "3.0") + raise RuntimeError("scan failed") + + monkeypatch.setattr( + pip_installer_module.importlib_metadata, + "distributions", + failing_distributions, + ) + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing is None + + +def test_get_core_constraints_caches_fallback_resolution(monkeypatch): + distribution_calls = [] + distributions_calls = [] + + class FakeFallbackDistribution: + metadata = {"Name": "AstrBot-App"} + requires = ["shared-lib>=1.0"] + + def read_text(self, name): + if name == "top_level.txt": + return "astrbot\n" + return "" + + fake_distribution = FakeFallbackDistribution() + + def mock_distribution(name): + distribution_calls.append(name) + if name == "AstrBot": + raise pip_installer_module.importlib_metadata.PackageNotFoundError + if name == "AstrBot-App": + return fake_distribution + raise pip_installer_module.importlib_metadata.PackageNotFoundError + + def mock_distributions(path=None): + del path + distributions_calls.append("scan") + return [fake_distribution] + + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distribution", + mock_distribution, + ) + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distributions", + mock_distributions, + ) + monkeypatch.setattr( + core_constraints_module, + "collect_installed_distribution_versions", + lambda paths: {"shared-lib": "2.0"}, + ) + + core_constraints_module._get_core_constraints.cache_clear() + try: + first = core_constraints_module._get_core_constraints(None) + second = core_constraints_module._get_core_constraints(None) + finally: + core_constraints_module._get_core_constraints.cache_clear() + + assert first == ("shared-lib==2.0",) + assert second == ("shared-lib==2.0",) + assert distribution_calls == ["AstrBot", "AstrBot-App"] + assert distributions_calls == ["scan"] + + +def test_get_core_constraints_skips_distributions_with_unreadable_top_level( + monkeypatch, +): + class BrokenDistribution: + metadata = {"Name": "Broken-App"} + requires = [] + + def read_text(self, name): + if name == "top_level.txt": + raise OSError("cannot read top_level.txt") + return "" + + class FakeFallbackDistribution: + metadata = {"Name": "AstrBot-App"} + requires = ["shared-lib>=1.0"] + + def read_text(self, name): + if name == "top_level.txt": + return "astrbot\n" + return "" + + broken_distribution = BrokenDistribution() + fake_distribution = FakeFallbackDistribution() + + def mock_distribution(name): + if name == "AstrBot": + raise pip_installer_module.importlib_metadata.PackageNotFoundError + if name == "AstrBot-App": + return fake_distribution + raise pip_installer_module.importlib_metadata.PackageNotFoundError + + def mock_distributions(path=None): + del path + return [broken_distribution, fake_distribution] + + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distribution", + mock_distribution, + ) + monkeypatch.setattr( + core_constraints_module.importlib_metadata, + "distributions", + mock_distributions, + ) + monkeypatch.setattr( + core_constraints_module, + "collect_installed_distribution_versions", + lambda paths: {"shared-lib": "2.0"}, + ) + + core_constraints_module._get_core_constraints.cache_clear() + try: + constraints = core_constraints_module._get_core_constraints(None) + finally: + core_constraints_module._get_core_constraints.cache_clear() + + assert constraints == ("shared-lib==2.0",) + + +def test_core_constraints_file_propagates_inner_conflict_without_fake_warning( + monkeypatch, +): + warning_logs = [] + conflict = pip_installer_module.DependencyConflictError( + "core conflict", + [], + is_core_conflict=True, + ) + + monkeypatch.setattr( + core_constraints_module, + "_get_core_constraints", + lambda core_dist_name: ("aiohttp==3.13.3",), + ) + monkeypatch.setattr( + "astrbot.core.utils.core_constraints.logger.warning", + lambda line, *args: warning_logs.append(line % args if args else line), + ) + + with pytest.raises( + pip_installer_module.DependencyConflictError, + match="core conflict", + ): + provider = core_constraints_module.CoreConstraintsProvider("AstrBot") + with provider.constraints_file() as constraints_path: + assert constraints_path is not None + raise conflict + + assert warning_logs == [] + + +def test_iter_requirement_lines_expands_nested_requirement_files(tmp_path): + base_requirements = tmp_path / "base.txt" + base_requirements.write_text("demo-package==1.0\n", encoding="utf-8") + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text( + "# comment\n-r base.txt\n--extra-index-url https://example.com/simple\n", + encoding="utf-8", + ) + + lines = list(requirements_utils._iter_requirement_lines(str(requirements_path))) + + assert lines == [ + "demo-package==1.0", + "--extra-index-url https://example.com/simple", + ] + + +def test_build_pip_args_extracts_requested_requirements(): + installer = PipInstaller("") + + args, requested = installer._build_pip_args( + "--index-url https://example.com/simple demo-package", + None, + None, + ) + + assert args == [ + "install", + "--index-url", + "https://example.com/simple", + "demo-package", + ] + assert requested == {"demo-package"} + + +def test_build_pip_args_appends_default_index_when_not_overridden(): + installer = PipInstaller("") + + args, requested = installer._build_pip_args("demo-package", None, None) + + assert args == ["install", "demo-package", "-i", "https://pypi.org/simple"] + assert requested == {"demo-package"} + + +@pytest.mark.asyncio +async def test_install_splits_space_separated_packages(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="demo-package another-package>=1.0") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:3] == ["install", "demo-package", "another-package>=1.0"] + + +@pytest.mark.asyncio +async def test_install_splits_three_space_separated_packages(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name="demo-package another-package extra-package>=1.0" + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:4] == [ + "install", + "demo-package", + "another-package", + "extra-package>=1.0", + ] + + +@pytest.mark.asyncio +async def test_install_splits_three_bare_packages(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="demo-package another-package extra-package") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:4] == [ + "install", + "demo-package", + "another-package", + "extra-package", + ] + + +@pytest.mark.asyncio +async def test_install_tracks_multiline_packages_for_desktop_client( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + monkeypatch.delattr("sys.frozen", raising=False) + + site_packages_path = tmp_path / "site-packages" + run_pip = _make_run_pip_mock() + ensure_preferred_calls = [] + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.get_astrbot_site_packages_path", + lambda: str(site_packages_path), + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._prepend_sys_path", + lambda path: None, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred", + lambda path, requirements: ensure_preferred_calls.append((path, requirements)), + ) + + installer = PipInstaller("") + await installer.install(package_name="demo-package\nanother-package>=1.0\n") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:3] == ["install", "demo-package", "another-package>=1.0"] + assert ensure_preferred_calls == [ + (str(site_packages_path), {"demo-package", "another-package"}) + ] + + +@pytest.mark.asyncio +async def test_install_splits_space_separated_packages_within_multiline_input( + monkeypatch, +): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name="demo-package another-package\nextra-package\n" + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:4] == [ + "install", + "demo-package", + "another-package", + "extra-package", + ] + + +@pytest.mark.asyncio +async def test_install_keeps_single_requirement_with_marker_intact(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="demo-package ; python_version < '4'") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:2] == [ + "install", + "demo-package ; python_version < '4'", + ] + + +@pytest.mark.asyncio +async def test_install_keeps_single_requirement_with_compact_marker_intact(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name='demo-package; python_version < "4"') + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:2] == [ + "install", + 'demo-package; python_version < "4"', + ] + + +@pytest.mark.asyncio +async def test_install_keeps_single_requirement_with_version_range_intact(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="demo-package >= 1.0, < 2.0") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:2] == [ + "install", + "demo-package >= 1.0, < 2.0", + ] + + +@pytest.mark.asyncio +async def test_install_tracks_only_real_requirement_names_for_spaced_single_requirement( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + monkeypatch.delattr("sys.frozen", raising=False) + + site_packages_path = tmp_path / "site-packages" + run_pip = _make_run_pip_mock() + ensure_preferred_calls = [] + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.get_astrbot_site_packages_path", + lambda: str(site_packages_path), + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._prepend_sys_path", + lambda path: None, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred", + lambda path, requirements: ensure_preferred_calls.append((path, requirements)), + ) + + installer = PipInstaller("") + await installer.install(package_name="demo-package >= 1.0, < 2.0") + + assert ensure_preferred_calls == [(str(site_packages_path), {"demo-package"})] + + +def test_prefer_installed_dependencies_prefers_modules_for_requirements_in_desktop_runtime( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + monkeypatch.delattr("sys.frozen", raising=False) + + site_packages_path = tmp_path / "site-packages" + site_packages_path.mkdir() + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("demo-package>=1.0\n", encoding="utf-8") + + prepend_calls = [] + preferred_calls = [] + + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.get_astrbot_site_packages_path", + lambda: str(site_packages_path), + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._prepend_sys_path", + lambda path: prepend_calls.append(path), + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred", + lambda path, requirements: preferred_calls.append((path, requirements)), + ) + + installer = PipInstaller("") + installer.prefer_installed_dependencies(str(requirements_path)) + + assert prepend_calls == [str(site_packages_path)] + assert preferred_calls == [(str(site_packages_path), {"demo-package"})] + + +@pytest.mark.asyncio +async def test_install_multiline_input_strips_comments_and_splits_options(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name=( + "demo-package==1.0 # pinned\n" + "--extra-index-url https://example.com/simple\n" + "another-package\n" + ) + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:5] == [ + "install", + "demo-package==1.0", + "--extra-index-url", + "https://example.com/simple", + "another-package", + ] + + +@pytest.mark.asyncio +async def test_install_single_line_input_strips_inline_comment(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="requests==2.31.0 # latest") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:2] == ["install", "requests==2.31.0"] + + +@pytest.mark.asyncio +async def test_install_splits_single_line_editable_option_input(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="-e .") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:3] == ["install", "-e", "."] + + +@pytest.mark.asyncio +async def test_install_splits_single_line_option_with_url(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name="--index-url https://example.com/simple demo-package" + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:4] == [ + "install", + "--index-url", + "https://example.com/simple", + "demo-package", + ] + assert recorded_args.count("--index-url") == 1 + assert "-i" not in recorded_args + + +@pytest.mark.asyncio +async def test_install_tracks_requirement_name_for_single_line_option_input( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + monkeypatch.delattr("sys.frozen", raising=False) + + site_packages_path = tmp_path / "site-packages" + run_pip = _make_run_pip_mock() + ensure_preferred_calls = [] + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.get_astrbot_site_packages_path", + lambda: str(site_packages_path), + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._prepend_sys_path", + lambda path: None, + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred", + lambda path, requirements: ensure_preferred_calls.append((path, requirements)), + ) + + installer = PipInstaller("") + await installer.install( + package_name="--index-url https://example.com/simple demo-package" + ) + + assert ensure_preferred_calls == [(str(site_packages_path), {"demo-package"})] + + +@pytest.mark.asyncio +async def test_install_keeps_equals_form_index_override(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name="--index-url=https://example.com/simple demo-package" + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:3] == [ + "install", + "--index-url=https://example.com/simple", + "demo-package", + ] + assert "-i" not in recorded_args + + +@pytest.mark.asyncio +async def test_install_keeps_short_form_index_override(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="-ihttps://example.com/simple demo-package") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:3] == [ + "install", + "-ihttps://example.com/simple", + "demo-package", + ] + assert "-i" not in recorded_args + + +@pytest.mark.asyncio +async def test_install_preserves_url_fragment_in_option_input(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name="--index-url https://example.com/simple#frag demo-package" + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:4] == [ + "install", + "--index-url", + "https://example.com/simple#frag", + "demo-package", + ] + assert "-i" not in recorded_args + + +def test_find_missing_requirements_returns_none_for_editable_local_path_reference( + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("-e ../sharedlib\n", encoding="utf-8") + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing is None + + +@pytest.mark.parametrize( + "requirement_line", + [ + "-e sharedlib\n", + "--editable=.\\sharedlib\n", + ], +) +def test_find_missing_requirements_returns_none_for_editable_local_path_variants( + tmp_path, requirement_line +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text(requirement_line, encoding="utf-8") + + missing = requirements_utils.find_missing_requirements(str(requirements_path)) + + assert missing is None + + +@pytest.mark.asyncio +async def test_install_strips_inline_comment_from_option_line(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name=( + "--extra-index-url https://example.com/simple # mirror\ndemo-package\n" + ) + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:4] == [ + "install", + "--extra-index-url", + "https://example.com/simple", + "demo-package", + ] + + +@pytest.mark.asyncio +async def test_install_falls_back_to_raw_input_for_invalid_token_string(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + raw_input = "demo-package !!! another-package" + await installer.install(package_name=raw_input) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:4] == ["install", "demo-package", "!!!", "another-package"] + + +@pytest.mark.asyncio +async def test_install_ignores_whitespace_only_package_string(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name=" ") + + run_pip.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_install_ignores_missing_package_and_requirements(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install() + + run_pip.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_install_respects_index_override_in_pip_install_arg(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("--index-url https://example.com/simple") + await installer.install(package_name="demo-package") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert "install" in recorded_args + assert "demo-package" in recorded_args + assert "--index-url" in recorded_args + assert "https://example.com/simple" in recorded_args + # Verify that default index overrides are NOT present + assert "mirrors.aliyun.com" not in recorded_args + assert "https://pypi.org/simple" not in recorded_args + + +@pytest.mark.asyncio +async def test_install_respects_no_index_with_find_links(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install( + package_name="--no-index --find-links /tmp/wheels demo-package" + ) + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert recorded_args[0:5] == [ + "install", + "--no-index", + "--find-links", + "/tmp/wheels", + "demo-package", + ] + assert "-i" not in recorded_args + + +def test_redact_pip_args_for_logging_redacts_inline_url_credentials(): + redacted_args = pip_installer_module._redact_pip_args_for_logging( + [ + "install", + "--index-url=https://user:secret@example.com/simple", + "demo-package", + ] + ) + + assert redacted_args == [ + "install", + "--index-url=https://@example.com/simple", + "demo-package", + ] + + +def test_redact_pip_args_for_logging_redacts_sensitive_option_value_pairs(): + redacted_args = pip_installer_module._redact_pip_args_for_logging( + [ + "install", + "--password", + "super-secret", + "--token", + "opaque-token", + "demo-package", + ] + ) + + assert redacted_args == [ + "install", + "--password", + "****", + "--token", + "****", + "demo-package", + ] + + +def test_redact_pip_args_for_logging_redacts_inline_sensitive_values(): + redacted_args = pip_installer_module._redact_pip_args_for_logging( + [ + "install", + "--api-token=super-secret", + "password=hunter2", + "demo-package", + ] + ) + + assert redacted_args == [ + "install", + "--api-token=****", + "password=****", + "demo-package", + ] + + +@pytest.mark.asyncio +async def test_install_logs_redacted_pip_argv_when_credentials_present(monkeypatch): + run_pip = _make_run_pip_mock() + logged_lines = [] + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.logger.info", + lambda line, *args: logged_lines.append(line % args if args else line), + ) + + installer = PipInstaller("") + await installer.install( + package_name="--index-url https://user:secret@example.com/simple demo-package" + ) + + argv_logs = [line for line in logged_lines if line.startswith("Pip 包管理器 argv:")] + + assert len(argv_logs) == 1 + assert "secret" not in argv_logs[0] + assert "user:" not in argv_logs[0] + assert "https://@example.com/simple" in argv_logs[0] + + +@pytest.mark.asyncio +async def test_install_does_not_add_aliyun_trusted_host_for_default_index(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("") + await installer.install(package_name="demo-package") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert "-i" in recorded_args + assert "https://pypi.org/simple" in recorded_args + assert "--trusted-host" not in recorded_args + + +@pytest.mark.asyncio +async def test_install_adds_aliyun_trusted_host_only_for_aliyun_index(monkeypatch): + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + + installer = PipInstaller("", pypi_index_url="https://mirrors.aliyun.com/simple") + await installer.install(package_name="demo-package") + + run_pip.assert_awaited_once() + recorded_args = run_pip.await_args_list[0].args[0] + + assert "-i" in recorded_args + assert "https://mirrors.aliyun.com/simple" in recorded_args + trusted_host_index = recorded_args.index("--trusted-host") + assert recorded_args[trusted_host_index + 1] == "mirrors.aliyun.com" diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index b91e25c01..1b52990a5 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,235 +1,472 @@ -import sys -from asyncio import Queue +import asyncio from pathlib import Path -from unittest.mock import MagicMock import pytest -import pytest_asyncio +import yaml -from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.db.sqlite import SQLiteDatabase -from astrbot.core.star.context import Context -from astrbot.core.star.star import star_map, star_registry -from astrbot.core.star.star_handler import star_handlers_registry -from astrbot.core.star.star_manager import PluginManager +from astrbot.core.star.star_manager import PluginDependencyInstallError, PluginManager +from astrbot.core.utils.pip_installer import PipInstallError +# --- Test Data & Helpers --- -def _clear_module_cache() -> None: - """Clear module cache for data module tree to ensure test isolation.""" - modules_to_remove = [ - key for key in sys.modules if key == "data" or key.startswith("data.") - ] - for key in modules_to_remove: - del sys.modules[key] - - -def _clear_registry(plugin_name: str) -> None: - """Clear plugin from global registries.""" - # Clear star_registry (list) - star_registry[:] = [md for md in star_registry if md.name != plugin_name] - # Clear star_map (dict) - keys_to_remove = [ - key for key, md in star_map.items() if md.name == plugin_name - ] - for key in keys_to_remove: - del star_map[key] - # Clear star_handlers_registry (StarHandlerRegistry) - for handler in list(star_handlers_registry): - if plugin_name in (handler.handler_module_path or ""): - star_handlers_registry.remove(handler) - -TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld" -TEST_PLUGIN_DIR = "helloworld" TEST_PLUGIN_NAME = "helloworld" +TEST_PLUGIN_REPO = "https://github.com/AstrBotDevs/astrbot_plugin_helloworld" +TEST_PLUGIN_DIR = "helloworld" -def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None: - plugin_dir.mkdir(parents=True, exist_ok=True) - (plugin_dir / "metadata.yaml").write_text( - "\n".join( - [ - f"name: {TEST_PLUGIN_NAME}", - "author: AstrBot Team", - "desc: Local test plugin", - "version: 1.0.0", - f"repo: {repo_url}", - ], - ) - + "\n", - encoding="utf-8", - ) - (plugin_dir / "main.py").write_text( - "\n".join( - [ - "from astrbot.api import star", - "", - "class Main(star.Star):", - " pass", - "", - ], - ), - encoding="utf-8", +class MockStar: + def __init__(self): + self.root_dir_name = TEST_PLUGIN_DIR + self.name = TEST_PLUGIN_NAME + self.repo = TEST_PLUGIN_REPO + self.reserved = False + self.info = {"repo": TEST_PLUGIN_REPO, "readme": ""} + + +def _write_local_test_plugin(plugin_path: Path, repo_url: str): + """Creates a minimal valid plugin structure.""" + plugin_path.mkdir(parents=True, exist_ok=True) + metadata = { + "name": TEST_PLUGIN_NAME, + "repo": repo_url, + "version": "1.0.0", + "author": "AstrBot Team", + "desc": "Local test plugin", + } + with open(plugin_path / "info.yaml", "w", encoding="utf-8") as f: + yaml.dump(metadata, f) + with open(plugin_path / "main.py", "w", encoding="utf-8") as f: + f.write("from astrbot.api.star import Star, Context, StarManager\n") + f.write("@StarManager.register\n") + f.write("class HelloWorld(Star):\n") + f.write(" def __init__(self, context: Context): ...\n") + + +def _write_requirements(plugin_path: Path): + """Creates a requirements.txt file.""" + with open(plugin_path / "requirements.txt", "w", encoding="utf-8") as f: + f.write("networkx\n") + + +def _clear_module_cache(): + """Clear test-specific modules from sys.modules to allow reloading.""" + import sys + + to_del = [m for m in sys.modules if m.startswith("data.plugins.helloworld")] + for m in to_del: + del sys.modules[m] + + +def _build_load_mock(events): + async def mock_load(specified_dir_name=None, ignore_version_check=False): + del ignore_version_check + events.append(("load", specified_dir_name or TEST_PLUGIN_DIR)) + return True, "" + + return mock_load + + +def _build_reload_mock(events): + async def mock_reload(specified_dir_name=None): + events.append(("reload", specified_dir_name or TEST_PLUGIN_DIR)) + return True, "" + + return mock_reload + + +def _build_dependency_install_mock(events, fail: bool): + async def mock_install_requirements( + *, requirements_path: str = None, package_name: str = None, **kwargs + ): + del kwargs + if requirements_path: + events.append(("deps", str(requirements_path))) + if package_name: + events.append(("deps_pkg", package_name)) + if fail: + raise Exception("pip failed") + + return mock_install_requirements + + +def _mock_missing_requirements(monkeypatch, missing: set[str]): + monkeypatch.setattr( + "astrbot.core.star.star_manager.find_missing_requirements_or_raise", + lambda requirements_path: missing, ) -@pytest_asyncio.fixture -async def plugin_manager_pm(tmp_path, monkeypatch): +def _mock_precheck_fails(monkeypatch): + from astrbot.core import RequirementsPrecheckFailed + + def mock_fail(requirements_path): + raise RequirementsPrecheckFailed("mock precheck failure") + + monkeypatch.setattr( + "astrbot.core.star.star_manager.find_missing_requirements_or_raise", + mock_fail, + ) + + +# --- Fixtures --- + + +@pytest.fixture +def plugin_manager_pm(tmp_path, monkeypatch): """Provides a fully isolated PluginManager instance for testing.""" # Clear module cache before setup to ensure isolation _clear_module_cache() - test_root = tmp_path / "astrbot_root" - data_dir = test_root / "data" - plugin_dir = data_dir / "plugins" - config_dir = data_dir / "config" - temp_dir = data_dir / "temp" - for path in (plugin_dir, config_dir, temp_dir): - path.mkdir(parents=True, exist_ok=True) + plugin_dir = tmp_path / "astrbot_root" / "data" / "plugins" + plugin_dir.mkdir(parents=True, exist_ok=True) - # Ensure `import data.plugins..main` resolves to this temp root. - (data_dir / "__init__.py").write_text("", encoding="utf-8") - (plugin_dir / "__init__.py").write_text("", encoding="utf-8") + class MockContext: + def __init__(self): + self.stars = [] - # Use monkeypatch for both env var and sys.path to ensure proper cleanup - monkeypatch.setenv("ASTRBOT_ROOT", str(test_root)) - monkeypatch.syspath_prepend(str(test_root)) + def get_all_stars(self): + return self.stars - # Create fresh, isolated instances for the context - event_queue = Queue() - config = AstrBotConfig() - db = SQLiteDatabase(str(data_dir / "test_db.db")) - config.plugin_store_path = str(plugin_dir) + def get_registered_star(self, name): + for s in self.stars: + if s.root_dir_name == name or s.name == name: + return s + return None - provider_manager = MagicMock() - platform_manager = MagicMock() - conversation_manager = MagicMock() - message_history_manager = MagicMock() - persona_manager = MagicMock() - persona_manager.personas_v3 = [] - astrbot_config_mgr = MagicMock() - knowledge_base_manager = MagicMock() - cron_manager = MagicMock() + mock_context = MockContext() + mock_config = {} + pm = PluginManager(mock_context, mock_config) - star_context = Context( - event_queue=event_queue, - config=config, - db=db, - provider_manager=provider_manager, - platform_manager=platform_manager, - conversation_manager=conversation_manager, - message_history_manager=message_history_manager, - persona_manager=persona_manager, - astrbot_config_mgr=astrbot_config_mgr, - knowledge_base_manager=knowledge_base_manager, - cron_manager=cron_manager, - subagent_orchestrator=None, + # Patch paths to use tmp_path + monkeypatch.setattr(pm, "plugin_store_path", str(plugin_dir)) + monkeypatch.setattr( + "astrbot.core.star.star_manager.get_astrbot_plugin_path", + lambda: str(plugin_dir), ) - manager = PluginManager(star_context, config) - try: - yield manager - finally: - # Cleanup global registries and module cache - _clear_registry(TEST_PLUGIN_NAME) - _clear_module_cache() - await db.engine.dispose() + return pm @pytest.fixture -def local_updator(plugin_manager_pm: PluginManager, monkeypatch): - plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR +def local_updator(plugin_manager_pm): + """Helper to setup a local plugin directory simulating a download.""" + path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR + _write_local_test_plugin(path, TEST_PLUGIN_REPO) + return path - async def mock_install(repo_url: str, proxy=""): # noqa: ARG001 - if repo_url != TEST_PLUGIN_REPO: - raise Exception("Repo not found") + +# --- Tests --- + + +@pytest.mark.asyncio +@pytest.mark.parametrize("dependency_install_fails", [False, True]) +async def test_install_plugin_dependency_install_flow( + plugin_manager_pm: PluginManager, monkeypatch, dependency_install_fails: bool +): + plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR + events = [] + _mock_missing_requirements(monkeypatch, {"networkx"}) + + async def mock_install(repo_url: str, proxy=""): + assert repo_url == TEST_PLUGIN_REPO _write_local_test_plugin(plugin_path, repo_url) + _write_requirements(plugin_path) return str(plugin_path) - async def mock_update(plugin, proxy=""): # noqa: ARG001 - if plugin.name != TEST_PLUGIN_NAME: - raise Exception("Plugin not found") - if not plugin_path.exists(): - raise Exception("Plugin path missing") - (plugin_path / ".updated").write_text("ok", encoding="utf-8") - monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install) - monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update) - return plugin_path + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, dependency_install_fails), + ) + + def mock_load_and_register(*args, **kwargs): + plugin_manager_pm.context.stars.append(MockStar()) + return _build_load_mock(events)(*args, **kwargs) + + monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register) + + if dependency_install_fails: + with pytest.raises(PluginDependencyInstallError, match="pip failed"): + await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + assert events == [("deps", str(plugin_path / "requirements.txt"))] + else: + await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + assert events == [ + ("deps", str(plugin_path / "requirements.txt")), + ("load", TEST_PLUGIN_DIR), + ] @pytest.mark.asyncio -async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): - assert plugin_manager_pm is not None - assert plugin_manager_pm.context is not None - assert plugin_manager_pm.config is not None - - -@pytest.mark.asyncio -async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): - success, err_message = await plugin_manager_pm.reload() - assert success is True - assert err_message is None - - -@pytest.mark.asyncio -async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path): - """Tests successful plugin installation without external network.""" - plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) - assert plugin_info is not None - assert plugin_info["name"] == TEST_PLUGIN_NAME - assert local_updator.exists() - assert any(md.name == TEST_PLUGIN_NAME for md in star_registry) - - -@pytest.mark.asyncio -async def test_install_nonexistent_plugin( - plugin_manager_pm: PluginManager, local_updator +@pytest.mark.parametrize("dependency_install_fails", [False, True]) +async def test_install_plugin_from_file_dependency_install_flow( + plugin_manager_pm: PluginManager, + monkeypatch, + tmp_path, + dependency_install_fails: bool, ): - """Tests that installing a non-existent plugin raises an exception.""" - with pytest.raises(Exception): - await plugin_manager_pm.install_plugin( - "https://github.com/Soulter/non_existent_repo" + zip_file_path = tmp_path / f"{TEST_PLUGIN_DIR}.zip" + zip_file_path.write_text("placeholder", encoding="utf-8") + events = [] + _mock_missing_requirements(monkeypatch, {"networkx"}) + + def mock_unzip_file(zip_path: str, target_dir: str) -> None: + assert zip_path == str(zip_file_path) + plugin_path = Path(target_dir) + _write_local_test_plugin(plugin_path, TEST_PLUGIN_REPO) + _write_requirements(plugin_path) + + monkeypatch.setattr(plugin_manager_pm.updator, "unzip_file", mock_unzip_file) + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, dependency_install_fails), + ) + + def mock_load_and_register(*args, **kwargs): + plugin_manager_pm.context.stars.append(MockStar()) + return _build_load_mock(events)(*args, **kwargs) + + monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register) + + if dependency_install_fails: + with pytest.raises(PluginDependencyInstallError, match="pip failed"): + await plugin_manager_pm.install_plugin_from_file(str(zip_file_path)) + assert any(e[0] == "deps" for e in events) + else: + await plugin_manager_pm.install_plugin_from_file(str(zip_file_path)) + assert any(e[0] == "deps" for e in events) + assert ("load", TEST_PLUGIN_DIR) in events + + +@pytest.mark.asyncio +@pytest.mark.parametrize("dependency_install_fails", [False, True]) +async def test_reload_failed_plugin_dependency_install_flow( + plugin_manager_pm: PluginManager, + local_updator: Path, + monkeypatch, + dependency_install_fails: bool, +): + _write_requirements(local_updator) + plugin_manager_pm.failed_plugin_dict[TEST_PLUGIN_DIR] = {"error": "init fail"} + events = [] + _mock_missing_requirements(monkeypatch, {"networkx"}) + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, dependency_install_fails), + ) + + def mock_load_and_register(*args, **kwargs): + plugin_manager_pm.context.stars.append(MockStar()) + return _build_load_mock(events)(*args, **kwargs) + + monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register) + + if dependency_install_fails: + with pytest.raises(PluginDependencyInstallError, match="pip failed"): + await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR) + assert events == [("deps", str(local_updator / "requirements.txt"))] + else: + await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR) + assert events == [ + ("deps", str(local_updator / "requirements.txt")), + ("load", TEST_PLUGIN_DIR), + ] + + +@pytest.mark.asyncio +async def test_ensure_plugin_requirements_reraises_cancelled_error( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + _write_requirements(local_updator) + _mock_missing_requirements(monkeypatch, {"networkx"}) + + async def mock_install_requirements(*args, **kwargs): + raise asyncio.CancelledError() + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + mock_install_requirements, + ) + + with pytest.raises(asyncio.CancelledError): + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, ) @pytest.mark.asyncio -async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path): - """Tests updating an existing plugin without external network.""" - plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) - assert plugin_info is not None - plugin_name = plugin_info["name"] - await plugin_manager_pm.update_plugin(plugin_name) - assert (local_updator / ".updated").exists() - - -@pytest.mark.asyncio -async def test_update_nonexistent_plugin( - plugin_manager_pm: PluginManager, local_updator +async def test_ensure_plugin_requirements_wraps_generic_dependency_install_failure( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch ): - """Tests that updating a non-existent plugin raises an exception.""" - with pytest.raises(Exception): - await plugin_manager_pm.update_plugin("non_existent_plugin") + _write_requirements(local_updator) + _mock_missing_requirements(monkeypatch, {"networkx"}) + async def mock_install_requirements(*args, **kwargs): + raise RuntimeError("pip failed") -@pytest.mark.asyncio -async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path): - """Tests successful plugin uninstallation.""" - plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) - assert plugin_info is not None - plugin_name = plugin_info["name"] - assert local_updator.exists() - - await plugin_manager_pm.uninstall_plugin(plugin_name) - - assert not local_updator.exists() - assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry) - assert not any( - TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + mock_install_requirements, ) + with pytest.raises(PluginDependencyInstallError, match="pip failed") as exc_info: + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert exc_info.value.plugin_label == TEST_PLUGIN_DIR + assert exc_info.value.requirements_path == str(local_updator / "requirements.txt") + assert isinstance(exc_info.value.__cause__, RuntimeError) + @pytest.mark.asyncio -async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager): - """Tests that uninstalling a non-existent plugin raises an exception.""" - with pytest.raises(Exception): - await plugin_manager_pm.uninstall_plugin("non_existent_plugin") +async def test_ensure_plugin_requirements_wraps_pip_install_error( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + _write_requirements(local_updator) + _mock_missing_requirements(monkeypatch, {"networkx"}) + + async def mock_install_requirements(*args, **kwargs): + raise PipInstallError("install failed", code=2) + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + mock_install_requirements, + ) + + with pytest.raises(PluginDependencyInstallError, match="install failed") as exc_info: + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert isinstance(exc_info.value.__cause__, PipInstallError) + + +@pytest.mark.asyncio +async def test_ensure_plugin_requirements_logs_requirements_file_install_for_missing_dependencies( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + _write_requirements(local_updator) + _mock_missing_requirements(monkeypatch, {"networkx"}) + logged_lines = [] + + async def mock_install_requirements(*args, **kwargs): + return None + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + mock_install_requirements, + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.logger.info", + lambda line, *args: logged_lines.append(line % args if args else line), + ) + + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert any("按 requirements.txt 安装" in line for line in logged_lines) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("dependency_install_fails", [False, True]) +async def test_update_plugin_dependency_install_flow( + plugin_manager_pm: PluginManager, + local_updator: Path, + monkeypatch, + dependency_install_fails: bool, +): + mock_star = MockStar() + plugin_manager_pm.context.stars.append(mock_star) + + _write_requirements(local_updator) + events = [] + _mock_missing_requirements(monkeypatch, {"networkx"}) + + async def mock_update(plugin, proxy=""): + del proxy + events.append(("update", plugin.name)) + + monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update) + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, dependency_install_fails), + ) + monkeypatch.setattr(plugin_manager_pm, "reload", _build_reload_mock(events)) + + if dependency_install_fails: + with pytest.raises(PluginDependencyInstallError, match="pip failed"): + await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME) + assert ("deps", str(local_updator / "requirements.txt")) in events + else: + await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME) + assert ("deps", str(local_updator / "requirements.txt")) in events + assert ("reload", TEST_PLUGIN_DIR) in events + + +@pytest.mark.asyncio +async def test_install_plugin_skips_dependency_install_when_no_requirements_missing( + plugin_manager_pm: PluginManager, monkeypatch +): + plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR + events = [] + _mock_missing_requirements(monkeypatch, set()) + + async def mock_install(repo_url: str, proxy=""): + _write_local_test_plugin(plugin_path, repo_url) + _write_requirements(plugin_path) + return str(plugin_path) + + monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install) + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, False), + ) + + def mock_load_and_register(*args, **kwargs): + plugin_manager_pm.context.stars.append(MockStar()) + return _build_load_mock(events)(*args, **kwargs) + + monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register) + + await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + + assert "deps" not in [e[0] for e in events] + assert ("load", TEST_PLUGIN_DIR) in events + + +@pytest.mark.asyncio +async def test_install_plugin_runs_dependency_install_when_precheck_fails( + plugin_manager_pm: PluginManager, monkeypatch +): + plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR + events = [] + + async def mock_install(repo_url: str, proxy=""): + _write_local_test_plugin(plugin_path, repo_url) + _write_requirements(plugin_path) + return str(plugin_path) + + _mock_precheck_fails(monkeypatch) + monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install) + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, False), + ) + + def mock_load_and_register(*args, **kwargs): + plugin_manager_pm.context.stars.append(MockStar()) + return _build_load_mock(events)(*args, **kwargs) + + monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register) + + await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + + assert ("deps", str(plugin_path / "requirements.txt")) in events + assert ("load", TEST_PLUGIN_DIR) in events