486 lines
14 KiB
Python
486 lines
14 KiB
Python
import importlib.metadata as importlib_metadata
|
|
import logging
|
|
import os
|
|
import re
|
|
import shlex
|
|
import sys
|
|
from collections.abc import Iterable, Iterator, Sequence
|
|
from dataclasses import dataclass
|
|
|
|
from packaging.requirements import InvalidRequirement, Requirement
|
|
from packaging.specifiers import SpecifierSet
|
|
from packaging.version import InvalidVersion, Version
|
|
|
|
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
|
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
|
|
|
logger = logging.getLogger("astrbot")
|
|
|
|
|
|
class RequirementsPrecheckFailed(Exception):
|
|
"""Raised when the pre-check of requirements fails."""
|
|
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ParsedPackageInput:
|
|
specs: tuple[str, ...]
|
|
requirement_names: frozenset[str]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MissingRequirementsPlan:
|
|
missing_names: frozenset[str]
|
|
install_lines: tuple[str, ...]
|
|
fallback_reason: str | None = None
|
|
|
|
|
|
def canonicalize_distribution_name(name: str) -> str:
|
|
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
|
|
|
|
|
|
def strip_inline_requirement_comment(raw_input: str) -> str:
|
|
if raw_input.lstrip().startswith("#"):
|
|
return ""
|
|
return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip()
|
|
|
|
|
|
def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool:
|
|
try:
|
|
parsed_version = Version(version)
|
|
except InvalidVersion:
|
|
return False
|
|
return specifier.contains(parsed_version, prereleases=True)
|
|
|
|
|
|
def _looks_like_local_path_reference(token: str) -> bool:
|
|
candidate = token.strip()
|
|
if not candidate:
|
|
return False
|
|
return candidate in {".", ".."} or candidate.startswith(
|
|
("./", "../", "/", "~/", ".\\", "..\\", "\\")
|
|
)
|
|
|
|
|
|
def looks_like_direct_reference(token: str) -> bool:
|
|
candidate = token.strip()
|
|
if not candidate:
|
|
return False
|
|
return (
|
|
_looks_like_local_path_reference(candidate)
|
|
or candidate.startswith("git+")
|
|
or "://" in candidate
|
|
)
|
|
|
|
|
|
def extract_requirement_name(raw_requirement: str) -> str | None:
|
|
line = raw_requirement.split("#", 1)[0].strip()
|
|
if not line:
|
|
return None
|
|
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
|
|
return None
|
|
|
|
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
|
|
if egg_match:
|
|
return canonicalize_distribution_name(egg_match.group(1))
|
|
|
|
if line.startswith("-"):
|
|
return None
|
|
|
|
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
|
|
if not candidate:
|
|
return None
|
|
return canonicalize_distribution_name(candidate)
|
|
|
|
|
|
def _parse_editable_or_direct_name(target: str) -> str | None:
|
|
name = extract_requirement_name(target)
|
|
if not name:
|
|
return None
|
|
if "#egg=" in target or not looks_like_direct_reference(target):
|
|
return name
|
|
return None
|
|
|
|
|
|
def _parse_requirement_name_and_spec(
|
|
line: str,
|
|
) -> tuple[str | None, SpecifierSet | None]:
|
|
if line.startswith(("-c", "--constraint")):
|
|
return None, None
|
|
|
|
try:
|
|
req = Requirement(line)
|
|
except InvalidRequirement:
|
|
tokens = shlex.split(line)
|
|
if not tokens:
|
|
return None, None
|
|
|
|
editable_target: str | None = None
|
|
if tokens[0] in {"-e", "--editable"} and len(tokens) > 1:
|
|
editable_target = tokens[1]
|
|
elif tokens[0].startswith("--editable="):
|
|
editable_target = tokens[0].split("=", 1)[1]
|
|
|
|
if editable_target:
|
|
name = _parse_editable_or_direct_name(editable_target)
|
|
return (name, None) if name else (None, None)
|
|
|
|
name = _parse_editable_or_direct_name(line)
|
|
return (name, None) if name else (None, None)
|
|
|
|
if req.marker and not req.marker.evaluate():
|
|
return None, None
|
|
|
|
return canonicalize_distribution_name(req.name), (req.specifier or None)
|
|
|
|
|
|
def _parse_requirement_line(
|
|
line: str,
|
|
) -> tuple[str, SpecifierSet | None] | None:
|
|
name, specifier = _parse_requirement_name_and_spec(line)
|
|
return (name, specifier) if name else None
|
|
|
|
|
|
def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]:
|
|
requirement_names: set[str] = set()
|
|
skip_next_for: str | None = None
|
|
|
|
for token in tokens:
|
|
if skip_next_for:
|
|
if skip_next_for == "editable":
|
|
name = _parse_editable_or_direct_name(token)
|
|
if name:
|
|
requirement_names.add(name)
|
|
skip_next_for = None
|
|
continue
|
|
|
|
if token in {"-e", "--editable"}:
|
|
skip_next_for = "editable"
|
|
continue
|
|
|
|
if token in {
|
|
"-i",
|
|
"--index-url",
|
|
"--extra-index-url",
|
|
"-f",
|
|
"--find-links",
|
|
"--trusted-host",
|
|
"-r",
|
|
"--requirement",
|
|
"-c",
|
|
"--constraint",
|
|
}:
|
|
skip_next_for = "option-value"
|
|
continue
|
|
|
|
if token.startswith(("--editable=",)):
|
|
editable_target = token.split("=", 1)[1]
|
|
name = _parse_editable_or_direct_name(editable_target)
|
|
if name:
|
|
requirement_names.add(name)
|
|
continue
|
|
|
|
if token.startswith(
|
|
(
|
|
"--index-url=",
|
|
"--extra-index-url=",
|
|
"--find-links=",
|
|
"--trusted-host=",
|
|
"--requirement=",
|
|
"--constraint=",
|
|
)
|
|
):
|
|
continue
|
|
|
|
if (
|
|
(token.startswith("-i") and token != "-i")
|
|
or (token.startswith("-f") and token != "-f")
|
|
or token == "--no-index"
|
|
):
|
|
continue
|
|
|
|
if token.startswith("-"):
|
|
continue
|
|
|
|
name, _ = _parse_requirement_name_and_spec(token)
|
|
if name:
|
|
requirement_names.add(name)
|
|
|
|
return frozenset(requirement_names)
|
|
|
|
|
|
def parse_package_install_input(raw_input: str) -> ParsedPackageInput:
|
|
specs: list[str] = []
|
|
requirement_names: set[str] = set()
|
|
normalized = raw_input.strip()
|
|
if not normalized:
|
|
return ParsedPackageInput(specs=(), requirement_names=frozenset())
|
|
|
|
for raw_line in normalized.splitlines():
|
|
line = strip_inline_requirement_comment(raw_line)
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
Requirement(line)
|
|
except InvalidRequirement:
|
|
tokens = shlex.split(line)
|
|
if not tokens:
|
|
continue
|
|
specs.extend(tokens)
|
|
requirement_names.update(
|
|
_extract_requirement_names_from_package_tokens(tokens)
|
|
)
|
|
continue
|
|
|
|
specs.append(line)
|
|
name, _ = _parse_requirement_name_and_spec(line)
|
|
if name:
|
|
requirement_names.add(name)
|
|
|
|
return ParsedPackageInput(
|
|
specs=tuple(specs),
|
|
requirement_names=frozenset(requirement_names),
|
|
)
|
|
|
|
|
|
def _iter_requirement_lines(
|
|
requirements_path: str,
|
|
_visited: set[str] | None = None,
|
|
) -> Iterator[str]:
|
|
visited = _visited or set()
|
|
resolved_path = os.path.realpath(requirements_path)
|
|
if resolved_path in visited:
|
|
logger.warning(
|
|
"检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path
|
|
)
|
|
return
|
|
visited.add(resolved_path)
|
|
|
|
with open(resolved_path, encoding="utf-8") as f:
|
|
for raw_line in f:
|
|
line = strip_inline_requirement_comment(raw_line)
|
|
if not line:
|
|
continue
|
|
|
|
tokens = shlex.split(line)
|
|
if not tokens:
|
|
continue
|
|
|
|
nested: str | None = None
|
|
if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1:
|
|
nested = tokens[1]
|
|
elif tokens[0].startswith("--requirement="):
|
|
nested = tokens[0].split("=", 1)[1]
|
|
|
|
if nested:
|
|
if not os.path.isabs(nested):
|
|
nested = os.path.join(os.path.dirname(resolved_path), nested)
|
|
yield from _iter_requirement_lines(nested, _visited=visited)
|
|
continue
|
|
|
|
yield line
|
|
|
|
|
|
def iter_requirements(
|
|
requirements_path: str | None = None,
|
|
lines: Iterable[str] | None = None,
|
|
) -> Iterator[tuple[str, SpecifierSet | None]]:
|
|
if lines is None:
|
|
if requirements_path is None:
|
|
raise ValueError("Either requirements_path or lines must be provided")
|
|
lines = _iter_requirement_lines(requirements_path)
|
|
|
|
for line in lines:
|
|
parsed = _parse_requirement_line(line)
|
|
if parsed is not None:
|
|
yield parsed
|
|
|
|
|
|
def extract_requirement_names(requirements_path: str) -> set[str]:
|
|
try:
|
|
return {
|
|
name for name, _ in iter_requirements(requirements_path=requirements_path)
|
|
}
|
|
except Exception as exc:
|
|
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
|
|
return set()
|
|
|
|
|
|
def get_requirement_check_paths() -> list[str]:
|
|
paths = list(sys.path)
|
|
if is_packaged_desktop_runtime():
|
|
target_site_packages = get_astrbot_site_packages_path()
|
|
if os.path.isdir(target_site_packages):
|
|
paths.insert(0, target_site_packages)
|
|
return paths
|
|
|
|
|
|
def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
|
|
distribution_name = (
|
|
distribution.metadata["Name"] if "Name" in distribution.metadata else None
|
|
)
|
|
if not distribution_name:
|
|
return None, None
|
|
return canonicalize_distribution_name(distribution_name), distribution.version
|
|
|
|
|
|
def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None:
|
|
installed: dict[str, str] = {}
|
|
try:
|
|
for distribution in importlib_metadata.distributions(path=paths):
|
|
distribution_name, version = _canonical_distribution_identity(distribution)
|
|
if not distribution_name or not version:
|
|
continue
|
|
installed.setdefault(distribution_name, version)
|
|
except Exception as exc:
|
|
logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc)
|
|
return None
|
|
return installed
|
|
|
|
|
|
def _load_requirement_lines_for_precheck(
|
|
requirements_path: str,
|
|
) -> tuple[bool, list[str] | None]:
|
|
try:
|
|
requirement_lines = list(_iter_requirement_lines(requirements_path))
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"预检查缺失依赖失败,将回退到完整安装: %s (%s)",
|
|
requirements_path,
|
|
exc,
|
|
)
|
|
return False, None
|
|
|
|
fallback_line = next(
|
|
(
|
|
line
|
|
for line in requirement_lines
|
|
if (
|
|
(
|
|
line.startswith(("-e ", "--editable ", "--editable="))
|
|
and "#egg=" not in line
|
|
)
|
|
or (
|
|
_parse_requirement_line(line) is None
|
|
and looks_like_direct_reference(line)
|
|
)
|
|
)
|
|
),
|
|
None,
|
|
)
|
|
if fallback_line is not None:
|
|
logger.info(
|
|
"缺失依赖预检查发现无法安全裁剪的 option/direct-reference 行,将回退到完整安装: %s (%s)",
|
|
requirements_path,
|
|
fallback_line,
|
|
)
|
|
return False, None
|
|
|
|
return True, requirement_lines
|
|
|
|
|
|
def find_missing_requirements(requirements_path: str) -> set[str] | None:
|
|
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
|
|
requirements_path
|
|
)
|
|
if not can_precheck or requirement_lines is None:
|
|
return None
|
|
|
|
return find_missing_requirements_from_lines(requirement_lines)
|
|
|
|
|
|
def find_missing_requirements_from_lines(
|
|
requirement_lines: Sequence[str],
|
|
) -> set[str] | None:
|
|
required = list(iter_requirements(lines=requirement_lines))
|
|
if not required:
|
|
return set()
|
|
|
|
installed = collect_installed_distribution_versions(get_requirement_check_paths())
|
|
if installed is None:
|
|
return None
|
|
|
|
missing: set[str] = set()
|
|
for name, specifier in required:
|
|
installed_version = installed.get(name)
|
|
if not installed_version:
|
|
missing.add(name)
|
|
continue
|
|
if specifier and not _specifier_contains_version(specifier, installed_version):
|
|
missing.add(name)
|
|
|
|
return missing
|
|
|
|
|
|
def build_missing_requirements_install_lines(
|
|
requirements_path: str,
|
|
requirement_lines: Sequence[str],
|
|
missing_names: set[str] | frozenset[str],
|
|
) -> tuple[str, ...] | None:
|
|
wanted_names = set(missing_names)
|
|
install_lines: list[str] = []
|
|
for line in requirement_lines:
|
|
parsed = _parse_requirement_line(line)
|
|
if parsed is None:
|
|
if looks_like_direct_reference(line) or line.startswith(("-", "--")):
|
|
logger.debug(
|
|
"缺失依赖行筛选回退到完整安装:requirements 中包含无法安全裁剪的 option/direct-reference 行: %s (%s)",
|
|
requirements_path,
|
|
line,
|
|
)
|
|
return None
|
|
continue
|
|
|
|
name, _specifier = parsed
|
|
if name in wanted_names:
|
|
install_lines.append(line)
|
|
|
|
return tuple(install_lines)
|
|
|
|
|
|
def plan_missing_requirements_install(
|
|
requirements_path: str,
|
|
) -> MissingRequirementsPlan | None:
|
|
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
|
|
requirements_path
|
|
)
|
|
if not can_precheck or requirement_lines is None:
|
|
return None
|
|
|
|
missing = find_missing_requirements_from_lines(requirement_lines)
|
|
if missing is None:
|
|
return None
|
|
|
|
install_lines = build_missing_requirements_install_lines(
|
|
requirements_path,
|
|
requirement_lines,
|
|
missing,
|
|
)
|
|
if install_lines is None:
|
|
return None
|
|
if missing and not install_lines:
|
|
logger.warning(
|
|
"预检查缺失依赖成功,但无法映射到可安装 requirement 行,将回退到完整安装: %s -> %s",
|
|
requirements_path,
|
|
sorted(missing),
|
|
)
|
|
return MissingRequirementsPlan(
|
|
missing_names=frozenset(missing),
|
|
install_lines=(),
|
|
fallback_reason="unmapped missing requirement names",
|
|
)
|
|
|
|
return MissingRequirementsPlan(
|
|
missing_names=frozenset(missing),
|
|
install_lines=install_lines,
|
|
)
|
|
|
|
|
|
def find_missing_requirements_or_raise(requirements_path: str) -> set[str]:
|
|
missing = find_missing_requirements(requirements_path)
|
|
if missing is None:
|
|
raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}")
|
|
return missing
|