Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d87d586c0a | |||
| 410789311a | |||
| 6da59cfb07 |
@@ -61,3 +61,5 @@ GenieData/
|
||||
.codex/
|
||||
.opencode/
|
||||
.kilocode/
|
||||
.worktrees/
|
||||
docs/plans/
|
||||
|
||||
@@ -4,7 +4,21 @@ from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.file_token_service import FileTokenService
|
||||
from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.utils.pip_installer import (
|
||||
DependencyConflictError as DependencyConflictError,
|
||||
)
|
||||
from astrbot.core.utils.pip_installer import (
|
||||
PipInstaller,
|
||||
)
|
||||
from astrbot.core.utils.requirements_utils import (
|
||||
RequirementsPrecheckFailed as RequirementsPrecheckFailed,
|
||||
)
|
||||
from astrbot.core.utils.requirements_utils import (
|
||||
find_missing_requirements as find_missing_requirements,
|
||||
)
|
||||
from astrbot.core.utils.requirements_utils import (
|
||||
find_missing_requirements_or_raise as find_missing_requirements_or_raise,
|
||||
)
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
|
||||
|
||||
@@ -14,7 +14,12 @@ import yaml
|
||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from astrbot.core import logger, pip_installer, sp
|
||||
from astrbot.core import (
|
||||
DependencyConflictError,
|
||||
logger,
|
||||
pip_installer,
|
||||
sp,
|
||||
)
|
||||
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.default import VERSION
|
||||
@@ -27,6 +32,10 @@ from astrbot.core.utils.astrbot_path import (
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.requirements_utils import (
|
||||
RequirementsPrecheckFailed,
|
||||
find_missing_requirements_or_raise,
|
||||
)
|
||||
|
||||
from . import StarMetadata
|
||||
from .command_management import sync_command_configs
|
||||
@@ -48,6 +57,49 @@ class PluginVersionIncompatibleError(Exception):
|
||||
"""Raised when plugin astrbot_version is incompatible with current AstrBot."""
|
||||
|
||||
|
||||
class PluginDependencyInstallError(Exception):
|
||||
"""Raised when plugin dependency installation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
plugin_label: str,
|
||||
requirements_path: str,
|
||||
error: Exception,
|
||||
) -> None:
|
||||
message = f"插件 {plugin_label} 依赖安装失败: {error!s}"
|
||||
super().__init__(message)
|
||||
self.plugin_label = plugin_label
|
||||
self.requirements_path = requirements_path
|
||||
self.error = error
|
||||
|
||||
|
||||
async def _install_requirements_with_precheck(
|
||||
*,
|
||||
plugin_label: str,
|
||||
requirements_path: str,
|
||||
) -> None:
|
||||
try:
|
||||
missing = find_missing_requirements_or_raise(requirements_path)
|
||||
except RequirementsPrecheckFailed:
|
||||
logger.info(
|
||||
f"正在安装插件 {plugin_label} 的依赖库(预检查失败,回退到完整安装): "
|
||||
f"{requirements_path}"
|
||||
)
|
||||
await pip_installer.install(requirements_path=requirements_path)
|
||||
return
|
||||
|
||||
if not missing:
|
||||
logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: "
|
||||
f"{requirements_path} -> {sorted(missing)}"
|
||||
)
|
||||
await pip_installer.install(requirements_path=requirements_path)
|
||||
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
||||
from .star_tools import StarTools
|
||||
@@ -198,15 +250,37 @@ class PluginManager:
|
||||
to_update.append(p.root_dir_name)
|
||||
for p in to_update:
|
||||
plugin_path = os.path.join(plugin_dir, p)
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
await pip_installer.install(requirements_path=pth)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
|
||||
await self._ensure_plugin_requirements(plugin_path, p)
|
||||
return True
|
||||
|
||||
async def _ensure_plugin_requirements(
|
||||
self,
|
||||
plugin_dir_path: str,
|
||||
plugin_label: str,
|
||||
) -> None:
|
||||
requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
|
||||
if not os.path.exists(requirements_path):
|
||||
return
|
||||
|
||||
try:
|
||||
await _install_requirements_with_precheck(
|
||||
plugin_label=plugin_label,
|
||||
requirements_path=requirements_path,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except DependencyConflictError as e:
|
||||
logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
dependency_error = PluginDependencyInstallError(
|
||||
plugin_label=plugin_label,
|
||||
requirements_path=requirements_path,
|
||||
error=e,
|
||||
)
|
||||
logger.exception(str(dependency_error))
|
||||
raise dependency_error from e
|
||||
|
||||
async def _import_plugin_with_dependency_recovery(
|
||||
self,
|
||||
path: str,
|
||||
@@ -422,7 +496,7 @@ class PluginManager:
|
||||
root_dir_name: str,
|
||||
plugin_dir_path: str,
|
||||
reserved: bool,
|
||||
error: Exception | str,
|
||||
error: BaseException | str,
|
||||
error_trace: str,
|
||||
) -> dict:
|
||||
record: dict = {
|
||||
@@ -495,6 +569,9 @@ class PluginManager:
|
||||
|
||||
self._cleanup_plugin_state(dir_name)
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, dir_name)
|
||||
await self._ensure_plugin_requirements(plugin_path, dir_name)
|
||||
|
||||
success, error = await self.load(specified_dir_name=dir_name)
|
||||
if success:
|
||||
self.failed_plugin_dict.pop(dir_name, None)
|
||||
@@ -1078,6 +1155,10 @@ class PluginManager:
|
||||
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
await self._ensure_plugin_requirements(
|
||||
plugin_path,
|
||||
dir_name,
|
||||
)
|
||||
success, error_message = await self.load(
|
||||
specified_dir_name=dir_name,
|
||||
ignore_version_check=ignore_version_check,
|
||||
@@ -1317,6 +1398,12 @@ class PluginManager:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||
|
||||
await self.updator.update(plugin, proxy=proxy)
|
||||
if plugin.root_dir_name:
|
||||
plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
await self._ensure_plugin_requirements(
|
||||
plugin_dir_path,
|
||||
plugin_name,
|
||||
)
|
||||
await self.reload(plugin_name)
|
||||
|
||||
async def turn_off_plugin(self, plugin_name: str) -> None:
|
||||
@@ -1488,6 +1575,7 @@ class PluginManager:
|
||||
os.remove(zip_file_path)
|
||||
except BaseException as e:
|
||||
logger.warning(f"删除插件压缩包失败: {e!s}")
|
||||
await self._ensure_plugin_requirements(desti_dir, dir_name)
|
||||
# await self.reload()
|
||||
success, error_message = await self.load(
|
||||
specified_dir_name=dir_name,
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib.metadata as importlib_metadata
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
from astrbot.core.utils.requirements_utils import (
|
||||
canonicalize_distribution_name,
|
||||
collect_installed_distribution_versions,
|
||||
get_requirement_check_paths,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
def _resolve_core_dist_name(core_dist_name: str | None) -> str | None:
|
||||
if core_dist_name:
|
||||
try:
|
||||
importlib_metadata.distribution(core_dist_name)
|
||||
return core_dist_name
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
return None
|
||||
|
||||
try:
|
||||
importlib_metadata.distribution("AstrBot")
|
||||
return "AstrBot"
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
pass
|
||||
|
||||
if not __package__:
|
||||
return None
|
||||
|
||||
top_pkg = __package__.split(".")[0]
|
||||
for dist in importlib_metadata.distributions():
|
||||
try:
|
||||
top_level = dist.read_text("top_level.txt") or ""
|
||||
except Exception:
|
||||
continue
|
||||
if top_pkg in top_level.splitlines():
|
||||
if "Name" in dist.metadata:
|
||||
return dist.metadata["Name"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]:
|
||||
try:
|
||||
resolved_core_dist_name = _resolve_core_dist_name(core_dist_name)
|
||||
except Exception as exc:
|
||||
logger.warning("解析核心分发名称失败: %s", exc)
|
||||
return ()
|
||||
|
||||
if not resolved_core_dist_name:
|
||||
return ()
|
||||
|
||||
try:
|
||||
dist = importlib_metadata.distribution(resolved_core_dist_name)
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
return ()
|
||||
except Exception as exc:
|
||||
logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc)
|
||||
return ()
|
||||
|
||||
if not dist or not dist.requires:
|
||||
return ()
|
||||
|
||||
installed = collect_installed_distribution_versions(get_requirement_check_paths())
|
||||
if not installed:
|
||||
return ()
|
||||
|
||||
constraints: list[str] = []
|
||||
for req_str in dist.requires:
|
||||
try:
|
||||
req = Requirement(req_str)
|
||||
if req.marker and not req.marker.evaluate():
|
||||
continue
|
||||
name = canonicalize_distribution_name(req.name)
|
||||
if name in installed:
|
||||
constraints.append(f"{name}=={installed[name]}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return tuple(constraints)
|
||||
|
||||
|
||||
class CoreConstraintsProvider:
|
||||
def __init__(self, core_dist_name: str | None) -> None:
|
||||
self._core_dist_name = core_dist_name
|
||||
|
||||
@contextlib.contextmanager
|
||||
def constraints_file(self) -> Iterator[str | None]:
|
||||
constraints = _get_core_constraints(self._core_dist_name)
|
||||
if not constraints:
|
||||
yield None
|
||||
return
|
||||
|
||||
path: str | None = None
|
||||
try:
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write("\n".join(constraints))
|
||||
path = f.name
|
||||
logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints))
|
||||
except Exception as exc:
|
||||
logger.warning("创建临时约束文件失败: %s", exc)
|
||||
yield None
|
||||
return
|
||||
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
if path and os.path.exists(path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(path)
|
||||
@@ -7,21 +7,71 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import threading
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
||||
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
|
||||
from astrbot.core.utils.requirements_utils import (
|
||||
canonicalize_distribution_name as _canonicalize_distribution_name,
|
||||
)
|
||||
from astrbot.core.utils.requirements_utils import (
|
||||
extract_requirement_name,
|
||||
extract_requirement_names,
|
||||
parse_package_install_input,
|
||||
)
|
||||
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
|
||||
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
|
||||
_PIP_FAILURE_PATTERNS = {
|
||||
"error_prefix": re.compile(r"^\s*error:", re.IGNORECASE),
|
||||
"user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE),
|
||||
"resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE),
|
||||
"cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE),
|
||||
"conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE),
|
||||
"constraint": re.compile(r"\(constraint\)", re.IGNORECASE),
|
||||
"dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE),
|
||||
}
|
||||
_SENSITIVE_PIP_VALUE_KEYS = frozenset(
|
||||
{"password", "passwd", "pass", "api_token", "token", "auth_token"}
|
||||
)
|
||||
_MAX_PIP_OUTPUT_LINES = 200
|
||||
|
||||
|
||||
def _canonicalize_distribution_name(name: str) -> str:
|
||||
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
|
||||
class DependencyConflictError(Exception):
|
||||
"""Raised when pip encounters a dependency conflict."""
|
||||
|
||||
def __init__(
|
||||
self, message: str, errors: list[str], *, is_core_conflict: bool
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.errors = errors
|
||||
self.is_core_conflict = is_core_conflict
|
||||
|
||||
|
||||
class PipInstallError(Exception):
|
||||
"""Raised when pip install fails without a classified dependency conflict."""
|
||||
|
||||
def __init__(self, message: str, *, code: int) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipConflictContext:
|
||||
relevant_lines: list[str]
|
||||
requested_lines: list[str]
|
||||
dependency_detail_lines: list[str]
|
||||
constraint_lines: list[str]
|
||||
has_strong_conflict_signal: bool
|
||||
has_contextual_conflict_signal: bool
|
||||
|
||||
|
||||
def _get_pip_main():
|
||||
@@ -41,11 +91,12 @@ def _get_pip_main():
|
||||
return pip_main
|
||||
|
||||
|
||||
def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]:
|
||||
stream = io.StringIO()
|
||||
with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
|
||||
result_code = pip_main(args)
|
||||
return result_code, stream.getvalue()
|
||||
def _prepend_sys_path(path: str) -> None:
|
||||
normalized_target = os.path.realpath(path)
|
||||
sys.path[:] = [
|
||||
item for item in sys.path if os.path.realpath(item) != normalized_target
|
||||
]
|
||||
sys.path.insert(0, normalized_target)
|
||||
|
||||
|
||||
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
|
||||
@@ -59,76 +110,258 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No
|
||||
handler.close()
|
||||
|
||||
|
||||
def _prepend_sys_path(path: str) -> None:
|
||||
normalized_target = os.path.realpath(path)
|
||||
sys.path[:] = [
|
||||
item for item in sys.path if os.path.realpath(item) != normalized_target
|
||||
]
|
||||
sys.path.insert(0, normalized_target)
|
||||
def _get_trusted_host_for_index_url(index_url: str) -> str | None:
|
||||
parsed = urlparse(index_url if "://" in index_url else f"//{index_url}")
|
||||
host = parsed.hostname
|
||||
if host == "mirrors.aliyun.com":
|
||||
return host
|
||||
return None
|
||||
|
||||
|
||||
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
|
||||
base_path = os.path.join(site_packages_path, *module_name.split("."))
|
||||
package_init = os.path.join(base_path, "__init__.py")
|
||||
module_file = f"{base_path}.py"
|
||||
return os.path.isfile(package_init) or os.path.isfile(module_file)
|
||||
def _normalize_sensitive_pip_key(raw_key: str) -> str:
|
||||
return raw_key.lstrip("-").replace("-", "_").lower()
|
||||
|
||||
|
||||
def _is_module_loaded_from_site_packages(
|
||||
module_name: str,
|
||||
site_packages_path: str,
|
||||
) -> bool:
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except Exception:
|
||||
return False
|
||||
def _is_sensitive_pip_value_key(raw_key: str) -> bool:
|
||||
return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS
|
||||
|
||||
module_file = getattr(module, "__file__", None)
|
||||
if not module_file:
|
||||
return False
|
||||
|
||||
module_path = os.path.realpath(module_file)
|
||||
site_packages_real = os.path.realpath(site_packages_path)
|
||||
try:
|
||||
return (
|
||||
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
|
||||
def _redact_url_credentials(raw_value: str) -> str:
|
||||
"""Redact URL credentials and known inline secret values for safe logging."""
|
||||
parsed = urlparse(raw_value)
|
||||
if parsed.netloc and "@" in parsed.netloc:
|
||||
hostname = parsed.hostname or ""
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
return parsed._replace(netloc=f"<redacted>@{hostname}{port}").geturl()
|
||||
|
||||
if raw_value.startswith("--"):
|
||||
option, separator, _ = raw_value.partition("=")
|
||||
if separator and _is_sensitive_pip_value_key(option):
|
||||
return f"{option}=****"
|
||||
return raw_value
|
||||
|
||||
key, separator, _ = raw_value.partition("=")
|
||||
if separator and _is_sensitive_pip_value_key(key):
|
||||
return f"{key}=****"
|
||||
|
||||
return raw_value
|
||||
|
||||
|
||||
def _redact_pip_args_for_logging(args: list[str]) -> list[str]:
|
||||
redacted_args: list[str] = []
|
||||
redact_next_value = False
|
||||
|
||||
for arg in args:
|
||||
if redact_next_value:
|
||||
redacted_args.append("****")
|
||||
redact_next_value = False
|
||||
continue
|
||||
|
||||
if arg.startswith("--") and "=" in arg:
|
||||
option, value = arg.split("=", 1)
|
||||
if _is_sensitive_pip_value_key(option):
|
||||
redacted_args.append(f"{option}=****")
|
||||
else:
|
||||
redacted_args.append(f"{option}={_redact_url_credentials(value)}")
|
||||
continue
|
||||
|
||||
if arg.startswith("-i") and arg != "-i":
|
||||
redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}")
|
||||
continue
|
||||
|
||||
if _is_sensitive_pip_value_key(arg):
|
||||
redacted_args.append(arg)
|
||||
redact_next_value = True
|
||||
continue
|
||||
|
||||
redacted_args.append(_redact_url_credentials(arg))
|
||||
|
||||
return redacted_args
|
||||
|
||||
|
||||
def _package_specs_override_index(package_specs: list[str]) -> bool:
|
||||
for index, spec in enumerate(package_specs):
|
||||
if spec == "--no-index":
|
||||
return True
|
||||
if spec in {"-i", "--index-url"}:
|
||||
if index + 1 < len(package_specs):
|
||||
return True
|
||||
continue
|
||||
if spec.startswith("--index-url="):
|
||||
return True
|
||||
if spec.startswith("-i") and spec != "-i":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _StreamingLogWriter(io.TextIOBase):
|
||||
def __init__(self, log_func, *, max_lines: int | None = None) -> None:
|
||||
self._log_func = log_func
|
||||
self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES)
|
||||
self._buffer = ""
|
||||
|
||||
def write(self, text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
self._buffer += text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
while "\n" in self._buffer:
|
||||
raw_line, self._buffer = self._buffer.split("\n", 1)
|
||||
line = raw_line.rstrip("\r\n")
|
||||
self._log_func(line)
|
||||
self._lines.append(line)
|
||||
return len(text)
|
||||
|
||||
def flush(self) -> None:
|
||||
line = self._buffer.rstrip("\r\n")
|
||||
if line:
|
||||
self._log_func(line)
|
||||
self._lines.append(line)
|
||||
self._buffer = ""
|
||||
|
||||
@property
|
||||
def lines(self) -> list[str]:
|
||||
return list(self._lines)
|
||||
|
||||
|
||||
def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]:
|
||||
stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES)
|
||||
with (
|
||||
contextlib.redirect_stdout(stream),
|
||||
contextlib.redirect_stderr(stream),
|
||||
):
|
||||
result_code = pip_main(args)
|
||||
stream.flush()
|
||||
return result_code, stream.lines
|
||||
|
||||
|
||||
def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool:
|
||||
names = pattern_names or tuple(_PIP_FAILURE_PATTERNS)
|
||||
return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names)
|
||||
|
||||
|
||||
def _normalize_conflict_detail_line(line: str) -> str:
|
||||
stripped = line.strip()
|
||||
if _matches_pip_failure_pattern(stripped, "user_requested"):
|
||||
return re.sub(
|
||||
r"^\s*The user requested\s+",
|
||||
"",
|
||||
stripped,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
except ValueError:
|
||||
return False
|
||||
return stripped
|
||||
|
||||
|
||||
def _extract_requirement_name(raw_requirement: str) -> str | None:
|
||||
line = raw_requirement.split("#", 1)[0].strip()
|
||||
if not line:
|
||||
return None
|
||||
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
|
||||
return None
|
||||
if line.startswith("-"):
|
||||
def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None:
|
||||
matched_indices = [
|
||||
index
|
||||
for index, line in enumerate(output_lines)
|
||||
if _matches_pip_failure_pattern(line)
|
||||
]
|
||||
if matched_indices:
|
||||
relevant_index_set: set[int] = set()
|
||||
for index in matched_indices:
|
||||
start = max(0, index - 1)
|
||||
end = min(len(output_lines), index + 2)
|
||||
relevant_index_set.update(range(start, end))
|
||||
relevant_output_lines = [
|
||||
line
|
||||
for index, line in enumerate(output_lines)
|
||||
if index in relevant_index_set
|
||||
]
|
||||
else:
|
||||
relevant_output_lines = output_lines[-5:]
|
||||
|
||||
if not relevant_output_lines:
|
||||
return None
|
||||
|
||||
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
|
||||
if egg_match:
|
||||
return _canonicalize_distribution_name(egg_match.group(1))
|
||||
dependency_detail_lines = [
|
||||
line.strip()
|
||||
for line in relevant_output_lines
|
||||
if _matches_pip_failure_pattern(line, "dependency_detail")
|
||||
]
|
||||
requested_lines = [
|
||||
line.strip()
|
||||
for line in relevant_output_lines
|
||||
if _matches_pip_failure_pattern(line, "user_requested")
|
||||
and not _matches_pip_failure_pattern(line, "constraint")
|
||||
]
|
||||
if not requested_lines:
|
||||
requested_lines = [
|
||||
line
|
||||
for line in dependency_detail_lines
|
||||
if not _matches_pip_failure_pattern(line, "constraint")
|
||||
]
|
||||
constraint_lines = [
|
||||
line.strip()
|
||||
for line in relevant_output_lines
|
||||
if _matches_pip_failure_pattern(line, "constraint")
|
||||
]
|
||||
|
||||
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
|
||||
if not candidate:
|
||||
has_strong_conflict_signal = any(
|
||||
_matches_pip_failure_pattern(
|
||||
line,
|
||||
"resolution_impossible",
|
||||
"cannot_install",
|
||||
)
|
||||
for line in relevant_output_lines
|
||||
)
|
||||
|
||||
has_contextual_conflict_signal = any(
|
||||
_matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines
|
||||
) and bool(dependency_detail_lines or requested_lines or constraint_lines)
|
||||
|
||||
return PipConflictContext(
|
||||
relevant_lines=relevant_output_lines,
|
||||
requested_lines=requested_lines,
|
||||
dependency_detail_lines=dependency_detail_lines,
|
||||
constraint_lines=constraint_lines,
|
||||
has_strong_conflict_signal=has_strong_conflict_signal,
|
||||
has_contextual_conflict_signal=has_contextual_conflict_signal,
|
||||
)
|
||||
|
||||
|
||||
def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None:
|
||||
context = _build_pip_conflict_context(output_lines)
|
||||
if context is None:
|
||||
return None
|
||||
return _canonicalize_distribution_name(candidate)
|
||||
|
||||
if (
|
||||
not context.has_strong_conflict_signal
|
||||
and not context.has_contextual_conflict_signal
|
||||
and not (context.requested_lines and context.constraint_lines)
|
||||
):
|
||||
return None
|
||||
|
||||
def _extract_requirement_names(requirements_path: str) -> set[str]:
|
||||
names: set[str] = set()
|
||||
try:
|
||||
with open(requirements_path, encoding="utf-8") as requirements_file:
|
||||
for line in requirements_file:
|
||||
requirement_name = _extract_requirement_name(line)
|
||||
if requirement_name:
|
||||
names.add(requirement_name)
|
||||
except Exception as exc:
|
||||
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
|
||||
return names
|
||||
is_core_conflict = bool(context.constraint_lines)
|
||||
|
||||
detail = ""
|
||||
if context.constraint_lines and context.requested_lines:
|
||||
detail = (
|
||||
" 冲突详情: "
|
||||
f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs "
|
||||
f"{_normalize_conflict_detail_line(context.constraint_lines[0])}。"
|
||||
)
|
||||
elif len(context.dependency_detail_lines) >= 2:
|
||||
detail = (
|
||||
" 冲突详情: "
|
||||
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs "
|
||||
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}。"
|
||||
)
|
||||
|
||||
if is_core_conflict:
|
||||
message = (
|
||||
f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容,"
|
||||
"为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。"
|
||||
)
|
||||
else:
|
||||
message = f"检测到依赖冲突。{detail}"
|
||||
|
||||
return DependencyConflictError(
|
||||
message,
|
||||
context.relevant_lines,
|
||||
is_core_conflict=is_core_conflict,
|
||||
)
|
||||
|
||||
|
||||
def _extract_top_level_modules(
|
||||
@@ -155,7 +388,11 @@ def _collect_candidate_modules(
|
||||
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
|
||||
try:
|
||||
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
|
||||
distribution_name = distribution.metadata.get("Name")
|
||||
distribution_name = (
|
||||
distribution.metadata["Name"]
|
||||
if "Name" in distribution.metadata
|
||||
else None
|
||||
)
|
||||
if not distribution_name:
|
||||
continue
|
||||
canonical_name = _canonicalize_distribution_name(distribution_name)
|
||||
@@ -173,7 +410,7 @@ def _collect_candidate_modules(
|
||||
|
||||
for distribution in by_name.get(requirement_name, []):
|
||||
for dependency_line in distribution.requires or []:
|
||||
dependency_name = _extract_requirement_name(dependency_line)
|
||||
dependency_name = extract_requirement_name(dependency_line)
|
||||
if not dependency_name:
|
||||
continue
|
||||
if dependency_name in expanded_requirement_names:
|
||||
@@ -230,6 +467,38 @@ def _ensure_preferred_modules(
|
||||
raise RuntimeError(conflict_message)
|
||||
|
||||
|
||||
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
|
||||
base_path = os.path.join(site_packages_path, *module_name.split("."))
|
||||
package_init = os.path.join(base_path, "__init__.py")
|
||||
module_file = f"{base_path}.py"
|
||||
return os.path.isfile(package_init) or os.path.isfile(module_file)
|
||||
|
||||
|
||||
def _is_module_loaded_from_site_packages(
|
||||
module_name: str,
|
||||
site_packages_path: str,
|
||||
) -> bool:
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
module_file = getattr(module, "__file__", None)
|
||||
if not module_file:
|
||||
return False
|
||||
|
||||
module_path = os.path.realpath(module_file)
|
||||
site_packages_real = os.path.realpath(site_packages_path)
|
||||
try:
|
||||
return (
|
||||
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
|
||||
)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _prefer_module_from_site_packages(
|
||||
module_name: str, site_packages_path: str
|
||||
) -> bool:
|
||||
@@ -531,9 +800,63 @@ def _patch_distlib_finder_for_frozen_runtime() -> None:
|
||||
|
||||
|
||||
class PipInstaller:
|
||||
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pip_install_arg: str,
|
||||
pypi_index_url: str | None = None,
|
||||
core_dist_name: str | None = "AstrBot",
|
||||
) -> None:
|
||||
self.pip_install_arg = pip_install_arg
|
||||
self.pypi_index_url = pypi_index_url
|
||||
self.core_dist_name = core_dist_name
|
||||
self._core_constraints = CoreConstraintsProvider(core_dist_name)
|
||||
|
||||
def _build_pip_args(
|
||||
self,
|
||||
package_name: str | None,
|
||||
requirements_path: str | None,
|
||||
mirror: str | None,
|
||||
) -> tuple[list[str], set[str]]:
|
||||
args: list[str] = []
|
||||
requested_requirements: set[str] = set()
|
||||
normalized_requirements_path = (
|
||||
requirements_path.strip() if requirements_path else ""
|
||||
)
|
||||
|
||||
if package_name and normalized_requirements_path:
|
||||
raise ValueError(
|
||||
"package_name and requirements_path cannot be used together"
|
||||
)
|
||||
|
||||
if package_name:
|
||||
parsed_package = parse_package_install_input(package_name)
|
||||
if parsed_package.specs:
|
||||
args = ["install", *parsed_package.specs]
|
||||
requested_requirements = set(parsed_package.requirement_names)
|
||||
elif normalized_requirements_path:
|
||||
args = ["install", "-r", normalized_requirements_path]
|
||||
requested_requirements = extract_requirement_names(
|
||||
normalized_requirements_path
|
||||
)
|
||||
|
||||
if not args:
|
||||
return [], requested_requirements
|
||||
|
||||
pip_install_args = (
|
||||
shlex.split(self.pip_install_arg) if self.pip_install_arg else []
|
||||
)
|
||||
|
||||
if not _package_specs_override_index([*args[1:], *pip_install_args]):
|
||||
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
||||
trusted_host = _get_trusted_host_for_index_url(index_url)
|
||||
if trusted_host:
|
||||
args.extend(["--trusted-host", trusted_host])
|
||||
args.extend(["-i", index_url])
|
||||
|
||||
if pip_install_args:
|
||||
args.extend(pip_install_args)
|
||||
|
||||
return args, requested_requirements
|
||||
|
||||
async def install(
|
||||
self,
|
||||
@@ -541,36 +864,37 @@ class PipInstaller:
|
||||
requirements_path: str | None = None,
|
||||
mirror: str | None = None,
|
||||
) -> None:
|
||||
args = ["install"]
|
||||
requested_requirements: set[str] = set()
|
||||
if package_name:
|
||||
args.append(package_name)
|
||||
requirement_name = _extract_requirement_name(package_name)
|
||||
if requirement_name:
|
||||
requested_requirements.add(requirement_name)
|
||||
elif requirements_path:
|
||||
args.extend(["-r", requirements_path])
|
||||
requested_requirements = _extract_requirement_names(requirements_path)
|
||||
|
||||
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
||||
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
|
||||
args, requested_requirements = self._build_pip_args(
|
||||
package_name, requirements_path, mirror
|
||||
)
|
||||
if not args:
|
||||
logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。")
|
||||
return
|
||||
|
||||
target_site_packages = None
|
||||
if is_packaged_desktop_runtime():
|
||||
target_site_packages = get_astrbot_site_packages_path()
|
||||
os.makedirs(target_site_packages, exist_ok=True)
|
||||
_prepend_sys_path(target_site_packages)
|
||||
args.extend(["--target", target_site_packages])
|
||||
args.extend(["--upgrade", "--force-reinstall"])
|
||||
args.extend(
|
||||
[
|
||||
"--target",
|
||||
target_site_packages,
|
||||
"--upgrade",
|
||||
"--upgrade-strategy",
|
||||
"only-if-needed",
|
||||
]
|
||||
)
|
||||
|
||||
if self.pip_install_arg:
|
||||
args.extend(self.pip_install_arg.split())
|
||||
with self._core_constraints.constraints_file() as constraints_file_path:
|
||||
if constraints_file_path:
|
||||
args.extend(["-c", constraints_file_path])
|
||||
|
||||
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
|
||||
result_code = await self._run_pip_in_process(args)
|
||||
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
logger.info(
|
||||
"Pip 包管理器 argv: %s",
|
||||
["pip", *_redact_pip_args_for_logging(args)],
|
||||
)
|
||||
await self._run_pip_with_classification(args)
|
||||
|
||||
if target_site_packages:
|
||||
_prepend_sys_path(target_site_packages)
|
||||
@@ -589,7 +913,7 @@ class PipInstaller:
|
||||
if not os.path.isdir(target_site_packages):
|
||||
return
|
||||
|
||||
requested_requirements = _extract_requirement_names(requirements_path)
|
||||
requested_requirements = extract_requirement_names(requirements_path)
|
||||
if not requested_requirements:
|
||||
return
|
||||
|
||||
@@ -605,13 +929,21 @@ class PipInstaller:
|
||||
_patch_distlib_finder_for_frozen_runtime()
|
||||
|
||||
original_handlers = list(logging.getLogger().handlers)
|
||||
result_code, output = await asyncio.to_thread(
|
||||
_run_pip_main_with_output, pip_main, args
|
||||
)
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
logger.info(line)
|
||||
try:
|
||||
result_code, output_lines = await asyncio.to_thread(
|
||||
_run_pip_main_streaming, pip_main, args
|
||||
)
|
||||
finally:
|
||||
_cleanup_added_root_handlers(original_handlers)
|
||||
|
||||
if result_code != 0:
|
||||
conflict = _classify_pip_failure(output_lines)
|
||||
if conflict:
|
||||
raise conflict
|
||||
|
||||
_cleanup_added_root_handlers(original_handlers)
|
||||
return result_code
|
||||
|
||||
async def _run_pip_with_classification(self, args: list[str]) -> None:
|
||||
result_code = await self._run_pip_in_process(args)
|
||||
if result_code != 0:
|
||||
raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code)
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
import importlib.metadata as importlib_metadata
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from packaging.requirements import InvalidRequirement, Requirement
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
||||
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
class RequirementsPrecheckFailed(Exception):
|
||||
"""Raised when the pre-check of requirements fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParsedPackageInput:
|
||||
specs: tuple[str, ...]
|
||||
requirement_names: frozenset[str]
|
||||
|
||||
|
||||
def canonicalize_distribution_name(name: str) -> str:
|
||||
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
|
||||
|
||||
|
||||
def strip_inline_requirement_comment(raw_input: str) -> str:
|
||||
if raw_input.lstrip().startswith("#"):
|
||||
return ""
|
||||
return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip()
|
||||
|
||||
|
||||
def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool:
|
||||
try:
|
||||
parsed_version = Version(version)
|
||||
except InvalidVersion:
|
||||
return False
|
||||
return specifier.contains(parsed_version, prereleases=True)
|
||||
|
||||
|
||||
def _looks_like_local_path_reference(token: str) -> bool:
|
||||
candidate = token.strip()
|
||||
if not candidate:
|
||||
return False
|
||||
return candidate in {".", ".."} or candidate.startswith(
|
||||
("./", "../", "/", "~/", ".\\", "..\\", "\\")
|
||||
)
|
||||
|
||||
|
||||
def looks_like_direct_reference(token: str) -> bool:
|
||||
candidate = token.strip()
|
||||
if not candidate:
|
||||
return False
|
||||
return (
|
||||
_looks_like_local_path_reference(candidate)
|
||||
or candidate.startswith("git+")
|
||||
or "://" in candidate
|
||||
)
|
||||
|
||||
|
||||
def extract_requirement_name(raw_requirement: str) -> str | None:
|
||||
line = raw_requirement.split("#", 1)[0].strip()
|
||||
if not line:
|
||||
return None
|
||||
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
|
||||
return None
|
||||
|
||||
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
|
||||
if egg_match:
|
||||
return canonicalize_distribution_name(egg_match.group(1))
|
||||
|
||||
if line.startswith("-"):
|
||||
return None
|
||||
|
||||
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
|
||||
if not candidate:
|
||||
return None
|
||||
return canonicalize_distribution_name(candidate)
|
||||
|
||||
|
||||
def _parse_editable_or_direct_name(target: str) -> str | None:
|
||||
name = extract_requirement_name(target)
|
||||
if not name:
|
||||
return None
|
||||
if "#egg=" in target or not looks_like_direct_reference(target):
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def _parse_requirement_name_and_spec(
|
||||
line: str,
|
||||
) -> tuple[str | None, SpecifierSet | None]:
|
||||
if line.startswith(("-c", "--constraint")):
|
||||
return None, None
|
||||
|
||||
try:
|
||||
req = Requirement(line)
|
||||
except InvalidRequirement:
|
||||
tokens = shlex.split(line)
|
||||
if not tokens:
|
||||
return None, None
|
||||
|
||||
editable_target: str | None = None
|
||||
if tokens[0] in {"-e", "--editable"} and len(tokens) > 1:
|
||||
editable_target = tokens[1]
|
||||
elif tokens[0].startswith("--editable="):
|
||||
editable_target = tokens[0].split("=", 1)[1]
|
||||
|
||||
if editable_target:
|
||||
name = _parse_editable_or_direct_name(editable_target)
|
||||
return (name, None) if name else (None, None)
|
||||
|
||||
name = _parse_editable_or_direct_name(line)
|
||||
return (name, None) if name else (None, None)
|
||||
|
||||
if req.marker and not req.marker.evaluate():
|
||||
return None, None
|
||||
|
||||
return canonicalize_distribution_name(req.name), (req.specifier or None)
|
||||
|
||||
|
||||
def _parse_requirement_line(
|
||||
line: str,
|
||||
) -> tuple[str, SpecifierSet | None] | None:
|
||||
name, specifier = _parse_requirement_name_and_spec(line)
|
||||
return (name, specifier) if name else None
|
||||
|
||||
|
||||
def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]:
|
||||
requirement_names: set[str] = set()
|
||||
skip_next_for: str | None = None
|
||||
|
||||
for token in tokens:
|
||||
if skip_next_for:
|
||||
if skip_next_for == "editable":
|
||||
name = _parse_editable_or_direct_name(token)
|
||||
if name:
|
||||
requirement_names.add(name)
|
||||
skip_next_for = None
|
||||
continue
|
||||
|
||||
if token in {"-e", "--editable"}:
|
||||
skip_next_for = "editable"
|
||||
continue
|
||||
|
||||
if token in {
|
||||
"-i",
|
||||
"--index-url",
|
||||
"--extra-index-url",
|
||||
"-f",
|
||||
"--find-links",
|
||||
"--trusted-host",
|
||||
"-r",
|
||||
"--requirement",
|
||||
"-c",
|
||||
"--constraint",
|
||||
}:
|
||||
skip_next_for = "option-value"
|
||||
continue
|
||||
|
||||
if token.startswith(("--editable=",)):
|
||||
editable_target = token.split("=", 1)[1]
|
||||
name = _parse_editable_or_direct_name(editable_target)
|
||||
if name:
|
||||
requirement_names.add(name)
|
||||
continue
|
||||
|
||||
if token.startswith(
|
||||
(
|
||||
"--index-url=",
|
||||
"--extra-index-url=",
|
||||
"--find-links=",
|
||||
"--trusted-host=",
|
||||
"--requirement=",
|
||||
"--constraint=",
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
(token.startswith("-i") and token != "-i")
|
||||
or (token.startswith("-f") and token != "-f")
|
||||
or token == "--no-index"
|
||||
):
|
||||
continue
|
||||
|
||||
if token.startswith("-"):
|
||||
continue
|
||||
|
||||
name, _ = _parse_requirement_name_and_spec(token)
|
||||
if name:
|
||||
requirement_names.add(name)
|
||||
|
||||
return frozenset(requirement_names)
|
||||
|
||||
|
||||
def parse_package_install_input(raw_input: str) -> ParsedPackageInput:
|
||||
specs: list[str] = []
|
||||
requirement_names: set[str] = set()
|
||||
normalized = raw_input.strip()
|
||||
if not normalized:
|
||||
return ParsedPackageInput(specs=(), requirement_names=frozenset())
|
||||
|
||||
for raw_line in normalized.splitlines():
|
||||
line = strip_inline_requirement_comment(raw_line)
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
Requirement(line)
|
||||
except InvalidRequirement:
|
||||
tokens = shlex.split(line)
|
||||
if not tokens:
|
||||
continue
|
||||
specs.extend(tokens)
|
||||
requirement_names.update(
|
||||
_extract_requirement_names_from_package_tokens(tokens)
|
||||
)
|
||||
continue
|
||||
|
||||
specs.append(line)
|
||||
name, _ = _parse_requirement_name_and_spec(line)
|
||||
if name:
|
||||
requirement_names.add(name)
|
||||
|
||||
return ParsedPackageInput(
|
||||
specs=tuple(specs),
|
||||
requirement_names=frozenset(requirement_names),
|
||||
)
|
||||
|
||||
|
||||
def _iter_requirement_lines(
|
||||
requirements_path: str,
|
||||
_visited: set[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
visited = _visited or set()
|
||||
resolved_path = os.path.realpath(requirements_path)
|
||||
if resolved_path in visited:
|
||||
logger.warning(
|
||||
"检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path
|
||||
)
|
||||
return
|
||||
visited.add(resolved_path)
|
||||
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
for raw_line in f:
|
||||
line = strip_inline_requirement_comment(raw_line)
|
||||
if not line:
|
||||
continue
|
||||
|
||||
tokens = shlex.split(line)
|
||||
if not tokens:
|
||||
continue
|
||||
|
||||
nested: str | None = None
|
||||
if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1:
|
||||
nested = tokens[1]
|
||||
elif tokens[0].startswith("--requirement="):
|
||||
nested = tokens[0].split("=", 1)[1]
|
||||
|
||||
if nested:
|
||||
if not os.path.isabs(nested):
|
||||
nested = os.path.join(os.path.dirname(resolved_path), nested)
|
||||
yield from _iter_requirement_lines(nested, _visited=visited)
|
||||
continue
|
||||
|
||||
yield line
|
||||
|
||||
|
||||
def iter_requirements(
|
||||
requirements_path: str | None = None,
|
||||
lines: Iterable[str] | None = None,
|
||||
) -> Iterator[tuple[str, SpecifierSet | None]]:
|
||||
if lines is None:
|
||||
if requirements_path is None:
|
||||
raise ValueError("Either requirements_path or lines must be provided")
|
||||
lines = _iter_requirement_lines(requirements_path)
|
||||
|
||||
for line in lines:
|
||||
parsed = _parse_requirement_line(line)
|
||||
if parsed is not None:
|
||||
yield parsed
|
||||
|
||||
|
||||
def extract_requirement_names(requirements_path: str) -> set[str]:
|
||||
try:
|
||||
return {
|
||||
name for name, _ in iter_requirements(requirements_path=requirements_path)
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
|
||||
return set()
|
||||
|
||||
|
||||
def get_requirement_check_paths() -> list[str]:
|
||||
paths = list(sys.path)
|
||||
if is_packaged_desktop_runtime():
|
||||
target_site_packages = get_astrbot_site_packages_path()
|
||||
if os.path.isdir(target_site_packages):
|
||||
paths.insert(0, target_site_packages)
|
||||
return paths
|
||||
|
||||
|
||||
def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
|
||||
distribution_name = (
|
||||
distribution.metadata["Name"] if "Name" in distribution.metadata else None
|
||||
)
|
||||
if not distribution_name:
|
||||
return None, None
|
||||
return canonicalize_distribution_name(distribution_name), distribution.version
|
||||
|
||||
|
||||
def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None:
|
||||
installed: dict[str, str] = {}
|
||||
try:
|
||||
for distribution in importlib_metadata.distributions(path=paths):
|
||||
distribution_name, version = _canonical_distribution_identity(distribution)
|
||||
if not distribution_name or not version:
|
||||
continue
|
||||
installed.setdefault(distribution_name, version)
|
||||
except Exception as exc:
|
||||
logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc)
|
||||
return None
|
||||
return installed
|
||||
|
||||
|
||||
def _load_requirement_lines_for_precheck(
|
||||
requirements_path: str,
|
||||
) -> tuple[bool, list[str] | None]:
|
||||
try:
|
||||
requirement_lines = list(_iter_requirement_lines(requirements_path))
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"预检查缺失依赖失败,将回退到完整安装: %s (%s)",
|
||||
requirements_path,
|
||||
exc,
|
||||
)
|
||||
return False, None
|
||||
|
||||
fallback_line = next(
|
||||
(
|
||||
line
|
||||
for line in requirement_lines
|
||||
if (
|
||||
(
|
||||
line.startswith(("-e ", "--editable ", "--editable="))
|
||||
and "#egg=" not in line
|
||||
)
|
||||
or (
|
||||
_parse_requirement_line(line) is None
|
||||
and looks_like_direct_reference(line)
|
||||
)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if fallback_line is not None:
|
||||
logger.warning(
|
||||
"预检查缺失依赖失败,将回退到完整安装: unresolved direct reference in %s: %s",
|
||||
requirements_path,
|
||||
fallback_line,
|
||||
)
|
||||
return False, None
|
||||
|
||||
return True, requirement_lines
|
||||
|
||||
|
||||
def find_missing_requirements(requirements_path: str) -> set[str] | None:
|
||||
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
|
||||
requirements_path
|
||||
)
|
||||
if not can_precheck or requirement_lines is None:
|
||||
return None
|
||||
|
||||
required = list(iter_requirements(lines=requirement_lines))
|
||||
if not required:
|
||||
return set()
|
||||
|
||||
installed = collect_installed_distribution_versions(get_requirement_check_paths())
|
||||
if installed is None:
|
||||
return None
|
||||
|
||||
missing: set[str] = set()
|
||||
for name, specifier in required:
|
||||
installed_version = installed.get(name)
|
||||
if not installed_version:
|
||||
missing.add(name)
|
||||
continue
|
||||
if specifier and not _specifier_contains_version(specifier, installed_version):
|
||||
missing.add(name)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def find_missing_requirements_or_raise(requirements_path: str) -> set[str]:
|
||||
missing = find_missing_requirements(requirements_path)
|
||||
if missing is None:
|
||||
raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}")
|
||||
return missing
|
||||
@@ -43,6 +43,7 @@ class SessionManagementRoute(Route):
|
||||
"/session/group/create": ("POST", self.create_group),
|
||||
"/session/group/update": ("POST", self.update_group),
|
||||
"/session/group/delete": ("POST", self.delete_group),
|
||||
"/session/group/update-config": ("POST", self.update_group_config),
|
||||
}
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.core_lifecycle = core_lifecycle
|
||||
@@ -145,9 +146,20 @@ class SessionManagementRoute(Route):
|
||||
page=page, page_size=page_size, search=search
|
||||
)
|
||||
|
||||
# 构建规则列表
|
||||
# 收集属于有配置分组的 UMO,避免重复显示
|
||||
grouped_umos = set()
|
||||
groups = self._get_groups()
|
||||
for group_data in groups.values():
|
||||
if group_data.get("config"):
|
||||
grouped_umos.update(group_data.get("umos", []))
|
||||
|
||||
# 构建规则列表(排除已被分组管理的 UMO)
|
||||
rules_list = []
|
||||
filtered_count = 0
|
||||
for umo, rules in umo_rules.items():
|
||||
if umo in grouped_umos:
|
||||
filtered_count += 1
|
||||
continue
|
||||
rule_info = {
|
||||
"umo": umo,
|
||||
"rules": rules,
|
||||
@@ -159,6 +171,7 @@ class SessionManagementRoute(Route):
|
||||
rule_info["message_type"] = parts[1]
|
||||
rule_info["session_id"] = parts[2]
|
||||
rules_list.append(rule_info)
|
||||
total -= filtered_count
|
||||
|
||||
# 获取可用的 providers 和 personas
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
@@ -240,6 +253,7 @@ class SessionManagementRoute(Route):
|
||||
"available_plugins": available_plugins,
|
||||
"available_kbs": available_kbs,
|
||||
"available_rule_keys": AVAILABLE_SESSION_RULE_KEYS,
|
||||
"group_rules": self._get_group_rules(),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
@@ -793,6 +807,51 @@ class SessionManagementRoute(Route):
|
||||
"""保存分组"""
|
||||
sp.put("session_groups", groups)
|
||||
|
||||
def _get_group_rules(self) -> list:
|
||||
"""获取有配置的分组列表,用于在规则列表中显示"""
|
||||
groups = self._get_groups()
|
||||
group_rules = []
|
||||
for group_id, group_data in groups.items():
|
||||
config = group_data.get("config", {})
|
||||
if config: # 只返回有配置的分组
|
||||
group_rules.append(
|
||||
{
|
||||
"group_id": group_id,
|
||||
"name": group_data.get("name", ""),
|
||||
"umo_count": len(group_data.get("umos", [])),
|
||||
"config": config,
|
||||
}
|
||||
)
|
||||
return group_rules
|
||||
|
||||
async def _sync_group_config_to_umos(
|
||||
self, config: dict, umos: list[str]
|
||||
) -> tuple[int, list[str]]:
|
||||
"""将分组配置同步到指定的 UMO 列表
|
||||
|
||||
Returns:
|
||||
(success_count, failed_umos)
|
||||
"""
|
||||
success_count = 0
|
||||
failed_umos = []
|
||||
for umo in umos:
|
||||
try:
|
||||
for rule_key, rule_value in config.items():
|
||||
if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
|
||||
continue
|
||||
if rule_value is None:
|
||||
continue
|
||||
if rule_key == "session_plugin_config":
|
||||
# session_plugin_config 需要包裹 umo key
|
||||
await sp.session_put(umo, rule_key, {umo: rule_value})
|
||||
else:
|
||||
await sp.session_put(umo, rule_key, rule_value)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"同步配置到 {umo} 失败: {e!s}")
|
||||
failed_umos.append(umo)
|
||||
return success_count, failed_umos
|
||||
|
||||
async def list_groups(self):
|
||||
"""获取所有分组列表"""
|
||||
try:
|
||||
@@ -806,6 +865,7 @@ class SessionManagementRoute(Route):
|
||||
"name": group_data.get("name", ""),
|
||||
"umos": group_data.get("umos", []),
|
||||
"umo_count": len(group_data.get("umos", [])),
|
||||
"config": group_data.get("config", {}),
|
||||
}
|
||||
)
|
||||
return Response().ok({"groups": groups_list}).__dict__
|
||||
@@ -875,6 +935,7 @@ class SessionManagementRoute(Route):
|
||||
return Response().error(f"分组 '{group_id}' 不存在").__dict__
|
||||
|
||||
group = groups[group_id]
|
||||
old_umos = set(group.get("umos", []))
|
||||
|
||||
# 更新名称
|
||||
if name is not None:
|
||||
@@ -883,6 +944,7 @@ class SessionManagementRoute(Route):
|
||||
# 直接设置 umos 列表
|
||||
if umos is not None:
|
||||
group["umos"] = umos
|
||||
new_umos = set(umos)
|
||||
else:
|
||||
# 增量更新
|
||||
current_umos = set(group.get("umos", []))
|
||||
@@ -891,9 +953,21 @@ class SessionManagementRoute(Route):
|
||||
if remove_umos:
|
||||
current_umos.difference_update(remove_umos)
|
||||
group["umos"] = list(current_umos)
|
||||
new_umos = current_umos
|
||||
|
||||
self._save_groups(groups)
|
||||
|
||||
# 自动同步分组配置给新加入的成员
|
||||
group_config = group.get("config", {})
|
||||
newly_added = new_umos - old_umos
|
||||
if group_config and newly_added:
|
||||
sync_count, _ = await self._sync_group_config_to_umos(
|
||||
group_config, list(newly_added)
|
||||
)
|
||||
logger.info(
|
||||
f"自动同步分组 '{group['name']}' 配置到 {sync_count} 个新成员"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
@@ -936,3 +1010,81 @@ class SessionManagementRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"删除分组失败: {e!s}")
|
||||
return Response().error(f"删除分组失败: {e!s}").__dict__
|
||||
|
||||
async def update_group_config(self):
|
||||
"""更新分组的配置,并同步到所有成员 UMO
|
||||
|
||||
请求体:
|
||||
{
|
||||
"group_id": "分组ID",
|
||||
"config": {
|
||||
"session_service_config": {...},
|
||||
"session_plugin_config": {...},
|
||||
"kb_config": {...},
|
||||
"provider_perf_chat_completion": ...,
|
||||
"provider_perf_speech_to_text": ...,
|
||||
"provider_perf_text_to_speech": ...
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
group_id = data.get("group_id")
|
||||
config = data.get("config", {})
|
||||
|
||||
if not group_id:
|
||||
return Response().error("缺少必要参数: group_id").__dict__
|
||||
|
||||
groups = self._get_groups()
|
||||
|
||||
if group_id not in groups:
|
||||
return Response().error(f"分组 '{group_id}' 不存在").__dict__
|
||||
|
||||
group = groups[group_id]
|
||||
|
||||
# 保存配置到分组
|
||||
group["config"] = config
|
||||
self._save_groups(groups)
|
||||
|
||||
# 同步到所有成员 UMO
|
||||
umos = group.get("umos", [])
|
||||
|
||||
if not config:
|
||||
# 空配置 → 清除成员上的所有分组下发规则
|
||||
success_count = 0
|
||||
failed_umos = []
|
||||
for umo in umos:
|
||||
try:
|
||||
for rule_key in AVAILABLE_SESSION_RULE_KEYS:
|
||||
try:
|
||||
await sp.session_remove(umo, rule_key)
|
||||
except Exception:
|
||||
pass
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"清除 {umo} 规则失败: {e!s}")
|
||||
failed_umos.append(umo)
|
||||
else:
|
||||
success_count, failed_umos = await self._sync_group_config_to_umos(
|
||||
config, umos
|
||||
)
|
||||
|
||||
msg = f"分组 '{group['name']}' 配置已保存并同步到 {success_count}/{len(umos)} 个会话"
|
||||
if failed_umos:
|
||||
msg += f",{len(failed_umos)} 个失败"
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": msg,
|
||||
"success_count": success_count,
|
||||
"failed_count": len(failed_umos),
|
||||
"failed_umos": failed_umos,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"更新分组配置失败: {e!s}")
|
||||
return Response().error(f"更新分组配置失败: {e!s}").__dict__
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<template>
|
||||
<template>
|
||||
<div class="session-management-page">
|
||||
<v-container fluid class="pa-0">
|
||||
<v-card flat>
|
||||
@@ -35,7 +35,16 @@
|
||||
<!-- UMO 信息 -->
|
||||
<template v-slot:item.umo_info="{ item }">
|
||||
<div>
|
||||
<div class="d-flex align-center">
|
||||
<div class="d-flex align-center" v-if="item.isGroup">
|
||||
<v-chip size="x-small" color="deep-purple" variant="flat" class="mr-2">
|
||||
分组
|
||||
</v-chip>
|
||||
<span class="font-weight-medium">{{ item.groupName }}</span>
|
||||
<v-chip size="x-small" variant="outlined" class="ml-2">
|
||||
{{ item.umo_count }} 个会话
|
||||
</v-chip>
|
||||
</div>
|
||||
<div class="d-flex align-center" v-else>
|
||||
<v-chip size="x-small" :color="getPlatformColor(item.platform)" class="mr-2">
|
||||
{{ item.platform || 'unknown' }}
|
||||
</v-chip>
|
||||
@@ -282,14 +291,24 @@
|
||||
{{ tm('addRule.description') }}
|
||||
</v-alert>
|
||||
|
||||
<v-autocomplete v-model="selectedNewUmo" :items="availableUmos" :loading="loadingUmos"
|
||||
<v-radio-group v-model="addRuleTargetType" inline hide-details class="mb-4">
|
||||
<v-radio label="单个会话" value="session"></v-radio>
|
||||
<v-radio label="分组" value="group" :disabled="groups.length === 0"></v-radio>
|
||||
</v-radio-group>
|
||||
|
||||
<v-autocomplete v-if="addRuleTargetType === 'session'" v-model="selectedNewUmo" :items="availableUmos" :loading="loadingUmos"
|
||||
:label="tm('addRule.selectUmo')" variant="outlined" clearable :no-data-text="tm('addRule.noUmos')" />
|
||||
|
||||
<v-select v-if="addRuleTargetType === 'group'" v-model="selectedGroup" :items="groupSelectOptions"
|
||||
item-title="label" item-value="value" return-object
|
||||
label="选择分组" variant="outlined" clearable
|
||||
:no-data-text="'暂无分组,请先创建分组'" />
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="px-4 pb-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="addRuleDialog = false">{{ tm('buttons.cancel') }}</v-btn>
|
||||
<v-btn color="primary" variant="tonal" @click="createNewRule" :disabled="!selectedNewUmo">
|
||||
<v-btn color="primary" variant="tonal" @click="createNewRule" :disabled="addRuleTargetType === 'session' ? !selectedNewUmo : !selectedGroup">
|
||||
{{ tm('buttons.next') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -334,12 +353,7 @@
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mt-4">
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="saveServiceConfig" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Provider Config Section -->
|
||||
<div class="d-flex align-center mb-4 mt-4">
|
||||
@@ -364,12 +378,7 @@
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mt-4">
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="saveProviderConfig" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Persona Config Section -->
|
||||
<div class="d-flex align-center mb-4 mt-4">
|
||||
@@ -389,12 +398,7 @@
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mt-4">
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="saveServiceConfig" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Plugin Config Section -->
|
||||
<div class="d-flex align-center mb-4 mt-4">
|
||||
@@ -414,12 +418,7 @@
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mt-4">
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="savePluginConfig" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- KB Config Section -->
|
||||
<div class="d-flex align-center mb-4 mt-4">
|
||||
@@ -442,14 +441,17 @@
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mt-4">
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="saveKbConfig" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</v-card-text>
|
||||
<v-card-actions class="px-6 pb-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="closeRuleEditor">{{ tm('buttons.cancel') }}</v-btn>
|
||||
<v-btn color="primary" variant="tonal" @click="saveAllConfigs" :loading="saving"
|
||||
prepend-icon="mdi-content-save">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
@@ -567,6 +569,8 @@ export default {
|
||||
addRuleDialog: false,
|
||||
availableUmos: [],
|
||||
selectedNewUmo: null,
|
||||
addRuleTargetType: 'session',
|
||||
selectedGroup: null,
|
||||
|
||||
// 规则编辑
|
||||
ruleDialog: false,
|
||||
@@ -729,6 +733,13 @@ export default {
|
||||
return options
|
||||
},
|
||||
|
||||
groupSelectOptions() {
|
||||
return this.groups.map(g => ({
|
||||
label: `${g.name} (${g.umo_count} 个会话)`,
|
||||
value: g,
|
||||
}))
|
||||
},
|
||||
|
||||
groupOptions() {
|
||||
return this.groups.map(g => ({
|
||||
label: `${g.name} (${g.umo_count} 个会话)`,
|
||||
@@ -811,7 +822,7 @@ export default {
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data
|
||||
this.rulesList = data.rules
|
||||
this.rulesList = data.rules || []
|
||||
this.totalItems = data.total
|
||||
this.availablePersonas = data.available_personas
|
||||
this.availableChatProviders = data.available_chat_providers
|
||||
@@ -819,6 +830,20 @@ export default {
|
||||
this.availableTtsProviders = data.available_tts_providers
|
||||
this.availablePlugins = data.available_plugins || []
|
||||
this.availableKbs = data.available_kbs || []
|
||||
|
||||
// 合并分组规则到列表中
|
||||
const groupRules = data.group_rules || []
|
||||
for (const gr of groupRules) {
|
||||
this.rulesList.unshift({
|
||||
umo: `[\u5206\u7ec4] ${gr.name}`,
|
||||
isGroup: true,
|
||||
groupId: gr.group_id,
|
||||
groupName: gr.name,
|
||||
umo_count: gr.umo_count,
|
||||
rules: gr.config || {},
|
||||
})
|
||||
}
|
||||
this.totalItems += groupRules.length
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadError'))
|
||||
}
|
||||
@@ -872,10 +897,89 @@ export default {
|
||||
async openAddRuleDialog() {
|
||||
this.addRuleDialog = true
|
||||
this.selectedNewUmo = null
|
||||
this.addRuleTargetType = 'session'
|
||||
this.selectedGroup = null
|
||||
await this.loadUmos()
|
||||
},
|
||||
|
||||
async saveAllConfigs() {
|
||||
if (!this.selectedUmo) return
|
||||
|
||||
// 分组模式:调用分组配置 API
|
||||
if (this.selectedUmo.isGroup) {
|
||||
this.saving = true
|
||||
try {
|
||||
const config = {
|
||||
session_service_config: { ...this.serviceConfig },
|
||||
provider_perf_chat_completion: this.providerConfig.chat_completion || null,
|
||||
provider_perf_speech_to_text: this.providerConfig.speech_to_text || null,
|
||||
provider_perf_text_to_speech: this.providerConfig.text_to_speech || null,
|
||||
session_plugin_config: { ...this.pluginConfig },
|
||||
kb_config: { ...this.kbConfig },
|
||||
}
|
||||
// 清理空值
|
||||
if (!config.session_service_config.custom_name) delete config.session_service_config.custom_name
|
||||
if (config.session_service_config.persona_id === null) delete config.session_service_config.persona_id
|
||||
|
||||
const response = await axios.post('/api/session/group/update-config', {
|
||||
group_id: this.selectedUmo.groupId,
|
||||
config: config
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(response.data.data?.message || '分组配置已保存并同步')
|
||||
await this.loadData()
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.saveError'))
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
} finally {
|
||||
this.saving = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 单个会话模式
|
||||
this.saving = true
|
||||
this._batchSaving = true
|
||||
try {
|
||||
await this.saveServiceConfig()
|
||||
await this.saveProviderConfig()
|
||||
await this.savePluginConfig()
|
||||
await this.saveKbConfig()
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
} finally {
|
||||
this._batchSaving = false
|
||||
this.saving = false
|
||||
}
|
||||
},
|
||||
|
||||
createNewRule() {
|
||||
if (this.addRuleTargetType === 'group') {
|
||||
// 分组模式
|
||||
if (!this.selectedGroup) return
|
||||
const group = this.selectedGroup.value || this.selectedGroup
|
||||
if (!group.umos || group.umos.length === 0) {
|
||||
this.showError('该分组没有成员会话')
|
||||
return
|
||||
}
|
||||
// 创建一个特殊的规则项,标记为分组
|
||||
const newItem = {
|
||||
umo: `[分组] ${group.name}`,
|
||||
isGroup: true,
|
||||
groupId: group.id,
|
||||
groupName: group.name,
|
||||
groupUmos: group.umos,
|
||||
rules: {},
|
||||
}
|
||||
this.addRuleDialog = false
|
||||
this.openRuleEditor(newItem)
|
||||
return
|
||||
}
|
||||
|
||||
// 单个会话模式(原逻辑)
|
||||
if (!this.selectedNewUmo) return
|
||||
|
||||
// 创建一个新的规则项并打开编辑器
|
||||
@@ -943,13 +1047,37 @@ export default {
|
||||
async saveServiceConfig() {
|
||||
if (!this.selectedUmo) return
|
||||
|
||||
this.saving = true
|
||||
if (!this._batchSaving) this.saving = true
|
||||
try {
|
||||
const config = { ...this.serviceConfig }
|
||||
// 清理空值
|
||||
if (!config.custom_name) delete config.custom_name
|
||||
if (config.persona_id === null) delete config.persona_id
|
||||
|
||||
// 分组模式:批量下发给所有成员
|
||||
if (this.selectedUmo.isGroup) {
|
||||
const umos = this.selectedUmo.groupUmos
|
||||
let successCount = 0
|
||||
for (const umo of umos) {
|
||||
try {
|
||||
await axios.post('/api/session/update-rule', {
|
||||
umo: umo,
|
||||
rule_key: 'session_service_config',
|
||||
rule_value: config
|
||||
})
|
||||
successCount++
|
||||
} catch (e) {
|
||||
console.error(`更新 ${umo} 失败:`, e)
|
||||
}
|
||||
}
|
||||
if (!this._batchSaving) {
|
||||
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的服务配置`)
|
||||
await this.loadData()
|
||||
this.saving = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const response = await axios.post('/api/session/update-rule', {
|
||||
umo: this.selectedUmo.umo,
|
||||
rule_key: 'session_service_config',
|
||||
@@ -957,7 +1085,7 @@ export default {
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
this.editingRules.session_service_config = config
|
||||
|
||||
// 更新或添加到列表
|
||||
@@ -980,17 +1108,45 @@ export default {
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
}
|
||||
this.saving = false
|
||||
if (!this._batchSaving) this.saving = false
|
||||
},
|
||||
|
||||
async saveProviderConfig() {
|
||||
if (!this.selectedUmo) return
|
||||
|
||||
this.saving = true
|
||||
if (!this._batchSaving) this.saving = true
|
||||
try {
|
||||
const providerTypes = ['chat_completion', 'speech_to_text', 'text_to_speech']
|
||||
|
||||
// 分组模式:批量下发给所有成员
|
||||
if (this.selectedUmo.isGroup) {
|
||||
const umos = this.selectedUmo.groupUmos
|
||||
let successCount = 0
|
||||
for (const umo of umos) {
|
||||
try {
|
||||
const tasks = []
|
||||
for (const type of providerTypes) {
|
||||
const value = this.providerConfig[type]
|
||||
if (value) {
|
||||
tasks.push(axios.post('/api/session/update-rule', { umo, rule_key: `provider_perf_${type}`, rule_value: value }))
|
||||
}
|
||||
}
|
||||
if (tasks.length > 0) await Promise.all(tasks)
|
||||
successCount++
|
||||
} catch (e) {
|
||||
console.error(`更新 ${umo} Provider 失败:`, e)
|
||||
}
|
||||
}
|
||||
if (!this._batchSaving) {
|
||||
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的 Provider 配置`)
|
||||
await this.loadData()
|
||||
this.saving = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const updateTasks = []
|
||||
const deleteTasks = []
|
||||
const providerTypes = ['chat_completion', 'speech_to_text', 'text_to_speech']
|
||||
|
||||
for (const type of providerTypes) {
|
||||
const value = this.providerConfig[type]
|
||||
@@ -1017,7 +1173,7 @@ export default {
|
||||
const allTasks = [...updateTasks, ...deleteTasks]
|
||||
if (allTasks.length > 0) {
|
||||
await Promise.all(allTasks)
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
|
||||
// 更新或添加到列表
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
@@ -1042,24 +1198,48 @@ export default {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.showSuccess(this.tm('messages.noChanges'))
|
||||
if (!this._batchSaving) this.showSuccess(this.tm('messages.noChanges'))
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
}
|
||||
this.saving = false
|
||||
if (!this._batchSaving) this.saving = false
|
||||
},
|
||||
|
||||
async savePluginConfig() {
|
||||
if (!this.selectedUmo) return
|
||||
|
||||
this.saving = true
|
||||
if (!this._batchSaving) this.saving = true
|
||||
try {
|
||||
const config = {
|
||||
enabled_plugins: this.pluginConfig.enabled_plugins,
|
||||
disabled_plugins: this.pluginConfig.disabled_plugins,
|
||||
}
|
||||
|
||||
// 分组模式:批量下发给所有成员
|
||||
if (this.selectedUmo.isGroup) {
|
||||
const umos = this.selectedUmo.groupUmos
|
||||
let successCount = 0
|
||||
for (const umo of umos) {
|
||||
try {
|
||||
if (config.enabled_plugins.length === 0 && config.disabled_plugins.length === 0) {
|
||||
await axios.post('/api/session/delete-rule', { umo, rule_key: 'session_plugin_config' })
|
||||
} else {
|
||||
await axios.post('/api/session/update-rule', { umo, rule_key: 'session_plugin_config', rule_value: config })
|
||||
}
|
||||
successCount++
|
||||
} catch (e) {
|
||||
console.error(`更新 ${umo} 插件配置失败:`, e)
|
||||
}
|
||||
}
|
||||
if (!this._batchSaving) {
|
||||
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的插件配置`)
|
||||
await this.loadData()
|
||||
this.saving = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 如果两个列表都为空,删除配置
|
||||
if (config.enabled_plugins.length === 0 && config.disabled_plugins.length === 0) {
|
||||
if (this.editingRules.session_plugin_config) {
|
||||
@@ -1071,7 +1251,7 @@ export default {
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
if (item) delete item.rules.session_plugin_config
|
||||
}
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
} else {
|
||||
const response = await axios.post('/api/session/update-rule', {
|
||||
umo: this.selectedUmo.umo,
|
||||
@@ -1080,7 +1260,7 @@ export default {
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
this.editingRules.session_plugin_config = config
|
||||
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
@@ -1102,13 +1282,13 @@ export default {
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
}
|
||||
this.saving = false
|
||||
if (!this._batchSaving) this.saving = false
|
||||
},
|
||||
|
||||
async saveKbConfig() {
|
||||
if (!this.selectedUmo) return
|
||||
|
||||
this.saving = true
|
||||
if (!this._batchSaving) this.saving = true
|
||||
try {
|
||||
const config = {
|
||||
kb_ids: this.kbConfig.kb_ids,
|
||||
@@ -1116,6 +1296,30 @@ export default {
|
||||
enable_rerank: this.kbConfig.enable_rerank,
|
||||
}
|
||||
|
||||
// 分组模式:批量下发给所有成员
|
||||
if (this.selectedUmo.isGroup) {
|
||||
const umos = this.selectedUmo.groupUmos
|
||||
let successCount = 0
|
||||
for (const umo of umos) {
|
||||
try {
|
||||
if (config.kb_ids.length === 0) {
|
||||
await axios.post('/api/session/delete-rule', { umo, rule_key: 'kb_config' })
|
||||
} else {
|
||||
await axios.post('/api/session/update-rule', { umo, rule_key: 'kb_config', rule_value: config })
|
||||
}
|
||||
successCount++
|
||||
} catch (e) {
|
||||
console.error(`更新 ${umo} 知识库配置失败:`, e)
|
||||
}
|
||||
}
|
||||
if (!this._batchSaving) {
|
||||
this.showSuccess(`已更新 ${successCount}/${umos.length} 个会话的知识库配置`)
|
||||
await this.loadData()
|
||||
this.saving = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 如果 kb_ids 为空,删除配置
|
||||
if (config.kb_ids.length === 0) {
|
||||
if (this.editingRules.kb_config) {
|
||||
@@ -1127,7 +1331,7 @@ export default {
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
if (item) delete item.rules.kb_config
|
||||
}
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
} else {
|
||||
const response = await axios.post('/api/session/update-rule', {
|
||||
umo: this.selectedUmo.umo,
|
||||
@@ -1136,7 +1340,7 @@ export default {
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
if (!this._batchSaving) this.showSuccess(this.tm('messages.saveSuccess'))
|
||||
this.editingRules.kb_config = config
|
||||
|
||||
let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo)
|
||||
@@ -1158,7 +1362,7 @@ export default {
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'))
|
||||
}
|
||||
this.saving = false
|
||||
if (!this._batchSaving) this.saving = false
|
||||
},
|
||||
|
||||
confirmDeleteRules(item) {
|
||||
@@ -1171,6 +1375,24 @@ export default {
|
||||
|
||||
this.deleting = true
|
||||
try {
|
||||
// 分组规则:清空分组配置
|
||||
if (this.deleteTarget.isGroup) {
|
||||
const response = await axios.post('/api/session/group/update-config', {
|
||||
group_id: this.deleteTarget.groupId,
|
||||
config: {}
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess('分组配置已清除')
|
||||
this.deleteDialog = false
|
||||
this.deleteTarget = null
|
||||
await this.loadData()
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.deleteError'))
|
||||
}
|
||||
this.deleting = false
|
||||
return
|
||||
}
|
||||
|
||||
const response = await axios.post('/api/session/delete-rule', {
|
||||
umo: this.deleteTarget.umo
|
||||
})
|
||||
|
||||
@@ -16,6 +16,7 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.utils.pip_installer import PipInstallError
|
||||
from astrbot.dashboard.routes.plugin import PluginRoute
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from tests.fixtures.helpers import (
|
||||
@@ -359,6 +360,35 @@ async def test_do_update(
|
||||
assert os.path.exists(release_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_pip_package_returns_pip_install_error_message(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
monkeypatch,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
async def mock_pip_install(*args, **kwargs):
|
||||
del args, kwargs
|
||||
raise PipInstallError("install failed", code=2)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.pip_installer.install",
|
||||
mock_pip_install,
|
||||
)
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/update/pip-install",
|
||||
headers=authenticated_header,
|
||||
json={"package": "demo-package"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "error"
|
||||
assert data["message"] == "install failed"
|
||||
|
||||
|
||||
class _FakeNeoSkills:
|
||||
async def list_candidates(self, **kwargs):
|
||||
_ = kwargs
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.utils import core_constraints as core_constraints_module
|
||||
from astrbot.core.utils import requirements_utils
|
||||
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
|
||||
|
||||
|
||||
def test_requirements_utils_parse_package_install_input_collects_specs_and_names():
|
||||
parsed = requirements_utils.parse_package_install_input(
|
||||
"--index-url https://example.com/simple demo-package\nanother-package>=1.0\n"
|
||||
)
|
||||
|
||||
assert parsed.specs == (
|
||||
"--index-url",
|
||||
"https://example.com/simple",
|
||||
"demo-package",
|
||||
"another-package>=1.0",
|
||||
)
|
||||
assert parsed.requirement_names == {"demo-package", "another-package"}
|
||||
|
||||
|
||||
def test_core_constraints_provider_writes_constraints_file_from_fallback_distribution(
|
||||
monkeypatch,
|
||||
):
|
||||
class FakeFallbackDistribution:
|
||||
metadata = {"Name": "AstrBot-App"}
|
||||
requires = ["shared-lib>=1.0"]
|
||||
|
||||
def read_text(self, name):
|
||||
if name == "top_level.txt":
|
||||
return "astrbot\n"
|
||||
return ""
|
||||
|
||||
fake_distribution = FakeFallbackDistribution()
|
||||
|
||||
def mock_distribution(name):
|
||||
if name == "AstrBot":
|
||||
raise core_constraints_module.importlib_metadata.PackageNotFoundError
|
||||
if name == "AstrBot-App":
|
||||
return fake_distribution
|
||||
raise core_constraints_module.importlib_metadata.PackageNotFoundError
|
||||
|
||||
def mock_distributions(path=None):
|
||||
del path
|
||||
return [fake_distribution]
|
||||
|
||||
monkeypatch.setattr(
|
||||
core_constraints_module.importlib_metadata,
|
||||
"distribution",
|
||||
mock_distribution,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_constraints_module.importlib_metadata,
|
||||
"distributions",
|
||||
mock_distributions,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_constraints_module,
|
||||
"collect_installed_distribution_versions",
|
||||
lambda paths: {"shared-lib": "2.0"},
|
||||
)
|
||||
|
||||
core_constraints_module._get_core_constraints.cache_clear()
|
||||
try:
|
||||
provider = CoreConstraintsProvider(None)
|
||||
with provider.constraints_file() as constraints_path:
|
||||
assert constraints_path is not None
|
||||
assert (
|
||||
Path(constraints_path).read_text(encoding="utf-8") == "shared-lib==2.0"
|
||||
)
|
||||
finally:
|
||||
core_constraints_module._get_core_constraints.cache_clear()
|
||||
|
||||
|
||||
def test_resolve_core_dist_name_skips_distribution_without_name(monkeypatch):
|
||||
class NamelessDistribution:
|
||||
metadata = {}
|
||||
|
||||
def read_text(self, name):
|
||||
if name == "top_level.txt":
|
||||
return "astrbot\n"
|
||||
return ""
|
||||
|
||||
class NamedDistribution:
|
||||
metadata = {"Name": "AstrBot-App"}
|
||||
|
||||
def read_text(self, name):
|
||||
if name == "top_level.txt":
|
||||
return "astrbot\n"
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(
|
||||
core_constraints_module.importlib_metadata,
|
||||
"distribution",
|
||||
lambda name: (_ for _ in ()).throw(
|
||||
core_constraints_module.importlib_metadata.PackageNotFoundError
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_constraints_module.importlib_metadata,
|
||||
"distributions",
|
||||
lambda: [NamelessDistribution(), NamedDistribution()],
|
||||
)
|
||||
|
||||
assert core_constraints_module._resolve_core_dist_name(None) == "AstrBot-App"
|
||||
|
||||
|
||||
def test_find_missing_requirements_returns_none_when_precheck_gate_fails(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
requirements_path = tmp_path / "requirements.txt"
|
||||
requirements_path.write_text("demo-package\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
requirements_utils,
|
||||
"_load_requirement_lines_for_precheck",
|
||||
lambda path: (False, None),
|
||||
)
|
||||
|
||||
missing = requirements_utils.find_missing_requirements(str(requirements_path))
|
||||
|
||||
assert missing is None
|
||||
|
||||
|
||||
def test_parse_package_install_input_tracks_only_named_direct_references():
|
||||
named = requirements_utils.parse_package_install_input(
|
||||
"git+https://example.com/demo.git#egg=demo-package"
|
||||
)
|
||||
unnamed = requirements_utils.parse_package_install_input(
|
||||
"git+https://example.com/demo.git"
|
||||
)
|
||||
|
||||
assert named.requirement_names == {"demo-package"}
|
||||
assert unnamed.requirement_names == set()
|
||||
|
||||
|
||||
def test_find_missing_requirements_or_raise_uses_requirements_exception(tmp_path):
|
||||
requirements_path = tmp_path / "requirements.txt"
|
||||
requirements_path.write_text("-e ../sharedlib\n", encoding="utf-8")
|
||||
|
||||
with pytest.raises(requirements_utils.RequirementsPrecheckFailed):
|
||||
requirements_utils.find_missing_requirements_or_raise(str(requirements_path))
|
||||
|
||||
|
||||
def test_find_missing_requirements_logs_path_and_reason_on_precheck_fallback(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
requirements_path = tmp_path / "requirements.txt"
|
||||
requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8")
|
||||
warning_logs = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.utils.requirements_utils.logger.warning",
|
||||
lambda line, *args: warning_logs.append(line % args if args else line),
|
||||
)
|
||||
|
||||
missing = requirements_utils.find_missing_requirements(str(requirements_path))
|
||||
|
||||
assert missing is None
|
||||
assert any(str(requirements_path) in log for log in warning_logs)
|
||||
assert any("direct reference" in log for log in warning_logs)
|
||||
|
||||
|
||||
def test_load_requirement_lines_for_precheck_uses_parse_requirement_line_result(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
requirements_path = tmp_path / "requirements.txt"
|
||||
requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
requirements_utils,
|
||||
"_parse_requirement_line",
|
||||
lambda line: ("demo-package", None) if line.startswith("git+") else None,
|
||||
)
|
||||
|
||||
can_precheck, requirement_lines = (
|
||||
requirements_utils._load_requirement_lines_for_precheck(str(requirements_path))
|
||||
)
|
||||
|
||||
assert can_precheck is True
|
||||
assert requirement_lines == ["git+https://example.com/demo.git"]
|
||||
|
||||
|
||||
def test_collect_installed_distribution_versions_skips_nameless_distribution(
|
||||
monkeypatch,
|
||||
):
|
||||
class NamelessDistribution:
|
||||
metadata = {}
|
||||
version = "1.0"
|
||||
|
||||
class NamedDistribution:
|
||||
metadata = {"Name": "demo-package"}
|
||||
version = "2.0"
|
||||
|
||||
monkeypatch.setattr(
|
||||
requirements_utils.importlib_metadata,
|
||||
"distributions",
|
||||
lambda path: [NamelessDistribution(), NamedDistribution()],
|
||||
)
|
||||
|
||||
installed = requirements_utils.collect_installed_distribution_versions(
|
||||
["/tmp/test"]
|
||||
)
|
||||
|
||||
assert installed == {"demo-package": "2.0"}
|
||||
|
||||
|
||||
def test_get_core_constraints_logs_resolution_step_context(monkeypatch):
|
||||
warning_logs = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
core_constraints_module,
|
||||
"_resolve_core_dist_name",
|
||||
lambda core_dist_name: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.utils.core_constraints.logger.warning",
|
||||
lambda line, *args: warning_logs.append(line % args if args else line),
|
||||
)
|
||||
|
||||
core_constraints_module._get_core_constraints.cache_clear()
|
||||
try:
|
||||
constraints = core_constraints_module._get_core_constraints(None)
|
||||
finally:
|
||||
core_constraints_module._get_core_constraints.cache_clear()
|
||||
|
||||
assert constraints == ()
|
||||
assert any("解析核心分发名称失败" in log for log in warning_logs)
|
||||
|
||||
|
||||
def test_iter_requirements_supports_direct_line_input():
|
||||
parsed = list(
|
||||
requirements_utils.iter_requirements(
|
||||
lines=["demo-package>=1.0", 'other-package; sys_platform == "win32"']
|
||||
)
|
||||
)
|
||||
|
||||
assert parsed == [
|
||||
("demo-package", requirements_utils.Requirement("demo-package>=1.0").specifier)
|
||||
]
|
||||
|
||||
|
||||
def test_parse_requirement_name_and_spec_preserves_direct_reference_rules():
|
||||
named = requirements_utils._parse_requirement_name_and_spec(
|
||||
"git+https://example.com/demo.git#egg=demo-package"
|
||||
)
|
||||
unnamed = requirements_utils._parse_requirement_name_and_spec(
|
||||
"git+https://example.com/demo.git"
|
||||
)
|
||||
|
||||
assert named == ("demo-package", None)
|
||||
assert unnamed == (None, None)
|
||||
|
||||
|
||||
def test_parse_requirement_name_and_spec_handles_plain_requirement_token():
|
||||
parsed = requirements_utils._parse_requirement_name_and_spec("demo-package>=1.0")
|
||||
|
||||
assert parsed == (
|
||||
"demo-package",
|
||||
requirements_utils.Requirement("demo-package>=1.0").specifier,
|
||||
)
|
||||
+1304
-1
File diff suppressed because it is too large
Load Diff
+426
-189
@@ -1,235 +1,472 @@
|
||||
import sys
|
||||
from asyncio import Queue
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import yaml
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star import star_map, star_registry
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.star.star_manager import PluginDependencyInstallError, PluginManager
|
||||
from astrbot.core.utils.pip_installer import PipInstallError
|
||||
|
||||
# --- Test Data & Helpers ---
|
||||
|
||||
def _clear_module_cache() -> None:
|
||||
"""Clear module cache for data module tree to ensure test isolation."""
|
||||
modules_to_remove = [
|
||||
key for key in sys.modules if key == "data" or key.startswith("data.")
|
||||
]
|
||||
for key in modules_to_remove:
|
||||
del sys.modules[key]
|
||||
|
||||
|
||||
def _clear_registry(plugin_name: str) -> None:
|
||||
"""Clear plugin from global registries."""
|
||||
# Clear star_registry (list)
|
||||
star_registry[:] = [md for md in star_registry if md.name != plugin_name]
|
||||
# Clear star_map (dict)
|
||||
keys_to_remove = [
|
||||
key for key, md in star_map.items() if md.name == plugin_name
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del star_map[key]
|
||||
# Clear star_handlers_registry (StarHandlerRegistry)
|
||||
for handler in list(star_handlers_registry):
|
||||
if plugin_name in (handler.handler_module_path or ""):
|
||||
star_handlers_registry.remove(handler)
|
||||
|
||||
TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld"
|
||||
TEST_PLUGIN_DIR = "helloworld"
|
||||
TEST_PLUGIN_NAME = "helloworld"
|
||||
TEST_PLUGIN_REPO = "https://github.com/AstrBotDevs/astrbot_plugin_helloworld"
|
||||
TEST_PLUGIN_DIR = "helloworld"
|
||||
|
||||
|
||||
def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None:
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
(plugin_dir / "metadata.yaml").write_text(
|
||||
"\n".join(
|
||||
[
|
||||
f"name: {TEST_PLUGIN_NAME}",
|
||||
"author: AstrBot Team",
|
||||
"desc: Local test plugin",
|
||||
"version: 1.0.0",
|
||||
f"repo: {repo_url}",
|
||||
],
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(plugin_dir / "main.py").write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"from astrbot.api import star",
|
||||
"",
|
||||
"class Main(star.Star):",
|
||||
" pass",
|
||||
"",
|
||||
],
|
||||
),
|
||||
encoding="utf-8",
|
||||
class MockStar:
|
||||
def __init__(self):
|
||||
self.root_dir_name = TEST_PLUGIN_DIR
|
||||
self.name = TEST_PLUGIN_NAME
|
||||
self.repo = TEST_PLUGIN_REPO
|
||||
self.reserved = False
|
||||
self.info = {"repo": TEST_PLUGIN_REPO, "readme": ""}
|
||||
|
||||
|
||||
def _write_local_test_plugin(plugin_path: Path, repo_url: str):
|
||||
"""Creates a minimal valid plugin structure."""
|
||||
plugin_path.mkdir(parents=True, exist_ok=True)
|
||||
metadata = {
|
||||
"name": TEST_PLUGIN_NAME,
|
||||
"repo": repo_url,
|
||||
"version": "1.0.0",
|
||||
"author": "AstrBot Team",
|
||||
"desc": "Local test plugin",
|
||||
}
|
||||
with open(plugin_path / "info.yaml", "w", encoding="utf-8") as f:
|
||||
yaml.dump(metadata, f)
|
||||
with open(plugin_path / "main.py", "w", encoding="utf-8") as f:
|
||||
f.write("from astrbot.api.star import Star, Context, StarManager\n")
|
||||
f.write("@StarManager.register\n")
|
||||
f.write("class HelloWorld(Star):\n")
|
||||
f.write(" def __init__(self, context: Context): ...\n")
|
||||
|
||||
|
||||
def _write_requirements(plugin_path: Path):
|
||||
"""Creates a requirements.txt file."""
|
||||
with open(plugin_path / "requirements.txt", "w", encoding="utf-8") as f:
|
||||
f.write("networkx\n")
|
||||
|
||||
|
||||
def _clear_module_cache():
|
||||
"""Clear test-specific modules from sys.modules to allow reloading."""
|
||||
import sys
|
||||
|
||||
to_del = [m for m in sys.modules if m.startswith("data.plugins.helloworld")]
|
||||
for m in to_del:
|
||||
del sys.modules[m]
|
||||
|
||||
|
||||
def _build_load_mock(events):
|
||||
async def mock_load(specified_dir_name=None, ignore_version_check=False):
|
||||
del ignore_version_check
|
||||
events.append(("load", specified_dir_name or TEST_PLUGIN_DIR))
|
||||
return True, ""
|
||||
|
||||
return mock_load
|
||||
|
||||
|
||||
def _build_reload_mock(events):
|
||||
async def mock_reload(specified_dir_name=None):
|
||||
events.append(("reload", specified_dir_name or TEST_PLUGIN_DIR))
|
||||
return True, ""
|
||||
|
||||
return mock_reload
|
||||
|
||||
|
||||
def _build_dependency_install_mock(events, fail: bool):
|
||||
async def mock_install_requirements(
|
||||
*, requirements_path: str = None, package_name: str = None, **kwargs
|
||||
):
|
||||
del kwargs
|
||||
if requirements_path:
|
||||
events.append(("deps", str(requirements_path)))
|
||||
if package_name:
|
||||
events.append(("deps_pkg", package_name))
|
||||
if fail:
|
||||
raise Exception("pip failed")
|
||||
|
||||
return mock_install_requirements
|
||||
|
||||
|
||||
def _mock_missing_requirements(monkeypatch, missing: set[str]):
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.find_missing_requirements_or_raise",
|
||||
lambda requirements_path: missing,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def plugin_manager_pm(tmp_path, monkeypatch):
|
||||
def _mock_precheck_fails(monkeypatch):
|
||||
from astrbot.core import RequirementsPrecheckFailed
|
||||
|
||||
def mock_fail(requirements_path):
|
||||
raise RequirementsPrecheckFailed("mock precheck failure")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.find_missing_requirements_or_raise",
|
||||
mock_fail,
|
||||
)
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_manager_pm(tmp_path, monkeypatch):
|
||||
"""Provides a fully isolated PluginManager instance for testing."""
|
||||
# Clear module cache before setup to ensure isolation
|
||||
_clear_module_cache()
|
||||
|
||||
test_root = tmp_path / "astrbot_root"
|
||||
data_dir = test_root / "data"
|
||||
plugin_dir = data_dir / "plugins"
|
||||
config_dir = data_dir / "config"
|
||||
temp_dir = data_dir / "temp"
|
||||
for path in (plugin_dir, config_dir, temp_dir):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
plugin_dir = tmp_path / "astrbot_root" / "data" / "plugins"
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Ensure `import data.plugins.<plugin>.main` resolves to this temp root.
|
||||
(data_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(plugin_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
class MockContext:
|
||||
def __init__(self):
|
||||
self.stars = []
|
||||
|
||||
# Use monkeypatch for both env var and sys.path to ensure proper cleanup
|
||||
monkeypatch.setenv("ASTRBOT_ROOT", str(test_root))
|
||||
monkeypatch.syspath_prepend(str(test_root))
|
||||
def get_all_stars(self):
|
||||
return self.stars
|
||||
|
||||
# Create fresh, isolated instances for the context
|
||||
event_queue = Queue()
|
||||
config = AstrBotConfig()
|
||||
db = SQLiteDatabase(str(data_dir / "test_db.db"))
|
||||
config.plugin_store_path = str(plugin_dir)
|
||||
def get_registered_star(self, name):
|
||||
for s in self.stars:
|
||||
if s.root_dir_name == name or s.name == name:
|
||||
return s
|
||||
return None
|
||||
|
||||
provider_manager = MagicMock()
|
||||
platform_manager = MagicMock()
|
||||
conversation_manager = MagicMock()
|
||||
message_history_manager = MagicMock()
|
||||
persona_manager = MagicMock()
|
||||
persona_manager.personas_v3 = []
|
||||
astrbot_config_mgr = MagicMock()
|
||||
knowledge_base_manager = MagicMock()
|
||||
cron_manager = MagicMock()
|
||||
mock_context = MockContext()
|
||||
mock_config = {}
|
||||
pm = PluginManager(mock_context, mock_config)
|
||||
|
||||
star_context = Context(
|
||||
event_queue=event_queue,
|
||||
config=config,
|
||||
db=db,
|
||||
provider_manager=provider_manager,
|
||||
platform_manager=platform_manager,
|
||||
conversation_manager=conversation_manager,
|
||||
message_history_manager=message_history_manager,
|
||||
persona_manager=persona_manager,
|
||||
astrbot_config_mgr=astrbot_config_mgr,
|
||||
knowledge_base_manager=knowledge_base_manager,
|
||||
cron_manager=cron_manager,
|
||||
subagent_orchestrator=None,
|
||||
# Patch paths to use tmp_path
|
||||
monkeypatch.setattr(pm, "plugin_store_path", str(plugin_dir))
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.get_astrbot_plugin_path",
|
||||
lambda: str(plugin_dir),
|
||||
)
|
||||
|
||||
manager = PluginManager(star_context, config)
|
||||
try:
|
||||
yield manager
|
||||
finally:
|
||||
# Cleanup global registries and module cache
|
||||
_clear_registry(TEST_PLUGIN_NAME)
|
||||
_clear_module_cache()
|
||||
await db.engine.dispose()
|
||||
return pm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_updator(plugin_manager_pm: PluginManager, monkeypatch):
|
||||
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||
def local_updator(plugin_manager_pm):
|
||||
"""Helper to setup a local plugin directory simulating a download."""
|
||||
path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||
_write_local_test_plugin(path, TEST_PLUGIN_REPO)
|
||||
return path
|
||||
|
||||
async def mock_install(repo_url: str, proxy=""): # noqa: ARG001
|
||||
if repo_url != TEST_PLUGIN_REPO:
|
||||
raise Exception("Repo not found")
|
||||
|
||||
# --- Tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||
async def test_install_plugin_dependency_install_flow(
|
||||
plugin_manager_pm: PluginManager, monkeypatch, dependency_install_fails: bool
|
||||
):
|
||||
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||
events = []
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
|
||||
async def mock_install(repo_url: str, proxy=""):
|
||||
assert repo_url == TEST_PLUGIN_REPO
|
||||
_write_local_test_plugin(plugin_path, repo_url)
|
||||
_write_requirements(plugin_path)
|
||||
return str(plugin_path)
|
||||
|
||||
async def mock_update(plugin, proxy=""): # noqa: ARG001
|
||||
if plugin.name != TEST_PLUGIN_NAME:
|
||||
raise Exception("Plugin not found")
|
||||
if not plugin_path.exists():
|
||||
raise Exception("Plugin path missing")
|
||||
(plugin_path / ".updated").write_text("ok", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
||||
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
|
||||
return plugin_path
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
_build_dependency_install_mock(events, dependency_install_fails),
|
||||
)
|
||||
|
||||
def mock_load_and_register(*args, **kwargs):
|
||||
plugin_manager_pm.context.stars.append(MockStar())
|
||||
return _build_load_mock(events)(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||
|
||||
if dependency_install_fails:
|
||||
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||
assert events == [("deps", str(plugin_path / "requirements.txt"))]
|
||||
else:
|
||||
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||
assert events == [
|
||||
("deps", str(plugin_path / "requirements.txt")),
|
||||
("load", TEST_PLUGIN_DIR),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||
assert plugin_manager_pm is not None
|
||||
assert plugin_manager_pm.context is not None
|
||||
assert plugin_manager_pm.config is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
success, err_message = await plugin_manager_pm.reload()
|
||||
assert success is True
|
||||
assert err_message is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
||||
"""Tests successful plugin installation without external network."""
|
||||
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||
assert plugin_info is not None
|
||||
assert plugin_info["name"] == TEST_PLUGIN_NAME
|
||||
assert local_updator.exists()
|
||||
assert any(md.name == TEST_PLUGIN_NAME for md in star_registry)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_nonexistent_plugin(
|
||||
plugin_manager_pm: PluginManager, local_updator
|
||||
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||
async def test_install_plugin_from_file_dependency_install_flow(
|
||||
plugin_manager_pm: PluginManager,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
dependency_install_fails: bool,
|
||||
):
|
||||
"""Tests that installing a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.install_plugin(
|
||||
"https://github.com/Soulter/non_existent_repo"
|
||||
zip_file_path = tmp_path / f"{TEST_PLUGIN_DIR}.zip"
|
||||
zip_file_path.write_text("placeholder", encoding="utf-8")
|
||||
events = []
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
|
||||
def mock_unzip_file(zip_path: str, target_dir: str) -> None:
|
||||
assert zip_path == str(zip_file_path)
|
||||
plugin_path = Path(target_dir)
|
||||
_write_local_test_plugin(plugin_path, TEST_PLUGIN_REPO)
|
||||
_write_requirements(plugin_path)
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm.updator, "unzip_file", mock_unzip_file)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
_build_dependency_install_mock(events, dependency_install_fails),
|
||||
)
|
||||
|
||||
def mock_load_and_register(*args, **kwargs):
|
||||
plugin_manager_pm.context.stars.append(MockStar())
|
||||
return _build_load_mock(events)(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||
|
||||
if dependency_install_fails:
|
||||
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||
await plugin_manager_pm.install_plugin_from_file(str(zip_file_path))
|
||||
assert any(e[0] == "deps" for e in events)
|
||||
else:
|
||||
await plugin_manager_pm.install_plugin_from_file(str(zip_file_path))
|
||||
assert any(e[0] == "deps" for e in events)
|
||||
assert ("load", TEST_PLUGIN_DIR) in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||
async def test_reload_failed_plugin_dependency_install_flow(
|
||||
plugin_manager_pm: PluginManager,
|
||||
local_updator: Path,
|
||||
monkeypatch,
|
||||
dependency_install_fails: bool,
|
||||
):
|
||||
_write_requirements(local_updator)
|
||||
plugin_manager_pm.failed_plugin_dict[TEST_PLUGIN_DIR] = {"error": "init fail"}
|
||||
events = []
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
_build_dependency_install_mock(events, dependency_install_fails),
|
||||
)
|
||||
|
||||
def mock_load_and_register(*args, **kwargs):
|
||||
plugin_manager_pm.context.stars.append(MockStar())
|
||||
return _build_load_mock(events)(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||
|
||||
if dependency_install_fails:
|
||||
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||
await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR)
|
||||
assert events == [("deps", str(local_updator / "requirements.txt"))]
|
||||
else:
|
||||
await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR)
|
||||
assert events == [
|
||||
("deps", str(local_updator / "requirements.txt")),
|
||||
("load", TEST_PLUGIN_DIR),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_plugin_requirements_reraises_cancelled_error(
|
||||
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||
):
|
||||
_write_requirements(local_updator)
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
|
||||
async def mock_install_requirements(*args, **kwargs):
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
mock_install_requirements,
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await plugin_manager_pm._ensure_plugin_requirements(
|
||||
str(local_updator),
|
||||
TEST_PLUGIN_DIR,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
||||
"""Tests updating an existing plugin without external network."""
|
||||
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||
assert plugin_info is not None
|
||||
plugin_name = plugin_info["name"]
|
||||
await plugin_manager_pm.update_plugin(plugin_name)
|
||||
assert (local_updator / ".updated").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_plugin(
|
||||
plugin_manager_pm: PluginManager, local_updator
|
||||
async def test_ensure_plugin_requirements_wraps_generic_dependency_install_failure(
|
||||
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||
):
|
||||
"""Tests that updating a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("non_existent_plugin")
|
||||
_write_requirements(local_updator)
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
|
||||
async def mock_install_requirements(*args, **kwargs):
|
||||
raise RuntimeError("pip failed")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
||||
"""Tests successful plugin uninstallation."""
|
||||
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||
assert plugin_info is not None
|
||||
plugin_name = plugin_info["name"]
|
||||
assert local_updator.exists()
|
||||
|
||||
await plugin_manager_pm.uninstall_plugin(plugin_name)
|
||||
|
||||
assert not local_updator.exists()
|
||||
assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry)
|
||||
assert not any(
|
||||
TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
mock_install_requirements,
|
||||
)
|
||||
|
||||
with pytest.raises(PluginDependencyInstallError, match="pip failed") as exc_info:
|
||||
await plugin_manager_pm._ensure_plugin_requirements(
|
||||
str(local_updator),
|
||||
TEST_PLUGIN_DIR,
|
||||
)
|
||||
|
||||
assert exc_info.value.plugin_label == TEST_PLUGIN_DIR
|
||||
assert exc_info.value.requirements_path == str(local_updator / "requirements.txt")
|
||||
assert isinstance(exc_info.value.__cause__, RuntimeError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests that uninstalling a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")
|
||||
async def test_ensure_plugin_requirements_wraps_pip_install_error(
|
||||
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||
):
|
||||
_write_requirements(local_updator)
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
|
||||
async def mock_install_requirements(*args, **kwargs):
|
||||
raise PipInstallError("install failed", code=2)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
mock_install_requirements,
|
||||
)
|
||||
|
||||
with pytest.raises(PluginDependencyInstallError, match="install failed") as exc_info:
|
||||
await plugin_manager_pm._ensure_plugin_requirements(
|
||||
str(local_updator),
|
||||
TEST_PLUGIN_DIR,
|
||||
)
|
||||
|
||||
assert isinstance(exc_info.value.__cause__, PipInstallError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_plugin_requirements_logs_requirements_file_install_for_missing_dependencies(
|
||||
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||
):
|
||||
_write_requirements(local_updator)
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
logged_lines = []
|
||||
|
||||
async def mock_install_requirements(*args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
mock_install_requirements,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.logger.info",
|
||||
lambda line, *args: logged_lines.append(line % args if args else line),
|
||||
)
|
||||
|
||||
await plugin_manager_pm._ensure_plugin_requirements(
|
||||
str(local_updator),
|
||||
TEST_PLUGIN_DIR,
|
||||
)
|
||||
|
||||
assert any("按 requirements.txt 安装" in line for line in logged_lines)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||
async def test_update_plugin_dependency_install_flow(
|
||||
plugin_manager_pm: PluginManager,
|
||||
local_updator: Path,
|
||||
monkeypatch,
|
||||
dependency_install_fails: bool,
|
||||
):
|
||||
mock_star = MockStar()
|
||||
plugin_manager_pm.context.stars.append(mock_star)
|
||||
|
||||
_write_requirements(local_updator)
|
||||
events = []
|
||||
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||
|
||||
async def mock_update(plugin, proxy=""):
|
||||
del proxy
|
||||
events.append(("update", plugin.name))
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
_build_dependency_install_mock(events, dependency_install_fails),
|
||||
)
|
||||
monkeypatch.setattr(plugin_manager_pm, "reload", _build_reload_mock(events))
|
||||
|
||||
if dependency_install_fails:
|
||||
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||
await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME)
|
||||
assert ("deps", str(local_updator / "requirements.txt")) in events
|
||||
else:
|
||||
await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME)
|
||||
assert ("deps", str(local_updator / "requirements.txt")) in events
|
||||
assert ("reload", TEST_PLUGIN_DIR) in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_plugin_skips_dependency_install_when_no_requirements_missing(
|
||||
plugin_manager_pm: PluginManager, monkeypatch
|
||||
):
|
||||
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||
events = []
|
||||
_mock_missing_requirements(monkeypatch, set())
|
||||
|
||||
async def mock_install(repo_url: str, proxy=""):
|
||||
_write_local_test_plugin(plugin_path, repo_url)
|
||||
_write_requirements(plugin_path)
|
||||
return str(plugin_path)
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
_build_dependency_install_mock(events, False),
|
||||
)
|
||||
|
||||
def mock_load_and_register(*args, **kwargs):
|
||||
plugin_manager_pm.context.stars.append(MockStar())
|
||||
return _build_load_mock(events)(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||
|
||||
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||
|
||||
assert "deps" not in [e[0] for e in events]
|
||||
assert ("load", TEST_PLUGIN_DIR) in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_plugin_runs_dependency_install_when_precheck_fails(
|
||||
plugin_manager_pm: PluginManager, monkeypatch
|
||||
):
|
||||
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||
events = []
|
||||
|
||||
async def mock_install(repo_url: str, proxy=""):
|
||||
_write_local_test_plugin(plugin_path, repo_url)
|
||||
_write_requirements(plugin_path)
|
||||
return str(plugin_path)
|
||||
|
||||
_mock_precheck_fails(monkeypatch)
|
||||
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.star.star_manager.pip_installer.install",
|
||||
_build_dependency_install_mock(events, False),
|
||||
)
|
||||
|
||||
def mock_load_and_register(*args, **kwargs):
|
||||
plugin_manager_pm.context.stars.append(MockStar())
|
||||
return _build_load_mock(events)(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||
|
||||
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||
|
||||
assert ("deps", str(plugin_path / "requirements.txt")) in events
|
||||
assert ("load", TEST_PLUGIN_DIR) in events
|
||||
|
||||
Reference in New Issue
Block a user