Compare commits
7 Commits
v4.19.5
...
feat/sendtool
| Author | SHA1 | Date | |
|---|---|---|---|
| 8faaa4b2be | |||
| 7f5bd942b3 | |||
| e254caf82d | |||
| 7efcd242d6 | |||
| 5d811d3949 | |||
| 8e6aaee10c | |||
| 6da59cfb07 |
@@ -61,3 +61,5 @@ GenieData/
|
|||||||
.codex/
|
.codex/
|
||||||
.opencode/
|
.opencode/
|
||||||
.kilocode/
|
.kilocode/
|
||||||
|
.worktrees/
|
||||||
|
docs/plans/
|
||||||
|
|||||||
@@ -234,7 +234,8 @@ pre-commit install
|
|||||||
- Group 7: 743746109
|
- Group 7: 743746109
|
||||||
- Group 8: 1030353265
|
- Group 8: 1030353265
|
||||||
|
|
||||||
- Developer Group: 975206796
|
- Developer Group(Chit-chat): 975206796
|
||||||
|
- Developer Group(Formal): 1039761811
|
||||||
|
|
||||||
### Discord Server
|
### Discord Server
|
||||||
|
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ pre-commit install
|
|||||||
- Groupe 5 : 822130018
|
- Groupe 5 : 822130018
|
||||||
- Groupe 6 : 753075035
|
- Groupe 6 : 753075035
|
||||||
- Groupe développeurs : 975206796
|
- Groupe développeurs : 975206796
|
||||||
|
- Groupe développeurs (officiel) : 1039761811
|
||||||
|
|
||||||
### Serveur Discord
|
### Serveur Discord
|
||||||
|
|
||||||
|
|||||||
@@ -223,6 +223,7 @@ pre-commit install
|
|||||||
- 5群: 822130018
|
- 5群: 822130018
|
||||||
- 6群: 753075035
|
- 6群: 753075035
|
||||||
- 開発者群: 975206796
|
- 開発者群: 975206796
|
||||||
|
- 開発者群(正式): 1039761811
|
||||||
|
|
||||||
### Discord サーバー
|
### Discord サーバー
|
||||||
|
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ pre-commit install
|
|||||||
- Группа 5: 822130018
|
- Группа 5: 822130018
|
||||||
- Группа 6: 753075035
|
- Группа 6: 753075035
|
||||||
- Группа разработчиков: 975206796
|
- Группа разработчиков: 975206796
|
||||||
|
- Группа разработчиков (официальная): 1039761811
|
||||||
|
|
||||||
### Сервер Discord
|
### Сервер Discord
|
||||||
|
|
||||||
|
|||||||
+2
-1
@@ -225,7 +225,8 @@ pre-commit install
|
|||||||
- 6 群:753075035
|
- 6 群:753075035
|
||||||
- 7 群:743746109
|
- 7 群:743746109
|
||||||
- 8 群:1030353265
|
- 8 群:1030353265
|
||||||
- 開發者群:975206796
|
- 開發者群(闲聊吹水):975206796
|
||||||
|
- 開發者群(正式):1039761811
|
||||||
|
|
||||||
### Discord 群組
|
### Discord 群組
|
||||||
|
|
||||||
|
|||||||
+2
-1
@@ -226,7 +226,8 @@ pre-commit install
|
|||||||
- 6 群:753075035
|
- 6 群:753075035
|
||||||
- 7 群:743746109
|
- 7 群:743746109
|
||||||
- 8 群:1030353265
|
- 8 群:1030353265
|
||||||
- 开发者群:975206796
|
- 开发者群(偏闲聊吹水):975206796
|
||||||
|
- 开发者群(正式):1039761811
|
||||||
|
|
||||||
### Discord 频道
|
### Discord 频道
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,21 @@ from astrbot.core.config import AstrBotConfig
|
|||||||
from astrbot.core.config.default import DB_PATH
|
from astrbot.core.config.default import DB_PATH
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
from astrbot.core.file_token_service import FileTokenService
|
from astrbot.core.file_token_service import FileTokenService
|
||||||
from astrbot.core.utils.pip_installer import PipInstaller
|
from astrbot.core.utils.pip_installer import (
|
||||||
|
DependencyConflictError as DependencyConflictError,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.pip_installer import (
|
||||||
|
PipInstaller,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.requirements_utils import (
|
||||||
|
RequirementsPrecheckFailed as RequirementsPrecheckFailed,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.requirements_utils import (
|
||||||
|
find_missing_requirements as find_missing_requirements,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.requirements_utils import (
|
||||||
|
find_missing_requirements_or_raise as find_missing_requirements_or_raise,
|
||||||
|
)
|
||||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||||
|
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": (
|
"description": (
|
||||||
"Component type. One of: "
|
"Component type. One of: "
|
||||||
"plain, image, record, file, mention_user"
|
"plain, image, record, video, file, mention_user. Record is voice message."
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"text": {
|
"text": {
|
||||||
@@ -320,6 +320,19 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
|||||||
components.append(Comp.Record.fromURL(url=url))
|
components.append(Comp.Record.fromURL(url=url))
|
||||||
else:
|
else:
|
||||||
return f"error: messages[{idx}] must include path or url for record component."
|
return f"error: messages[{idx}] must include path or url for record component."
|
||||||
|
elif msg_type == "video":
|
||||||
|
path = msg.get("path")
|
||||||
|
url = msg.get("url")
|
||||||
|
if path:
|
||||||
|
(
|
||||||
|
local_path,
|
||||||
|
file_from_sandbox,
|
||||||
|
) = await self._resolve_path_from_sandbox(context, path)
|
||||||
|
components.append(Comp.Video.fromFileSystem(path=local_path))
|
||||||
|
elif url:
|
||||||
|
components.append(Comp.Video.fromURL(url=url))
|
||||||
|
else:
|
||||||
|
return f"error: messages[{idx}] must include path or url for video component."
|
||||||
elif msg_type == "file":
|
elif msg_type == "file":
|
||||||
path = msg.get("path")
|
path = msg.get("path")
|
||||||
url = msg.get("url")
|
url = msg.get("url")
|
||||||
|
|||||||
@@ -422,6 +422,12 @@ async def get_booter(
|
|||||||
) -> ComputerBooter:
|
) -> ComputerBooter:
|
||||||
config = context.get_config(umo=session_id)
|
config = context.get_config(umo=session_id)
|
||||||
|
|
||||||
|
runtime = config.get("provider_settings", {}).get("computer_use_runtime", "local")
|
||||||
|
if runtime == "local":
|
||||||
|
return get_local_booter()
|
||||||
|
elif runtime == "none":
|
||||||
|
raise RuntimeError("Sandbox runtime is disabled by configuration.")
|
||||||
|
|
||||||
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
|
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
|
||||||
booter_type = sandbox_cfg.get("booter", "shipyard_neo")
|
booter_type = sandbox_cfg.get("booter", "shipyard_neo")
|
||||||
|
|
||||||
|
|||||||
@@ -219,6 +219,9 @@ DEFAULT_CONFIG = {
|
|||||||
"telegram": {
|
"telegram": {
|
||||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
||||||
},
|
},
|
||||||
|
"discord": {
|
||||||
|
"pre_ack_emoji": {"enable": False, "emojis": ["🤔"]},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"wake_prefix": ["/"],
|
"wake_prefix": ["/"],
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
|
|||||||
@@ -14,7 +14,12 @@ import yaml
|
|||||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||||
from packaging.version import InvalidVersion, Version
|
from packaging.version import InvalidVersion, Version
|
||||||
|
|
||||||
from astrbot.core import logger, pip_installer, sp
|
from astrbot.core import (
|
||||||
|
DependencyConflictError,
|
||||||
|
logger,
|
||||||
|
pip_installer,
|
||||||
|
sp,
|
||||||
|
)
|
||||||
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
|
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from astrbot.core.config.default import VERSION
|
from astrbot.core.config.default import VERSION
|
||||||
@@ -27,6 +32,10 @@ from astrbot.core.utils.astrbot_path import (
|
|||||||
)
|
)
|
||||||
from astrbot.core.utils.io import remove_dir
|
from astrbot.core.utils.io import remove_dir
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
|
from astrbot.core.utils.requirements_utils import (
|
||||||
|
RequirementsPrecheckFailed,
|
||||||
|
find_missing_requirements_or_raise,
|
||||||
|
)
|
||||||
|
|
||||||
from . import StarMetadata
|
from . import StarMetadata
|
||||||
from .command_management import sync_command_configs
|
from .command_management import sync_command_configs
|
||||||
@@ -48,6 +57,49 @@ class PluginVersionIncompatibleError(Exception):
|
|||||||
"""Raised when plugin astrbot_version is incompatible with current AstrBot."""
|
"""Raised when plugin astrbot_version is incompatible with current AstrBot."""
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDependencyInstallError(Exception):
|
||||||
|
"""Raised when plugin dependency installation fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
plugin_label: str,
|
||||||
|
requirements_path: str,
|
||||||
|
error: Exception,
|
||||||
|
) -> None:
|
||||||
|
message = f"插件 {plugin_label} 依赖安装失败: {error!s}"
|
||||||
|
super().__init__(message)
|
||||||
|
self.plugin_label = plugin_label
|
||||||
|
self.requirements_path = requirements_path
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
|
||||||
|
async def _install_requirements_with_precheck(
|
||||||
|
*,
|
||||||
|
plugin_label: str,
|
||||||
|
requirements_path: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
missing = find_missing_requirements_or_raise(requirements_path)
|
||||||
|
except RequirementsPrecheckFailed:
|
||||||
|
logger.info(
|
||||||
|
f"正在安装插件 {plugin_label} 的依赖库(预检查失败,回退到完整安装): "
|
||||||
|
f"{requirements_path}"
|
||||||
|
)
|
||||||
|
await pip_installer.install(requirements_path=requirements_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not missing:
|
||||||
|
logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: "
|
||||||
|
f"{requirements_path} -> {sorted(missing)}"
|
||||||
|
)
|
||||||
|
await pip_installer.install(requirements_path=requirements_path)
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
||||||
from .star_tools import StarTools
|
from .star_tools import StarTools
|
||||||
@@ -198,15 +250,37 @@ class PluginManager:
|
|||||||
to_update.append(p.root_dir_name)
|
to_update.append(p.root_dir_name)
|
||||||
for p in to_update:
|
for p in to_update:
|
||||||
plugin_path = os.path.join(plugin_dir, p)
|
plugin_path = os.path.join(plugin_dir, p)
|
||||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
await self._ensure_plugin_requirements(plugin_path, p)
|
||||||
pth = os.path.join(plugin_path, "requirements.txt")
|
|
||||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
|
||||||
try:
|
|
||||||
await pip_installer.install(requirements_path=pth)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _ensure_plugin_requirements(
|
||||||
|
self,
|
||||||
|
plugin_dir_path: str,
|
||||||
|
plugin_label: str,
|
||||||
|
) -> None:
|
||||||
|
requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
|
||||||
|
if not os.path.exists(requirements_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _install_requirements_with_precheck(
|
||||||
|
plugin_label=plugin_label,
|
||||||
|
requirements_path=requirements_path,
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except DependencyConflictError as e:
|
||||||
|
logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
dependency_error = PluginDependencyInstallError(
|
||||||
|
plugin_label=plugin_label,
|
||||||
|
requirements_path=requirements_path,
|
||||||
|
error=e,
|
||||||
|
)
|
||||||
|
logger.exception(str(dependency_error))
|
||||||
|
raise dependency_error from e
|
||||||
|
|
||||||
async def _import_plugin_with_dependency_recovery(
|
async def _import_plugin_with_dependency_recovery(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -422,7 +496,7 @@ class PluginManager:
|
|||||||
root_dir_name: str,
|
root_dir_name: str,
|
||||||
plugin_dir_path: str,
|
plugin_dir_path: str,
|
||||||
reserved: bool,
|
reserved: bool,
|
||||||
error: Exception | str,
|
error: BaseException | str,
|
||||||
error_trace: str,
|
error_trace: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
record: dict = {
|
record: dict = {
|
||||||
@@ -495,6 +569,9 @@ class PluginManager:
|
|||||||
|
|
||||||
self._cleanup_plugin_state(dir_name)
|
self._cleanup_plugin_state(dir_name)
|
||||||
|
|
||||||
|
plugin_path = os.path.join(self.plugin_store_path, dir_name)
|
||||||
|
await self._ensure_plugin_requirements(plugin_path, dir_name)
|
||||||
|
|
||||||
success, error = await self.load(specified_dir_name=dir_name)
|
success, error = await self.load(specified_dir_name=dir_name)
|
||||||
if success:
|
if success:
|
||||||
self.failed_plugin_dict.pop(dir_name, None)
|
self.failed_plugin_dict.pop(dir_name, None)
|
||||||
@@ -1078,6 +1155,10 @@ class PluginManager:
|
|||||||
|
|
||||||
# reload the plugin
|
# reload the plugin
|
||||||
dir_name = os.path.basename(plugin_path)
|
dir_name = os.path.basename(plugin_path)
|
||||||
|
await self._ensure_plugin_requirements(
|
||||||
|
plugin_path,
|
||||||
|
dir_name,
|
||||||
|
)
|
||||||
success, error_message = await self.load(
|
success, error_message = await self.load(
|
||||||
specified_dir_name=dir_name,
|
specified_dir_name=dir_name,
|
||||||
ignore_version_check=ignore_version_check,
|
ignore_version_check=ignore_version_check,
|
||||||
@@ -1317,6 +1398,12 @@ class PluginManager:
|
|||||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||||
|
|
||||||
await self.updator.update(plugin, proxy=proxy)
|
await self.updator.update(plugin, proxy=proxy)
|
||||||
|
if plugin.root_dir_name:
|
||||||
|
plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||||
|
await self._ensure_plugin_requirements(
|
||||||
|
plugin_dir_path,
|
||||||
|
plugin_name,
|
||||||
|
)
|
||||||
await self.reload(plugin_name)
|
await self.reload(plugin_name)
|
||||||
|
|
||||||
async def turn_off_plugin(self, plugin_name: str) -> None:
|
async def turn_off_plugin(self, plugin_name: str) -> None:
|
||||||
@@ -1488,6 +1575,7 @@ class PluginManager:
|
|||||||
os.remove(zip_file_path)
|
os.remove(zip_file_path)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.warning(f"删除插件压缩包失败: {e!s}")
|
logger.warning(f"删除插件压缩包失败: {e!s}")
|
||||||
|
await self._ensure_plugin_requirements(desti_dir, dir_name)
|
||||||
# await self.reload()
|
# await self.reload()
|
||||||
success, error_message = await self.load(
|
success, error_message = await self.load(
|
||||||
specified_dir_name=dir_name,
|
specified_dir_name=dir_name,
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import importlib.metadata as importlib_metadata
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from packaging.requirements import Requirement
|
||||||
|
|
||||||
|
from astrbot.core.utils.requirements_utils import (
|
||||||
|
canonicalize_distribution_name,
|
||||||
|
collect_installed_distribution_versions,
|
||||||
|
get_requirement_check_paths,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_core_dist_name(core_dist_name: str | None) -> str | None:
|
||||||
|
if core_dist_name:
|
||||||
|
try:
|
||||||
|
importlib_metadata.distribution(core_dist_name)
|
||||||
|
return core_dist_name
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
importlib_metadata.distribution("AstrBot")
|
||||||
|
return "AstrBot"
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not __package__:
|
||||||
|
return None
|
||||||
|
|
||||||
|
top_pkg = __package__.split(".")[0]
|
||||||
|
for dist in importlib_metadata.distributions():
|
||||||
|
try:
|
||||||
|
top_level = dist.read_text("top_level.txt") or ""
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if top_pkg in top_level.splitlines():
|
||||||
|
if "Name" in dist.metadata:
|
||||||
|
return dist.metadata["Name"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]:
|
||||||
|
try:
|
||||||
|
resolved_core_dist_name = _resolve_core_dist_name(core_dist_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("解析核心分发名称失败: %s", exc)
|
||||||
|
return ()
|
||||||
|
|
||||||
|
if not resolved_core_dist_name:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
try:
|
||||||
|
dist = importlib_metadata.distribution(resolved_core_dist_name)
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
return ()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc)
|
||||||
|
return ()
|
||||||
|
|
||||||
|
if not dist or not dist.requires:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
installed = collect_installed_distribution_versions(get_requirement_check_paths())
|
||||||
|
if not installed:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
constraints: list[str] = []
|
||||||
|
for req_str in dist.requires:
|
||||||
|
try:
|
||||||
|
req = Requirement(req_str)
|
||||||
|
if req.marker and not req.marker.evaluate():
|
||||||
|
continue
|
||||||
|
name = canonicalize_distribution_name(req.name)
|
||||||
|
if name in installed:
|
||||||
|
constraints.append(f"{name}=={installed[name]}")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return tuple(constraints)
|
||||||
|
|
||||||
|
|
||||||
|
class CoreConstraintsProvider:
|
||||||
|
def __init__(self, core_dist_name: str | None) -> None:
|
||||||
|
self._core_dist_name = core_dist_name
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def constraints_file(self) -> Iterator[str | None]:
|
||||||
|
constraints = _get_core_constraints(self._core_dist_name)
|
||||||
|
if not constraints:
|
||||||
|
yield None
|
||||||
|
return
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
try:
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("\n".join(constraints))
|
||||||
|
path = f.name
|
||||||
|
logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("创建临时约束文件失败: %s", exc)
|
||||||
|
yield None
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield path
|
||||||
|
finally:
|
||||||
|
if path and os.path.exists(path):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
os.remove(path)
|
||||||
@@ -7,21 +7,71 @@ import io
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
||||||
|
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
|
||||||
|
from astrbot.core.utils.requirements_utils import (
|
||||||
|
canonicalize_distribution_name as _canonicalize_distribution_name,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.requirements_utils import (
|
||||||
|
extract_requirement_name,
|
||||||
|
extract_requirement_names,
|
||||||
|
parse_package_install_input,
|
||||||
|
)
|
||||||
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
||||||
|
|
||||||
logger = logging.getLogger("astrbot")
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
|
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
|
||||||
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
|
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
|
||||||
|
_PIP_FAILURE_PATTERNS = {
|
||||||
|
"error_prefix": re.compile(r"^\s*error:", re.IGNORECASE),
|
||||||
|
"user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE),
|
||||||
|
"resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE),
|
||||||
|
"cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE),
|
||||||
|
"conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE),
|
||||||
|
"constraint": re.compile(r"\(constraint\)", re.IGNORECASE),
|
||||||
|
"dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE),
|
||||||
|
}
|
||||||
|
_SENSITIVE_PIP_VALUE_KEYS = frozenset(
|
||||||
|
{"password", "passwd", "pass", "api_token", "token", "auth_token"}
|
||||||
|
)
|
||||||
|
_MAX_PIP_OUTPUT_LINES = 200
|
||||||
|
|
||||||
|
|
||||||
def _canonicalize_distribution_name(name: str) -> str:
|
class DependencyConflictError(Exception):
|
||||||
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
|
"""Raised when pip encounters a dependency conflict."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, message: str, errors: list[str], *, is_core_conflict: bool
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.errors = errors
|
||||||
|
self.is_core_conflict = is_core_conflict
|
||||||
|
|
||||||
|
|
||||||
|
class PipInstallError(Exception):
|
||||||
|
"""Raised when pip install fails without a classified dependency conflict."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, *, code: int) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.code = code
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipConflictContext:
|
||||||
|
relevant_lines: list[str]
|
||||||
|
requested_lines: list[str]
|
||||||
|
dependency_detail_lines: list[str]
|
||||||
|
constraint_lines: list[str]
|
||||||
|
has_strong_conflict_signal: bool
|
||||||
|
has_contextual_conflict_signal: bool
|
||||||
|
|
||||||
|
|
||||||
def _get_pip_main():
|
def _get_pip_main():
|
||||||
@@ -41,11 +91,12 @@ def _get_pip_main():
|
|||||||
return pip_main
|
return pip_main
|
||||||
|
|
||||||
|
|
||||||
def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]:
|
def _prepend_sys_path(path: str) -> None:
|
||||||
stream = io.StringIO()
|
normalized_target = os.path.realpath(path)
|
||||||
with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
|
sys.path[:] = [
|
||||||
result_code = pip_main(args)
|
item for item in sys.path if os.path.realpath(item) != normalized_target
|
||||||
return result_code, stream.getvalue()
|
]
|
||||||
|
sys.path.insert(0, normalized_target)
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
|
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
|
||||||
@@ -59,76 +110,258 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No
|
|||||||
handler.close()
|
handler.close()
|
||||||
|
|
||||||
|
|
||||||
def _prepend_sys_path(path: str) -> None:
|
def _get_trusted_host_for_index_url(index_url: str) -> str | None:
|
||||||
normalized_target = os.path.realpath(path)
|
parsed = urlparse(index_url if "://" in index_url else f"//{index_url}")
|
||||||
sys.path[:] = [
|
host = parsed.hostname
|
||||||
item for item in sys.path if os.path.realpath(item) != normalized_target
|
if host == "mirrors.aliyun.com":
|
||||||
]
|
return host
|
||||||
sys.path.insert(0, normalized_target)
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
|
def _normalize_sensitive_pip_key(raw_key: str) -> str:
|
||||||
base_path = os.path.join(site_packages_path, *module_name.split("."))
|
return raw_key.lstrip("-").replace("-", "_").lower()
|
||||||
package_init = os.path.join(base_path, "__init__.py")
|
|
||||||
module_file = f"{base_path}.py"
|
|
||||||
return os.path.isfile(package_init) or os.path.isfile(module_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_module_loaded_from_site_packages(
|
def _is_sensitive_pip_value_key(raw_key: str) -> bool:
|
||||||
module_name: str,
|
return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS
|
||||||
site_packages_path: str,
|
|
||||||
) -> bool:
|
|
||||||
module = sys.modules.get(module_name)
|
|
||||||
if module is None:
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
module_file = getattr(module, "__file__", None)
|
|
||||||
if not module_file:
|
|
||||||
return False
|
|
||||||
|
|
||||||
module_path = os.path.realpath(module_file)
|
def _redact_url_credentials(raw_value: str) -> str:
|
||||||
site_packages_real = os.path.realpath(site_packages_path)
|
"""Redact URL credentials and known inline secret values for safe logging."""
|
||||||
try:
|
parsed = urlparse(raw_value)
|
||||||
return (
|
if parsed.netloc and "@" in parsed.netloc:
|
||||||
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
|
hostname = parsed.hostname or ""
|
||||||
|
port = f":{parsed.port}" if parsed.port else ""
|
||||||
|
return parsed._replace(netloc=f"<redacted>@{hostname}{port}").geturl()
|
||||||
|
|
||||||
|
if raw_value.startswith("--"):
|
||||||
|
option, separator, _ = raw_value.partition("=")
|
||||||
|
if separator and _is_sensitive_pip_value_key(option):
|
||||||
|
return f"{option}=****"
|
||||||
|
return raw_value
|
||||||
|
|
||||||
|
key, separator, _ = raw_value.partition("=")
|
||||||
|
if separator and _is_sensitive_pip_value_key(key):
|
||||||
|
return f"{key}=****"
|
||||||
|
|
||||||
|
return raw_value
|
||||||
|
|
||||||
|
|
||||||
|
def _redact_pip_args_for_logging(args: list[str]) -> list[str]:
|
||||||
|
redacted_args: list[str] = []
|
||||||
|
redact_next_value = False
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
if redact_next_value:
|
||||||
|
redacted_args.append("****")
|
||||||
|
redact_next_value = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if arg.startswith("--") and "=" in arg:
|
||||||
|
option, value = arg.split("=", 1)
|
||||||
|
if _is_sensitive_pip_value_key(option):
|
||||||
|
redacted_args.append(f"{option}=****")
|
||||||
|
else:
|
||||||
|
redacted_args.append(f"{option}={_redact_url_credentials(value)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if arg.startswith("-i") and arg != "-i":
|
||||||
|
redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if _is_sensitive_pip_value_key(arg):
|
||||||
|
redacted_args.append(arg)
|
||||||
|
redact_next_value = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
redacted_args.append(_redact_url_credentials(arg))
|
||||||
|
|
||||||
|
return redacted_args
|
||||||
|
|
||||||
|
|
||||||
|
def _package_specs_override_index(package_specs: list[str]) -> bool:
|
||||||
|
for index, spec in enumerate(package_specs):
|
||||||
|
if spec == "--no-index":
|
||||||
|
return True
|
||||||
|
if spec in {"-i", "--index-url"}:
|
||||||
|
if index + 1 < len(package_specs):
|
||||||
|
return True
|
||||||
|
continue
|
||||||
|
if spec.startswith("--index-url="):
|
||||||
|
return True
|
||||||
|
if spec.startswith("-i") and spec != "-i":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamingLogWriter(io.TextIOBase):
|
||||||
|
def __init__(self, log_func, *, max_lines: int | None = None) -> None:
|
||||||
|
self._log_func = log_func
|
||||||
|
self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES)
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
|
def write(self, text: str) -> int:
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
self._buffer += text.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
|
while "\n" in self._buffer:
|
||||||
|
raw_line, self._buffer = self._buffer.split("\n", 1)
|
||||||
|
line = raw_line.rstrip("\r\n")
|
||||||
|
self._log_func(line)
|
||||||
|
self._lines.append(line)
|
||||||
|
return len(text)
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
line = self._buffer.rstrip("\r\n")
|
||||||
|
if line:
|
||||||
|
self._log_func(line)
|
||||||
|
self._lines.append(line)
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lines(self) -> list[str]:
|
||||||
|
return list(self._lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]:
|
||||||
|
stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES)
|
||||||
|
with (
|
||||||
|
contextlib.redirect_stdout(stream),
|
||||||
|
contextlib.redirect_stderr(stream),
|
||||||
|
):
|
||||||
|
result_code = pip_main(args)
|
||||||
|
stream.flush()
|
||||||
|
return result_code, stream.lines
|
||||||
|
|
||||||
|
|
||||||
|
def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool:
|
||||||
|
names = pattern_names or tuple(_PIP_FAILURE_PATTERNS)
|
||||||
|
return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_conflict_detail_line(line: str) -> str:
|
||||||
|
stripped = line.strip()
|
||||||
|
if _matches_pip_failure_pattern(stripped, "user_requested"):
|
||||||
|
return re.sub(
|
||||||
|
r"^\s*The user requested\s+",
|
||||||
|
"",
|
||||||
|
stripped,
|
||||||
|
flags=re.IGNORECASE,
|
||||||
)
|
)
|
||||||
except ValueError:
|
return stripped
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_requirement_name(raw_requirement: str) -> str | None:
|
def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None:
|
||||||
line = raw_requirement.split("#", 1)[0].strip()
|
matched_indices = [
|
||||||
if not line:
|
index
|
||||||
return None
|
for index, line in enumerate(output_lines)
|
||||||
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
|
if _matches_pip_failure_pattern(line)
|
||||||
return None
|
]
|
||||||
if line.startswith("-"):
|
if matched_indices:
|
||||||
|
relevant_index_set: set[int] = set()
|
||||||
|
for index in matched_indices:
|
||||||
|
start = max(0, index - 1)
|
||||||
|
end = min(len(output_lines), index + 2)
|
||||||
|
relevant_index_set.update(range(start, end))
|
||||||
|
relevant_output_lines = [
|
||||||
|
line
|
||||||
|
for index, line in enumerate(output_lines)
|
||||||
|
if index in relevant_index_set
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
relevant_output_lines = output_lines[-5:]
|
||||||
|
|
||||||
|
if not relevant_output_lines:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
|
dependency_detail_lines = [
|
||||||
if egg_match:
|
line.strip()
|
||||||
return _canonicalize_distribution_name(egg_match.group(1))
|
for line in relevant_output_lines
|
||||||
|
if _matches_pip_failure_pattern(line, "dependency_detail")
|
||||||
|
]
|
||||||
|
requested_lines = [
|
||||||
|
line.strip()
|
||||||
|
for line in relevant_output_lines
|
||||||
|
if _matches_pip_failure_pattern(line, "user_requested")
|
||||||
|
and not _matches_pip_failure_pattern(line, "constraint")
|
||||||
|
]
|
||||||
|
if not requested_lines:
|
||||||
|
requested_lines = [
|
||||||
|
line
|
||||||
|
for line in dependency_detail_lines
|
||||||
|
if not _matches_pip_failure_pattern(line, "constraint")
|
||||||
|
]
|
||||||
|
constraint_lines = [
|
||||||
|
line.strip()
|
||||||
|
for line in relevant_output_lines
|
||||||
|
if _matches_pip_failure_pattern(line, "constraint")
|
||||||
|
]
|
||||||
|
|
||||||
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
|
has_strong_conflict_signal = any(
|
||||||
if not candidate:
|
_matches_pip_failure_pattern(
|
||||||
|
line,
|
||||||
|
"resolution_impossible",
|
||||||
|
"cannot_install",
|
||||||
|
)
|
||||||
|
for line in relevant_output_lines
|
||||||
|
)
|
||||||
|
|
||||||
|
has_contextual_conflict_signal = any(
|
||||||
|
_matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines
|
||||||
|
) and bool(dependency_detail_lines or requested_lines or constraint_lines)
|
||||||
|
|
||||||
|
return PipConflictContext(
|
||||||
|
relevant_lines=relevant_output_lines,
|
||||||
|
requested_lines=requested_lines,
|
||||||
|
dependency_detail_lines=dependency_detail_lines,
|
||||||
|
constraint_lines=constraint_lines,
|
||||||
|
has_strong_conflict_signal=has_strong_conflict_signal,
|
||||||
|
has_contextual_conflict_signal=has_contextual_conflict_signal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None:
|
||||||
|
context = _build_pip_conflict_context(output_lines)
|
||||||
|
if context is None:
|
||||||
return None
|
return None
|
||||||
return _canonicalize_distribution_name(candidate)
|
|
||||||
|
|
||||||
|
if (
|
||||||
|
not context.has_strong_conflict_signal
|
||||||
|
and not context.has_contextual_conflict_signal
|
||||||
|
and not (context.requested_lines and context.constraint_lines)
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
def _extract_requirement_names(requirements_path: str) -> set[str]:
|
is_core_conflict = bool(context.constraint_lines)
|
||||||
names: set[str] = set()
|
|
||||||
try:
|
detail = ""
|
||||||
with open(requirements_path, encoding="utf-8") as requirements_file:
|
if context.constraint_lines and context.requested_lines:
|
||||||
for line in requirements_file:
|
detail = (
|
||||||
requirement_name = _extract_requirement_name(line)
|
" 冲突详情: "
|
||||||
if requirement_name:
|
f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs "
|
||||||
names.add(requirement_name)
|
f"{_normalize_conflict_detail_line(context.constraint_lines[0])}。"
|
||||||
except Exception as exc:
|
)
|
||||||
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
|
elif len(context.dependency_detail_lines) >= 2:
|
||||||
return names
|
detail = (
|
||||||
|
" 冲突详情: "
|
||||||
|
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs "
|
||||||
|
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}。"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_core_conflict:
|
||||||
|
message = (
|
||||||
|
f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容,"
|
||||||
|
"为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = f"检测到依赖冲突。{detail}"
|
||||||
|
|
||||||
|
return DependencyConflictError(
|
||||||
|
message,
|
||||||
|
context.relevant_lines,
|
||||||
|
is_core_conflict=is_core_conflict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_top_level_modules(
|
def _extract_top_level_modules(
|
||||||
@@ -155,7 +388,11 @@ def _collect_candidate_modules(
|
|||||||
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
|
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
|
||||||
try:
|
try:
|
||||||
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
|
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
|
||||||
distribution_name = distribution.metadata.get("Name")
|
distribution_name = (
|
||||||
|
distribution.metadata["Name"]
|
||||||
|
if "Name" in distribution.metadata
|
||||||
|
else None
|
||||||
|
)
|
||||||
if not distribution_name:
|
if not distribution_name:
|
||||||
continue
|
continue
|
||||||
canonical_name = _canonicalize_distribution_name(distribution_name)
|
canonical_name = _canonicalize_distribution_name(distribution_name)
|
||||||
@@ -173,7 +410,7 @@ def _collect_candidate_modules(
|
|||||||
|
|
||||||
for distribution in by_name.get(requirement_name, []):
|
for distribution in by_name.get(requirement_name, []):
|
||||||
for dependency_line in distribution.requires or []:
|
for dependency_line in distribution.requires or []:
|
||||||
dependency_name = _extract_requirement_name(dependency_line)
|
dependency_name = extract_requirement_name(dependency_line)
|
||||||
if not dependency_name:
|
if not dependency_name:
|
||||||
continue
|
continue
|
||||||
if dependency_name in expanded_requirement_names:
|
if dependency_name in expanded_requirement_names:
|
||||||
@@ -230,6 +467,38 @@ def _ensure_preferred_modules(
|
|||||||
raise RuntimeError(conflict_message)
|
raise RuntimeError(conflict_message)
|
||||||
|
|
||||||
|
|
||||||
|
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
|
||||||
|
base_path = os.path.join(site_packages_path, *module_name.split("."))
|
||||||
|
package_init = os.path.join(base_path, "__init__.py")
|
||||||
|
module_file = f"{base_path}.py"
|
||||||
|
return os.path.isfile(package_init) or os.path.isfile(module_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_module_loaded_from_site_packages(
|
||||||
|
module_name: str,
|
||||||
|
site_packages_path: str,
|
||||||
|
) -> bool:
|
||||||
|
module = sys.modules.get(module_name)
|
||||||
|
if module is None:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
module_file = getattr(module, "__file__", None)
|
||||||
|
if not module_file:
|
||||||
|
return False
|
||||||
|
|
||||||
|
module_path = os.path.realpath(module_file)
|
||||||
|
site_packages_real = os.path.realpath(site_packages_path)
|
||||||
|
try:
|
||||||
|
return (
|
||||||
|
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _prefer_module_from_site_packages(
|
def _prefer_module_from_site_packages(
|
||||||
module_name: str, site_packages_path: str
|
module_name: str, site_packages_path: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -531,9 +800,63 @@ def _patch_distlib_finder_for_frozen_runtime() -> None:
|
|||||||
|
|
||||||
|
|
||||||
class PipInstaller:
|
class PipInstaller:
|
||||||
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
pip_install_arg: str,
|
||||||
|
pypi_index_url: str | None = None,
|
||||||
|
core_dist_name: str | None = "AstrBot",
|
||||||
|
) -> None:
|
||||||
self.pip_install_arg = pip_install_arg
|
self.pip_install_arg = pip_install_arg
|
||||||
self.pypi_index_url = pypi_index_url
|
self.pypi_index_url = pypi_index_url
|
||||||
|
self.core_dist_name = core_dist_name
|
||||||
|
self._core_constraints = CoreConstraintsProvider(core_dist_name)
|
||||||
|
|
||||||
|
def _build_pip_args(
|
||||||
|
self,
|
||||||
|
package_name: str | None,
|
||||||
|
requirements_path: str | None,
|
||||||
|
mirror: str | None,
|
||||||
|
) -> tuple[list[str], set[str]]:
|
||||||
|
args: list[str] = []
|
||||||
|
requested_requirements: set[str] = set()
|
||||||
|
normalized_requirements_path = (
|
||||||
|
requirements_path.strip() if requirements_path else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if package_name and normalized_requirements_path:
|
||||||
|
raise ValueError(
|
||||||
|
"package_name and requirements_path cannot be used together"
|
||||||
|
)
|
||||||
|
|
||||||
|
if package_name:
|
||||||
|
parsed_package = parse_package_install_input(package_name)
|
||||||
|
if parsed_package.specs:
|
||||||
|
args = ["install", *parsed_package.specs]
|
||||||
|
requested_requirements = set(parsed_package.requirement_names)
|
||||||
|
elif normalized_requirements_path:
|
||||||
|
args = ["install", "-r", normalized_requirements_path]
|
||||||
|
requested_requirements = extract_requirement_names(
|
||||||
|
normalized_requirements_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args:
|
||||||
|
return [], requested_requirements
|
||||||
|
|
||||||
|
pip_install_args = (
|
||||||
|
shlex.split(self.pip_install_arg) if self.pip_install_arg else []
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _package_specs_override_index([*args[1:], *pip_install_args]):
|
||||||
|
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
||||||
|
trusted_host = _get_trusted_host_for_index_url(index_url)
|
||||||
|
if trusted_host:
|
||||||
|
args.extend(["--trusted-host", trusted_host])
|
||||||
|
args.extend(["-i", index_url])
|
||||||
|
|
||||||
|
if pip_install_args:
|
||||||
|
args.extend(pip_install_args)
|
||||||
|
|
||||||
|
return args, requested_requirements
|
||||||
|
|
||||||
async def install(
|
async def install(
|
||||||
self,
|
self,
|
||||||
@@ -541,36 +864,37 @@ class PipInstaller:
|
|||||||
requirements_path: str | None = None,
|
requirements_path: str | None = None,
|
||||||
mirror: str | None = None,
|
mirror: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
args = ["install"]
|
args, requested_requirements = self._build_pip_args(
|
||||||
requested_requirements: set[str] = set()
|
package_name, requirements_path, mirror
|
||||||
if package_name:
|
)
|
||||||
args.append(package_name)
|
if not args:
|
||||||
requirement_name = _extract_requirement_name(package_name)
|
logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。")
|
||||||
if requirement_name:
|
return
|
||||||
requested_requirements.add(requirement_name)
|
|
||||||
elif requirements_path:
|
|
||||||
args.extend(["-r", requirements_path])
|
|
||||||
requested_requirements = _extract_requirement_names(requirements_path)
|
|
||||||
|
|
||||||
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
|
||||||
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
|
|
||||||
|
|
||||||
target_site_packages = None
|
target_site_packages = None
|
||||||
if is_packaged_desktop_runtime():
|
if is_packaged_desktop_runtime():
|
||||||
target_site_packages = get_astrbot_site_packages_path()
|
target_site_packages = get_astrbot_site_packages_path()
|
||||||
os.makedirs(target_site_packages, exist_ok=True)
|
os.makedirs(target_site_packages, exist_ok=True)
|
||||||
_prepend_sys_path(target_site_packages)
|
_prepend_sys_path(target_site_packages)
|
||||||
args.extend(["--target", target_site_packages])
|
args.extend(
|
||||||
args.extend(["--upgrade", "--force-reinstall"])
|
[
|
||||||
|
"--target",
|
||||||
|
target_site_packages,
|
||||||
|
"--upgrade",
|
||||||
|
"--upgrade-strategy",
|
||||||
|
"only-if-needed",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if self.pip_install_arg:
|
with self._core_constraints.constraints_file() as constraints_file_path:
|
||||||
args.extend(self.pip_install_arg.split())
|
if constraints_file_path:
|
||||||
|
args.extend(["-c", constraints_file_path])
|
||||||
|
|
||||||
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
|
logger.info(
|
||||||
result_code = await self._run_pip_in_process(args)
|
"Pip 包管理器 argv: %s",
|
||||||
|
["pip", *_redact_pip_args_for_logging(args)],
|
||||||
if result_code != 0:
|
)
|
||||||
raise Exception(f"安装失败,错误码:{result_code}")
|
await self._run_pip_with_classification(args)
|
||||||
|
|
||||||
if target_site_packages:
|
if target_site_packages:
|
||||||
_prepend_sys_path(target_site_packages)
|
_prepend_sys_path(target_site_packages)
|
||||||
@@ -589,7 +913,7 @@ class PipInstaller:
|
|||||||
if not os.path.isdir(target_site_packages):
|
if not os.path.isdir(target_site_packages):
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_requirements = _extract_requirement_names(requirements_path)
|
requested_requirements = extract_requirement_names(requirements_path)
|
||||||
if not requested_requirements:
|
if not requested_requirements:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -605,13 +929,21 @@ class PipInstaller:
|
|||||||
_patch_distlib_finder_for_frozen_runtime()
|
_patch_distlib_finder_for_frozen_runtime()
|
||||||
|
|
||||||
original_handlers = list(logging.getLogger().handlers)
|
original_handlers = list(logging.getLogger().handlers)
|
||||||
result_code, output = await asyncio.to_thread(
|
try:
|
||||||
_run_pip_main_with_output, pip_main, args
|
result_code, output_lines = await asyncio.to_thread(
|
||||||
)
|
_run_pip_main_streaming, pip_main, args
|
||||||
for line in output.splitlines():
|
)
|
||||||
line = line.strip()
|
finally:
|
||||||
if line:
|
_cleanup_added_root_handlers(original_handlers)
|
||||||
logger.info(line)
|
|
||||||
|
if result_code != 0:
|
||||||
|
conflict = _classify_pip_failure(output_lines)
|
||||||
|
if conflict:
|
||||||
|
raise conflict
|
||||||
|
|
||||||
_cleanup_added_root_handlers(original_handlers)
|
|
||||||
return result_code
|
return result_code
|
||||||
|
|
||||||
|
async def _run_pip_with_classification(self, args: list[str]) -> None:
|
||||||
|
result_code = await self._run_pip_in_process(args)
|
||||||
|
if result_code != 0:
|
||||||
|
raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code)
|
||||||
|
|||||||
@@ -0,0 +1,408 @@
|
|||||||
|
import importlib.metadata as importlib_metadata
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
import sys
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from packaging.requirements import InvalidRequirement, Requirement
|
||||||
|
from packaging.specifiers import SpecifierSet
|
||||||
|
from packaging.version import InvalidVersion, Version
|
||||||
|
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
||||||
|
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
||||||
|
|
||||||
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
|
|
||||||
|
class RequirementsPrecheckFailed(Exception):
|
||||||
|
"""Raised when the pre-check of requirements fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ParsedPackageInput:
|
||||||
|
specs: tuple[str, ...]
|
||||||
|
requirement_names: frozenset[str]
|
||||||
|
|
||||||
|
|
||||||
|
def canonicalize_distribution_name(name: str) -> str:
|
||||||
|
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def strip_inline_requirement_comment(raw_input: str) -> str:
|
||||||
|
if raw_input.lstrip().startswith("#"):
|
||||||
|
return ""
|
||||||
|
return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool:
|
||||||
|
try:
|
||||||
|
parsed_version = Version(version)
|
||||||
|
except InvalidVersion:
|
||||||
|
return False
|
||||||
|
return specifier.contains(parsed_version, prereleases=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_local_path_reference(token: str) -> bool:
|
||||||
|
candidate = token.strip()
|
||||||
|
if not candidate:
|
||||||
|
return False
|
||||||
|
return candidate in {".", ".."} or candidate.startswith(
|
||||||
|
("./", "../", "/", "~/", ".\\", "..\\", "\\")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def looks_like_direct_reference(token: str) -> bool:
|
||||||
|
candidate = token.strip()
|
||||||
|
if not candidate:
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
_looks_like_local_path_reference(candidate)
|
||||||
|
or candidate.startswith("git+")
|
||||||
|
or "://" in candidate
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_requirement_name(raw_requirement: str) -> str | None:
|
||||||
|
line = raw_requirement.split("#", 1)[0].strip()
|
||||||
|
if not line:
|
||||||
|
return None
|
||||||
|
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
|
||||||
|
return None
|
||||||
|
|
||||||
|
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
|
||||||
|
if egg_match:
|
||||||
|
return canonicalize_distribution_name(egg_match.group(1))
|
||||||
|
|
||||||
|
if line.startswith("-"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
|
||||||
|
if not candidate:
|
||||||
|
return None
|
||||||
|
return canonicalize_distribution_name(candidate)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_editable_or_direct_name(target: str) -> str | None:
|
||||||
|
name = extract_requirement_name(target)
|
||||||
|
if not name:
|
||||||
|
return None
|
||||||
|
if "#egg=" in target or not looks_like_direct_reference(target):
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_requirement_name_and_spec(
|
||||||
|
line: str,
|
||||||
|
) -> tuple[str | None, SpecifierSet | None]:
|
||||||
|
if line.startswith(("-c", "--constraint")):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
req = Requirement(line)
|
||||||
|
except InvalidRequirement:
|
||||||
|
tokens = shlex.split(line)
|
||||||
|
if not tokens:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
editable_target: str | None = None
|
||||||
|
if tokens[0] in {"-e", "--editable"} and len(tokens) > 1:
|
||||||
|
editable_target = tokens[1]
|
||||||
|
elif tokens[0].startswith("--editable="):
|
||||||
|
editable_target = tokens[0].split("=", 1)[1]
|
||||||
|
|
||||||
|
if editable_target:
|
||||||
|
name = _parse_editable_or_direct_name(editable_target)
|
||||||
|
return (name, None) if name else (None, None)
|
||||||
|
|
||||||
|
name = _parse_editable_or_direct_name(line)
|
||||||
|
return (name, None) if name else (None, None)
|
||||||
|
|
||||||
|
if req.marker and not req.marker.evaluate():
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
return canonicalize_distribution_name(req.name), (req.specifier or None)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_requirement_line(
|
||||||
|
line: str,
|
||||||
|
) -> tuple[str, SpecifierSet | None] | None:
|
||||||
|
name, specifier = _parse_requirement_name_and_spec(line)
|
||||||
|
return (name, specifier) if name else None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]:
|
||||||
|
requirement_names: set[str] = set()
|
||||||
|
skip_next_for: str | None = None
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if skip_next_for:
|
||||||
|
if skip_next_for == "editable":
|
||||||
|
name = _parse_editable_or_direct_name(token)
|
||||||
|
if name:
|
||||||
|
requirement_names.add(name)
|
||||||
|
skip_next_for = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token in {"-e", "--editable"}:
|
||||||
|
skip_next_for = "editable"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token in {
|
||||||
|
"-i",
|
||||||
|
"--index-url",
|
||||||
|
"--extra-index-url",
|
||||||
|
"-f",
|
||||||
|
"--find-links",
|
||||||
|
"--trusted-host",
|
||||||
|
"-r",
|
||||||
|
"--requirement",
|
||||||
|
"-c",
|
||||||
|
"--constraint",
|
||||||
|
}:
|
||||||
|
skip_next_for = "option-value"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token.startswith(("--editable=",)):
|
||||||
|
editable_target = token.split("=", 1)[1]
|
||||||
|
name = _parse_editable_or_direct_name(editable_target)
|
||||||
|
if name:
|
||||||
|
requirement_names.add(name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token.startswith(
|
||||||
|
(
|
||||||
|
"--index-url=",
|
||||||
|
"--extra-index-url=",
|
||||||
|
"--find-links=",
|
||||||
|
"--trusted-host=",
|
||||||
|
"--requirement=",
|
||||||
|
"--constraint=",
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
(token.startswith("-i") and token != "-i")
|
||||||
|
or (token.startswith("-f") and token != "-f")
|
||||||
|
or token == "--no-index"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token.startswith("-"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name, _ = _parse_requirement_name_and_spec(token)
|
||||||
|
if name:
|
||||||
|
requirement_names.add(name)
|
||||||
|
|
||||||
|
return frozenset(requirement_names)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_package_install_input(raw_input: str) -> ParsedPackageInput:
|
||||||
|
specs: list[str] = []
|
||||||
|
requirement_names: set[str] = set()
|
||||||
|
normalized = raw_input.strip()
|
||||||
|
if not normalized:
|
||||||
|
return ParsedPackageInput(specs=(), requirement_names=frozenset())
|
||||||
|
|
||||||
|
for raw_line in normalized.splitlines():
|
||||||
|
line = strip_inline_requirement_comment(raw_line)
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
Requirement(line)
|
||||||
|
except InvalidRequirement:
|
||||||
|
tokens = shlex.split(line)
|
||||||
|
if not tokens:
|
||||||
|
continue
|
||||||
|
specs.extend(tokens)
|
||||||
|
requirement_names.update(
|
||||||
|
_extract_requirement_names_from_package_tokens(tokens)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
specs.append(line)
|
||||||
|
name, _ = _parse_requirement_name_and_spec(line)
|
||||||
|
if name:
|
||||||
|
requirement_names.add(name)
|
||||||
|
|
||||||
|
return ParsedPackageInput(
|
||||||
|
specs=tuple(specs),
|
||||||
|
requirement_names=frozenset(requirement_names),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_requirement_lines(
|
||||||
|
requirements_path: str,
|
||||||
|
_visited: set[str] | None = None,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
visited = _visited or set()
|
||||||
|
resolved_path = os.path.realpath(requirements_path)
|
||||||
|
if resolved_path in visited:
|
||||||
|
logger.warning(
|
||||||
|
"检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path
|
||||||
|
)
|
||||||
|
return
|
||||||
|
visited.add(resolved_path)
|
||||||
|
|
||||||
|
with open(resolved_path, encoding="utf-8") as f:
|
||||||
|
for raw_line in f:
|
||||||
|
line = strip_inline_requirement_comment(raw_line)
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tokens = shlex.split(line)
|
||||||
|
if not tokens:
|
||||||
|
continue
|
||||||
|
|
||||||
|
nested: str | None = None
|
||||||
|
if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1:
|
||||||
|
nested = tokens[1]
|
||||||
|
elif tokens[0].startswith("--requirement="):
|
||||||
|
nested = tokens[0].split("=", 1)[1]
|
||||||
|
|
||||||
|
if nested:
|
||||||
|
if not os.path.isabs(nested):
|
||||||
|
nested = os.path.join(os.path.dirname(resolved_path), nested)
|
||||||
|
yield from _iter_requirement_lines(nested, _visited=visited)
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
def iter_requirements(
|
||||||
|
requirements_path: str | None = None,
|
||||||
|
lines: Iterable[str] | None = None,
|
||||||
|
) -> Iterator[tuple[str, SpecifierSet | None]]:
|
||||||
|
if lines is None:
|
||||||
|
if requirements_path is None:
|
||||||
|
raise ValueError("Either requirements_path or lines must be provided")
|
||||||
|
lines = _iter_requirement_lines(requirements_path)
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
parsed = _parse_requirement_line(line)
|
||||||
|
if parsed is not None:
|
||||||
|
yield parsed
|
||||||
|
|
||||||
|
|
||||||
|
def extract_requirement_names(requirements_path: str) -> set[str]:
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
name for name, _ in iter_requirements(requirements_path=requirements_path)
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def get_requirement_check_paths() -> list[str]:
|
||||||
|
paths = list(sys.path)
|
||||||
|
if is_packaged_desktop_runtime():
|
||||||
|
target_site_packages = get_astrbot_site_packages_path()
|
||||||
|
if os.path.isdir(target_site_packages):
|
||||||
|
paths.insert(0, target_site_packages)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
|
||||||
|
distribution_name = (
|
||||||
|
distribution.metadata["Name"] if "Name" in distribution.metadata else None
|
||||||
|
)
|
||||||
|
if not distribution_name:
|
||||||
|
return None, None
|
||||||
|
return canonicalize_distribution_name(distribution_name), distribution.version
|
||||||
|
|
||||||
|
|
||||||
|
def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None:
|
||||||
|
installed: dict[str, str] = {}
|
||||||
|
try:
|
||||||
|
for distribution in importlib_metadata.distributions(path=paths):
|
||||||
|
distribution_name, version = _canonical_distribution_identity(distribution)
|
||||||
|
if not distribution_name or not version:
|
||||||
|
continue
|
||||||
|
installed.setdefault(distribution_name, version)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc)
|
||||||
|
return None
|
||||||
|
return installed
|
||||||
|
|
||||||
|
|
||||||
|
def _load_requirement_lines_for_precheck(
|
||||||
|
requirements_path: str,
|
||||||
|
) -> tuple[bool, list[str] | None]:
|
||||||
|
try:
|
||||||
|
requirement_lines = list(_iter_requirement_lines(requirements_path))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"预检查缺失依赖失败,将回退到完整安装: %s (%s)",
|
||||||
|
requirements_path,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
fallback_line = next(
|
||||||
|
(
|
||||||
|
line
|
||||||
|
for line in requirement_lines
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
line.startswith(("-e ", "--editable ", "--editable="))
|
||||||
|
and "#egg=" not in line
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
_parse_requirement_line(line) is None
|
||||||
|
and looks_like_direct_reference(line)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if fallback_line is not None:
|
||||||
|
logger.warning(
|
||||||
|
"预检查缺失依赖失败,将回退到完整安装: unresolved direct reference in %s: %s",
|
||||||
|
requirements_path,
|
||||||
|
fallback_line,
|
||||||
|
)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
return True, requirement_lines
|
||||||
|
|
||||||
|
|
||||||
|
def find_missing_requirements(requirements_path: str) -> set[str] | None:
|
||||||
|
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
|
||||||
|
requirements_path
|
||||||
|
)
|
||||||
|
if not can_precheck or requirement_lines is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
required = list(iter_requirements(lines=requirement_lines))
|
||||||
|
if not required:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
installed = collect_installed_distribution_versions(get_requirement_check_paths())
|
||||||
|
if installed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
missing: set[str] = set()
|
||||||
|
for name, specifier in required:
|
||||||
|
installed_version = installed.get(name)
|
||||||
|
if not installed_version:
|
||||||
|
missing.add(name)
|
||||||
|
continue
|
||||||
|
if specifier and not _specifier_contains_version(specifier, installed_version):
|
||||||
|
missing.add(name)
|
||||||
|
|
||||||
|
return missing
|
||||||
|
|
||||||
|
|
||||||
|
def find_missing_requirements_or_raise(requirements_path: str) -> set[str]:
|
||||||
|
missing = find_missing_requirements(requirements_path)
|
||||||
|
if missing is None:
|
||||||
|
raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}")
|
||||||
|
return missing
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed } from 'vue';
|
import { computed } from 'vue';
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue';
|
||||||
|
|
||||||
const { tm } = useModuleI18n('features/command');
|
const { tm } = useModuleI18n('features/command');
|
||||||
|
|
||||||
@@ -52,6 +53,7 @@ const statusItems = [
|
|||||||
{ title: tm('filters.disabled'), value: 'disabled' },
|
{ title: tm('filters.disabled'), value: 'disabled' },
|
||||||
{ title: tm('filters.conflict'), value: 'conflict' }
|
{ title: tm('filters.conflict'), value: 'conflict' }
|
||||||
];
|
];
|
||||||
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
@@ -108,10 +110,11 @@ const statusItems = [
|
|||||||
<div style="min-width: 200px; max-width: 350px; flex: 1; border: 1px solid #B9B9B9; border-radius: 16px;">
|
<div style="min-width: 200px; max-width: 350px; flex: 1; border: 1px solid #B9B9B9; border-radius: 16px;">
|
||||||
<v-text-field
|
<v-text-field
|
||||||
:model-value="searchQuery"
|
:model-value="searchQuery"
|
||||||
@update:model-value="emit('update:searchQuery', $event)"
|
@update:model-value="emit('update:searchQuery', normalizeTextInput($event))"
|
||||||
density="compact"
|
density="compact"
|
||||||
:label="tm('search.placeholder')"
|
:label="tm('search.placeholder')"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
|
clearable
|
||||||
variant="solo-filled"
|
variant="solo-filled"
|
||||||
flat
|
flat
|
||||||
hide-details
|
hide-details
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
import { ref, computed, type Ref } from 'vue';
|
import { ref, computed, type Ref } from 'vue';
|
||||||
import type { CommandItem, FilterState } from '../types';
|
import type { CommandItem, FilterState } from '../types';
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue';
|
||||||
|
|
||||||
export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
||||||
// 过滤状态
|
// 过滤状态
|
||||||
@@ -95,7 +96,7 @@ export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
|||||||
* 过滤后的指令列表(支持层级结构)
|
* 过滤后的指令列表(支持层级结构)
|
||||||
*/
|
*/
|
||||||
const filteredCommands = computed(() => {
|
const filteredCommands = computed(() => {
|
||||||
const query = searchQuery.value.toLowerCase();
|
const query = normalizeTextInput(searchQuery.value).toLowerCase();
|
||||||
const conflictCmds: CommandItem[] = [];
|
const conflictCmds: CommandItem[] = [];
|
||||||
const normalCmds: CommandItem[] = [];
|
const normalCmds: CommandItem[] = [];
|
||||||
|
|
||||||
@@ -184,4 +185,3 @@ export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
|||||||
isGroupExpanded
|
isGroupExpanded
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import { computed, onActivated, onMounted, ref, watch} from 'vue';
|
import { computed, onActivated, onMounted, ref, watch} from 'vue';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue';
|
||||||
|
|
||||||
// Composables
|
// Composables
|
||||||
import { useComponentData } from './composables/useComponentData';
|
import { useComponentData } from './composables/useComponentData';
|
||||||
@@ -83,7 +84,7 @@ const {
|
|||||||
} = useCommandActions(toast, () => fetchCommands(tm('messages.loadFailed')));
|
} = useCommandActions(toast, () => fetchCommands(tm('messages.loadFailed')));
|
||||||
|
|
||||||
const filteredTools = computed(() => {
|
const filteredTools = computed(() => {
|
||||||
const query = toolSearch.value.trim().toLowerCase();
|
const query = normalizeTextInput(toolSearch.value).trim().toLowerCase();
|
||||||
if (!query) return tools.value;
|
if (!query) return tools.value;
|
||||||
return tools.value.filter(tool =>
|
return tools.value.filter(tool =>
|
||||||
tool.name?.toLowerCase().includes(query) ||
|
tool.name?.toLowerCase().includes(query) ||
|
||||||
@@ -253,7 +254,8 @@ watch(viewMode, async (mode) => {
|
|||||||
<div class="d-flex flex-wrap align-center ga-3 mb-4">
|
<div class="d-flex flex-wrap align-center ga-3 mb-4">
|
||||||
<div style="min-width: 240px; max-width: 380px; flex: 1;">
|
<div style="min-width: 240px; max-width: 380px; flex: 1;">
|
||||||
<v-text-field
|
<v-text-field
|
||||||
v-model="toolSearch"
|
:model-value="toolSearch"
|
||||||
|
@update:model-value="toolSearch = normalizeTextInput($event)"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
:label="tmTool('functionTools.search')"
|
:label="tmTool('functionTools.search')"
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
v-model="modelSearchProxy"
|
v-model="modelSearchProxy"
|
||||||
density="compact"
|
density="compact"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
|
clearable
|
||||||
hide-details
|
hide-details
|
||||||
variant="solo-filled"
|
variant="solo-filled"
|
||||||
flat
|
flat
|
||||||
@@ -161,6 +162,7 @@
|
|||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
import { computed } from 'vue'
|
import { computed } from 'vue'
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue'
|
||||||
|
|
||||||
const props = defineProps({
|
const props = defineProps({
|
||||||
entries: {
|
entries: {
|
||||||
@@ -222,7 +224,7 @@ const emit = defineEmits([
|
|||||||
|
|
||||||
const modelSearchProxy = computed({
|
const modelSearchProxy = computed({
|
||||||
get: () => props.modelSearch,
|
get: () => props.modelSearch,
|
||||||
set: (val) => emit('update:modelSearch', val)
|
set: (val) => emit('update:modelSearch', normalizeTextInput(val))
|
||||||
})
|
})
|
||||||
|
|
||||||
const isProviderTesting = (providerId) => props.testingProviders.includes(providerId)
|
const isProviderTesting = (providerId) => props.testingProviders.includes(providerId)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { ref, computed, onMounted, nextTick, watch } from 'vue'
|
|||||||
import axios from 'axios'
|
import axios from 'axios'
|
||||||
import { getProviderIcon } from '@/utils/providerUtils'
|
import { getProviderIcon } from '@/utils/providerUtils'
|
||||||
import { askForConfirmation as askForConfirmationDialog, useConfirmDialog } from '@/utils/confirmDialog'
|
import { askForConfirmation as askForConfirmationDialog, useConfirmDialog } from '@/utils/confirmDialog'
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue'
|
||||||
|
|
||||||
export interface UseProviderSourcesOptions {
|
export interface UseProviderSourcesOptions {
|
||||||
defaultTab?: string
|
defaultTab?: string
|
||||||
@@ -157,7 +158,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const filteredMergedModelEntries = computed(() => {
|
const filteredMergedModelEntries = computed(() => {
|
||||||
const term = modelSearch.value.trim().toLowerCase()
|
const term = normalizeTextInput(modelSearch.value).trim().toLowerCase()
|
||||||
if (!term) return mergedModelEntries.value
|
if (!term) return mergedModelEntries.value
|
||||||
|
|
||||||
return mergedModelEntries.value.filter((entry: any) => {
|
return mergedModelEntries.value.filter((entry: any) => {
|
||||||
|
|||||||
@@ -873,7 +873,8 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"regex": {
|
"regex": {
|
||||||
"description": "Segmentation Regular Expression"
|
"description": "Segmentation Regular Expression",
|
||||||
|
"hint": "Used to identify split points with a regular expression. Prefer patterns that match separators."
|
||||||
},
|
},
|
||||||
"split_words": {
|
"split_words": {
|
||||||
"description": "Split Word List",
|
"description": "Split Word List",
|
||||||
|
|||||||
@@ -876,7 +876,8 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"regex": {
|
"regex": {
|
||||||
"description": "分段正则表达式"
|
"description": "分段正则表达式",
|
||||||
|
"hint": "用于按正则规则识别分段点。建议使用能匹配分隔符的表达式。"
|
||||||
},
|
},
|
||||||
"split_words": {
|
"split_words": {
|
||||||
"description": "分段词列表",
|
"description": "分段词列表",
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
export const normalizeTextInput = (value: unknown): string =>
|
||||||
|
typeof value === 'string' ? value : '';
|
||||||
@@ -13,9 +13,11 @@
|
|||||||
</v-select>
|
</v-select>
|
||||||
<v-text-field
|
<v-text-field
|
||||||
class="config-search-input"
|
class="config-search-input"
|
||||||
v-model="configSearchKeyword"
|
:model-value="configSearchKeyword"
|
||||||
|
@update:model-value="onConfigSearchInput"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
:label="tm('search.placeholder')"
|
:label="tm('search.placeholder')"
|
||||||
|
clearable
|
||||||
hide-details
|
hide-details
|
||||||
density="compact"
|
density="compact"
|
||||||
rounded="md"
|
rounded="md"
|
||||||
@@ -211,6 +213,7 @@ import {
|
|||||||
useConfirmDialog
|
useConfirmDialog
|
||||||
} from '@/utils/confirmDialog';
|
} from '@/utils/confirmDialog';
|
||||||
import UnsavedChangesConfirmDialog from '@/components/config/UnsavedChangesConfirmDialog.vue';
|
import UnsavedChangesConfirmDialog from '@/components/config/UnsavedChangesConfirmDialog.vue';
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue';
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: 'ConfigPage',
|
name: 'ConfigPage',
|
||||||
@@ -419,6 +422,9 @@ export default {
|
|||||||
|
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
|
onConfigSearchInput(value) {
|
||||||
|
this.configSearchKeyword = normalizeTextInput(value);
|
||||||
|
},
|
||||||
extractConfigTypeFromHash(hash) {
|
extractConfigTypeFromHash(hash) {
|
||||||
const rawHash = String(hash || '');
|
const rawHash = String(hash || '');
|
||||||
const lastHashIndex = rawHash.lastIndexOf('#');
|
const lastHashIndex = rawHash.lastIndexOf('#');
|
||||||
|
|||||||
@@ -353,10 +353,11 @@
|
|||||||
<v-window-item value="search">
|
<v-window-item value="search">
|
||||||
<div class="search-container pa-4">
|
<div class="search-container pa-4">
|
||||||
<v-form @submit.prevent="searchKnowledgeBase" class="d-flex align-center">
|
<v-form @submit.prevent="searchKnowledgeBase" class="d-flex align-center">
|
||||||
<v-text-field v-model="searchQuery" :label="tm('search.queryLabel')"
|
<v-text-field :model-value="searchQuery"
|
||||||
|
@update:model-value="onSearchQueryInput" :label="tm('search.queryLabel')"
|
||||||
append-icon="mdi-magnify" variant="outlined" class="flex-grow-1 me-2"
|
append-icon="mdi-magnify" variant="outlined" class="flex-grow-1 me-2"
|
||||||
@click:append="searchKnowledgeBase" @keyup.enter="searchKnowledgeBase"
|
@click:append="searchKnowledgeBase" @keyup.enter="searchKnowledgeBase"
|
||||||
:placeholder="tm('search.queryPlaceholder')" hide-details></v-text-field>
|
:placeholder="tm('search.queryPlaceholder')" hide-details clearable></v-text-field>
|
||||||
|
|
||||||
<v-select v-model="topK" :items="[3, 5, 10, 20]"
|
<v-select v-model="topK" :items="[3, 5, 10, 20]"
|
||||||
:label="tm('search.resultCountLabel')" variant="outlined"
|
:label="tm('search.resultCountLabel')" variant="outlined"
|
||||||
@@ -434,6 +435,7 @@
|
|||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue';
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: 'KnowledgeBase',
|
name: 'KnowledgeBase',
|
||||||
@@ -580,6 +582,9 @@ export default {
|
|||||||
this.getProviderList();
|
this.getProviderList();
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
|
onSearchQueryInput(value) {
|
||||||
|
this.searchQuery = normalizeTextInput(value);
|
||||||
|
},
|
||||||
getSelectedGitHubProxy() {
|
getSelectedGitHubProxy() {
|
||||||
if (typeof window === "undefined" || !window.localStorage) return "";
|
if (typeof window === "undefined" || !window.localStorage) return "";
|
||||||
return localStorage.getItem("githubProxyRadioValue") === "1"
|
return localStorage.getItem("githubProxyRadioValue") === "1"
|
||||||
@@ -903,7 +908,8 @@ export default {
|
|||||||
},
|
},
|
||||||
|
|
||||||
searchKnowledgeBase() {
|
searchKnowledgeBase() {
|
||||||
if (!this.searchQuery.trim()) {
|
const query = normalizeTextInput(this.searchQuery).trim();
|
||||||
|
if (!query) {
|
||||||
this.showSnackbar(this.tm('messages.pleaseEnterSearchContent'), 'warning');
|
this.showSnackbar(this.tm('messages.pleaseEnterSearchContent'), 'warning');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -914,7 +920,7 @@ export default {
|
|||||||
axios.get(`/api/plug/alkaid/kb/collection/search`, {
|
axios.get(`/api/plug/alkaid/kb/collection/search`, {
|
||||||
params: {
|
params: {
|
||||||
collection_name: this.currentKB.collection_name,
|
collection_name: this.currentKB.collection_name,
|
||||||
query: this.searchQuery,
|
query,
|
||||||
top_k: this.topK
|
top_k: this.topK
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -37,10 +37,12 @@
|
|||||||
<h3>{{ tm('search.title') }}</h3>
|
<h3>{{ tm('search.title') }}</h3>
|
||||||
<v-card variant="outlined" class="mt-2 pa-3">
|
<v-card variant="outlined" class="mt-2 pa-3">
|
||||||
<div>
|
<div>
|
||||||
<v-text-field v-model="searchMemoryUserId" :label="tm('search.userIdLabel')" variant="outlined" density="compact" hide-details
|
<v-text-field :model-value="searchMemoryUserId"
|
||||||
class="mb-2"></v-text-field>
|
@update:model-value="onSearchMemoryUserIdInput" :label="tm('search.userIdLabel')" variant="outlined" density="compact" hide-details
|
||||||
<v-text-field v-model="searchQuery" :label="tm('search.queryLabel')" variant="outlined" density="compact" hide-details
|
class="mb-2" clearable></v-text-field>
|
||||||
@keyup.enter="searchMemory" class="mb-2"></v-text-field>
|
<v-text-field :model-value="searchQuery"
|
||||||
|
@update:model-value="onSearchQueryInput" :label="tm('search.queryLabel')" variant="outlined" density="compact" hide-details
|
||||||
|
@keyup.enter="searchMemory" class="mb-2" clearable></v-text-field>
|
||||||
<v-btn color="info" @click="searchMemory" :loading="isSearching" variant="tonal">
|
<v-btn color="info" @click="searchMemory" :loading="isSearching" variant="tonal">
|
||||||
<v-icon start>mdi-text-search</v-icon>
|
<v-icon start>mdi-text-search</v-icon>
|
||||||
{{ tm('search.searchButton') }}
|
{{ tm('search.searchButton') }}
|
||||||
@@ -254,6 +256,7 @@
|
|||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
// import * as d3 from "d3"; // npm install d3
|
// import * as d3 from "d3"; // npm install d3
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
import { normalizeTextInput } from '@/utils/inputValue';
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: 'LongTermMemory',
|
name: 'LongTermMemory',
|
||||||
@@ -336,9 +339,16 @@ export default {
|
|||||||
this.searchResults = [];
|
this.searchResults = [];
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
|
onSearchMemoryUserIdInput(value) {
|
||||||
|
this.searchMemoryUserId = normalizeTextInput(value);
|
||||||
|
},
|
||||||
|
onSearchQueryInput(value) {
|
||||||
|
this.searchQuery = normalizeTextInput(value);
|
||||||
|
},
|
||||||
// 添加搜索记忆方法
|
// 添加搜索记忆方法
|
||||||
searchMemory() {
|
searchMemory() {
|
||||||
if (!this.searchQuery.trim()) {
|
const query = normalizeTextInput(this.searchQuery).trim();
|
||||||
|
if (!query) {
|
||||||
this.$toast.warning(this.tm('messages.searchQueryRequired'));
|
this.$toast.warning(this.tm('messages.searchQueryRequired'));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -349,12 +359,13 @@ export default {
|
|||||||
|
|
||||||
// 构建查询参数
|
// 构建查询参数
|
||||||
const params = {
|
const params = {
|
||||||
query: this.searchQuery
|
query
|
||||||
};
|
};
|
||||||
|
|
||||||
// 如果有选择用户ID,也加入查询参数
|
// 如果有选择用户ID,也加入查询参数
|
||||||
if (this.searchMemoryUserId) {
|
const normalizedUserId = normalizeTextInput(this.searchMemoryUserId).trim();
|
||||||
params.user_id = this.searchMemoryUserId;
|
if (normalizedUserId) {
|
||||||
|
params.user_id = normalizedUserId;
|
||||||
}
|
}
|
||||||
|
|
||||||
axios.get('/api/plug/alkaid/ltm/graph/search', { params })
|
axios.get('/api/plug/alkaid/ltm/graph/search', { params })
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import PluginSortControl from "@/components/extension/PluginSortControl.vue";
|
|||||||
import ExtensionCard from "@/components/shared/ExtensionCard.vue";
|
import ExtensionCard from "@/components/shared/ExtensionCard.vue";
|
||||||
import StyledMenu from "@/components/shared/StyledMenu.vue";
|
import StyledMenu from "@/components/shared/StyledMenu.vue";
|
||||||
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
|
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
|
||||||
|
import { normalizeTextInput } from "@/utils/inputValue";
|
||||||
|
|
||||||
const props = defineProps({
|
const props = defineProps({
|
||||||
state: {
|
state: {
|
||||||
@@ -164,10 +165,12 @@ const {
|
|||||||
|
|
||||||
<div class="d-flex align-center flex-wrap ml-auto" style="gap: 8px">
|
<div class="d-flex align-center flex-wrap ml-auto" style="gap: 8px">
|
||||||
<v-text-field
|
<v-text-field
|
||||||
v-model="pluginSearch"
|
:model-value="pluginSearch"
|
||||||
|
@update:model-value="pluginSearch = normalizeTextInput($event)"
|
||||||
density="compact"
|
density="compact"
|
||||||
:label="tm('search.placeholder')"
|
:label="tm('search.placeholder')"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
|
clearable
|
||||||
variant="solo-filled"
|
variant="solo-filled"
|
||||||
flat
|
flat
|
||||||
hide-details
|
hide-details
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import MarketPluginCard from "@/components/extension/MarketPluginCard.vue";
|
|||||||
import PluginSortControl from "@/components/extension/PluginSortControl.vue";
|
import PluginSortControl from "@/components/extension/PluginSortControl.vue";
|
||||||
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
|
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
|
||||||
import { computed } from "vue";
|
import { computed } from "vue";
|
||||||
|
import { normalizeTextInput } from "@/utils/inputValue";
|
||||||
|
|
||||||
const props = defineProps({
|
const props = defineProps({
|
||||||
state: {
|
state: {
|
||||||
@@ -212,11 +213,13 @@ const marketSortItems = computed(() => [
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<v-text-field
|
<v-text-field
|
||||||
v-model="marketSearch"
|
:model-value="marketSearch"
|
||||||
|
@update:model-value="marketSearch = normalizeTextInput($event)"
|
||||||
class="ml-auto"
|
class="ml-auto"
|
||||||
density="compact"
|
density="compact"
|
||||||
:label="tm('search.marketPlaceholder')"
|
:label="tm('search.marketPlaceholder')"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
|
clearable
|
||||||
variant="solo-filled"
|
variant="solo-filled"
|
||||||
flat
|
flat
|
||||||
hide-details
|
hide-details
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ export default defineConfig({
|
|||||||
next: '下一篇'
|
next: '下一篇'
|
||||||
},
|
},
|
||||||
editLink: {
|
editLink: {
|
||||||
pattern: 'https://github.com/AstrBotdevs/AstrBot-docs/edit/v4/:path',
|
pattern: 'https://github.com/AstrBotdevs/AstrBot/edit/master/docs/:path',
|
||||||
text: '发现文档有问题?在 GitHub 上编辑此页',
|
text: '发现文档有问题?在 GitHub 上编辑此页',
|
||||||
},
|
},
|
||||||
logo: '/logo_prod.png',
|
logo: '/logo_prod.png',
|
||||||
@@ -484,7 +484,7 @@ export default defineConfig({
|
|||||||
next: 'Next'
|
next: 'Next'
|
||||||
},
|
},
|
||||||
editLink: {
|
editLink: {
|
||||||
pattern: 'https://github.com/AstrBotdevs/AstrBot-docs/edit/v4/:path',
|
pattern: 'https://github.com/AstrBotdevs/AstrBot/edit/master/docs/:path',
|
||||||
text: 'Edit this page on GitHub',
|
text: 'Edit this page on GitHub',
|
||||||
},
|
},
|
||||||
logo: '/logo_prod.png',
|
logo: '/logo_prod.png',
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ Welcome to submit Issues or Pull Requests:
|
|||||||
|
|
||||||
- [AstrBotDevs/AstrBot](https://github.com/AstrBotDevs/AstrBot)
|
- [AstrBotDevs/AstrBot](https://github.com/AstrBotDevs/AstrBot)
|
||||||
|
|
||||||
- [AstrBotDevs/AstrBot-Docs](https://github.com/AstrBotDevs/AstrBot-docs)
|
|
||||||
|
|
||||||
### Tencent QQ Groups
|
### Tencent QQ Groups
|
||||||
|
|
||||||
> - All groups are available to join. If you find that the group size is below the limit, please feel free to join.
|
> - All groups are available to join. If you find that the group size is below the limit, please feel free to join.
|
||||||
|
|||||||
@@ -128,6 +128,9 @@ The default AstrBot configuration is as follows:
|
|||||||
"telegram": {
|
"telegram": {
|
||||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
||||||
},
|
},
|
||||||
|
"discord": {
|
||||||
|
"pre_ack_emoji": {"enable": False, "emojis": ["🤔"]},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"wake_prefix": ["/"],
|
"wake_prefix": ["/"],
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
@@ -511,6 +514,11 @@ When enabled, AstrBot sends a pre-reply emoji before requesting the LLM to infor
|
|||||||
- `enable`: Whether to enable pre-reply emojis for Telegram messages. Default is `false`.
|
- `enable`: Whether to enable pre-reply emojis for Telegram messages. Default is `false`.
|
||||||
- `emojis`: List of pre-reply emojis. Default is `["✍️"]`. Telegram only supports a fixed set of reactions; refer to [reactions.txt](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9).
|
- `emojis`: List of pre-reply emojis. Default is `["✍️"]`. Telegram only supports a fixed set of reactions; refer to [reactions.txt](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9).
|
||||||
|
|
||||||
|
##### discord
|
||||||
|
|
||||||
|
- `enable`: Whether to enable pre-reply emojis for Discord messages. Default is `false`.
|
||||||
|
- `emojis`: List of pre-reply emojis. Default is `["🤔"]`. Refer to [Discord Reaction FAQ](https://support.discord.com/hc/en-us/articles/12102061808663-Reactions-and-Super-Reactions-FAQ).
|
||||||
|
|
||||||
### `wake_prefix`
|
### `wake_prefix`
|
||||||
|
|
||||||
Wake prefix. Default is `/`. When a message starts with `/`, AstrBot is awakened.
|
Wake prefix. Default is `/`. When a message starts with `/`, AstrBot is awakened.
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ https://discord.gg/PxgzhmxJ
|
|||||||
|
|
||||||
- [AstrBotDevs/AstrBot](https://github.com/AstrBotDevs/AstrBot)
|
- [AstrBotDevs/AstrBot](https://github.com/AstrBotDevs/AstrBot)
|
||||||
|
|
||||||
- [AstrBotDevs/AstrBot-Docs](https://github.com/AstrBotDevs/AstrBot-docs)
|
|
||||||
|
|
||||||
## 成为 AstrBot 组织成员
|
## 成为 AstrBot 组织成员
|
||||||
|
|
||||||
欢迎加入我们!
|
欢迎加入我们!
|
||||||
|
|||||||
@@ -128,6 +128,9 @@ AstrBot 默认配置如下:
|
|||||||
"telegram": {
|
"telegram": {
|
||||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
||||||
},
|
},
|
||||||
|
"discord": {
|
||||||
|
"pre_ack_emoji": {"enable": False, "emojis": ["🤔"]},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"wake_prefix": ["/"],
|
"wake_prefix": ["/"],
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
@@ -506,11 +509,16 @@ AstrBot WebUI 配置。
|
|||||||
- `enable`: 是否启用飞书消息预回复表情。默认为 `false`。
|
- `enable`: 是否启用飞书消息预回复表情。默认为 `false`。
|
||||||
- `emojis`: 预回复的表情列表。默认为 `["Typing"]`。表情枚举名参考:[表情文案说明](https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce)
|
- `emojis`: 预回复的表情列表。默认为 `["Typing"]`。表情枚举名参考:[表情文案说明](https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce)
|
||||||
|
|
||||||
#### telegram
|
##### telegram
|
||||||
|
|
||||||
- `enable`: 是否启用 Telegram 消息预回复表情。默认为 `false`。
|
- `enable`: 是否启用 Telegram 消息预回复表情。默认为 `false`。
|
||||||
- `emojis`: 预回复的表情列表。默认为 `["✍️"]`。Telegram 仅支持固定反应集合,参考:[reactions.txt](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)
|
- `emojis`: 预回复的表情列表。默认为 `["✍️"]`。Telegram 仅支持固定反应集合,参考:[reactions.txt](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)
|
||||||
|
|
||||||
|
##### discord
|
||||||
|
|
||||||
|
- `enable`: 是否启用 Discord 消息预回复表情。默认为 `false`。
|
||||||
|
- `emojis`: 预回复的表情列表。默认为 `["🤔"]`。Discord反应支持参考:[Discord Reaction FAQ](https://support.discord.com/hc/en-us/articles/12102061808663-Reactions-and-Super-Reactions-FAQ)
|
||||||
|
|
||||||
### `wake_prefix`
|
### `wake_prefix`
|
||||||
|
|
||||||
唤醒前缀。默认为 `/`。当消息以 `/` 开头时,AstrBot 会被唤醒。
|
唤醒前缀。默认为 `/`。当消息以 `/` 开头时,AstrBot 会被唤醒。
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
from astrbot.core.star.star import star_registry
|
from astrbot.core.star.star import star_registry
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry
|
from astrbot.core.star.star_handler import star_handlers_registry
|
||||||
|
from astrbot.core.utils.pip_installer import PipInstallError
|
||||||
from astrbot.dashboard.routes.plugin import PluginRoute
|
from astrbot.dashboard.routes.plugin import PluginRoute
|
||||||
from astrbot.dashboard.server import AstrBotDashboard
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
from tests.fixtures.helpers import (
|
from tests.fixtures.helpers import (
|
||||||
@@ -359,6 +360,35 @@ async def test_do_update(
|
|||||||
assert os.path.exists(release_path)
|
assert os.path.exists(release_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_install_pip_package_returns_pip_install_error_message(
|
||||||
|
app: Quart,
|
||||||
|
authenticated_header: dict,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
test_client = app.test_client()
|
||||||
|
|
||||||
|
async def mock_pip_install(*args, **kwargs):
|
||||||
|
del args, kwargs
|
||||||
|
raise PipInstallError("install failed", code=2)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.dashboard.routes.update.pip_installer.install",
|
||||||
|
mock_pip_install,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/update/pip-install",
|
||||||
|
headers=authenticated_header,
|
||||||
|
json={"package": "demo-package"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert data["message"] == "install failed"
|
||||||
|
|
||||||
|
|
||||||
class _FakeNeoSkills:
|
class _FakeNeoSkills:
|
||||||
async def list_candidates(self, **kwargs):
|
async def list_candidates(self, **kwargs):
|
||||||
_ = kwargs
|
_ = kwargs
|
||||||
|
|||||||
@@ -0,0 +1,266 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrbot.core.utils import core_constraints as core_constraints_module
|
||||||
|
from astrbot.core.utils import requirements_utils
|
||||||
|
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_requirements_utils_parse_package_install_input_collects_specs_and_names():
|
||||||
|
parsed = requirements_utils.parse_package_install_input(
|
||||||
|
"--index-url https://example.com/simple demo-package\nanother-package>=1.0\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert parsed.specs == (
|
||||||
|
"--index-url",
|
||||||
|
"https://example.com/simple",
|
||||||
|
"demo-package",
|
||||||
|
"another-package>=1.0",
|
||||||
|
)
|
||||||
|
assert parsed.requirement_names == {"demo-package", "another-package"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_core_constraints_provider_writes_constraints_file_from_fallback_distribution(
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
class FakeFallbackDistribution:
|
||||||
|
metadata = {"Name": "AstrBot-App"}
|
||||||
|
requires = ["shared-lib>=1.0"]
|
||||||
|
|
||||||
|
def read_text(self, name):
|
||||||
|
if name == "top_level.txt":
|
||||||
|
return "astrbot\n"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
fake_distribution = FakeFallbackDistribution()
|
||||||
|
|
||||||
|
def mock_distribution(name):
|
||||||
|
if name == "AstrBot":
|
||||||
|
raise core_constraints_module.importlib_metadata.PackageNotFoundError
|
||||||
|
if name == "AstrBot-App":
|
||||||
|
return fake_distribution
|
||||||
|
raise core_constraints_module.importlib_metadata.PackageNotFoundError
|
||||||
|
|
||||||
|
def mock_distributions(path=None):
|
||||||
|
del path
|
||||||
|
return [fake_distribution]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
core_constraints_module.importlib_metadata,
|
||||||
|
"distribution",
|
||||||
|
mock_distribution,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
core_constraints_module.importlib_metadata,
|
||||||
|
"distributions",
|
||||||
|
mock_distributions,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
core_constraints_module,
|
||||||
|
"collect_installed_distribution_versions",
|
||||||
|
lambda paths: {"shared-lib": "2.0"},
|
||||||
|
)
|
||||||
|
|
||||||
|
core_constraints_module._get_core_constraints.cache_clear()
|
||||||
|
try:
|
||||||
|
provider = CoreConstraintsProvider(None)
|
||||||
|
with provider.constraints_file() as constraints_path:
|
||||||
|
assert constraints_path is not None
|
||||||
|
assert (
|
||||||
|
Path(constraints_path).read_text(encoding="utf-8") == "shared-lib==2.0"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
core_constraints_module._get_core_constraints.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_core_dist_name_skips_distribution_without_name(monkeypatch):
|
||||||
|
class NamelessDistribution:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
def read_text(self, name):
|
||||||
|
if name == "top_level.txt":
|
||||||
|
return "astrbot\n"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
class NamedDistribution:
|
||||||
|
metadata = {"Name": "AstrBot-App"}
|
||||||
|
|
||||||
|
def read_text(self, name):
|
||||||
|
if name == "top_level.txt":
|
||||||
|
return "astrbot\n"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
core_constraints_module.importlib_metadata,
|
||||||
|
"distribution",
|
||||||
|
lambda name: (_ for _ in ()).throw(
|
||||||
|
core_constraints_module.importlib_metadata.PackageNotFoundError
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
core_constraints_module.importlib_metadata,
|
||||||
|
"distributions",
|
||||||
|
lambda: [NamelessDistribution(), NamedDistribution()],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert core_constraints_module._resolve_core_dist_name(None) == "AstrBot-App"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_missing_requirements_returns_none_when_precheck_gate_fails(
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
|
requirements_path = tmp_path / "requirements.txt"
|
||||||
|
requirements_path.write_text("demo-package\n", encoding="utf-8")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
requirements_utils,
|
||||||
|
"_load_requirement_lines_for_precheck",
|
||||||
|
lambda path: (False, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
missing = requirements_utils.find_missing_requirements(str(requirements_path))
|
||||||
|
|
||||||
|
assert missing is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_package_install_input_tracks_only_named_direct_references():
|
||||||
|
named = requirements_utils.parse_package_install_input(
|
||||||
|
"git+https://example.com/demo.git#egg=demo-package"
|
||||||
|
)
|
||||||
|
unnamed = requirements_utils.parse_package_install_input(
|
||||||
|
"git+https://example.com/demo.git"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert named.requirement_names == {"demo-package"}
|
||||||
|
assert unnamed.requirement_names == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_missing_requirements_or_raise_uses_requirements_exception(tmp_path):
|
||||||
|
requirements_path = tmp_path / "requirements.txt"
|
||||||
|
requirements_path.write_text("-e ../sharedlib\n", encoding="utf-8")
|
||||||
|
|
||||||
|
with pytest.raises(requirements_utils.RequirementsPrecheckFailed):
|
||||||
|
requirements_utils.find_missing_requirements_or_raise(str(requirements_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_missing_requirements_logs_path_and_reason_on_precheck_fallback(
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
|
requirements_path = tmp_path / "requirements.txt"
|
||||||
|
requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8")
|
||||||
|
warning_logs = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.utils.requirements_utils.logger.warning",
|
||||||
|
lambda line, *args: warning_logs.append(line % args if args else line),
|
||||||
|
)
|
||||||
|
|
||||||
|
missing = requirements_utils.find_missing_requirements(str(requirements_path))
|
||||||
|
|
||||||
|
assert missing is None
|
||||||
|
assert any(str(requirements_path) in log for log in warning_logs)
|
||||||
|
assert any("direct reference" in log for log in warning_logs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_requirement_lines_for_precheck_uses_parse_requirement_line_result(
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
|
requirements_path = tmp_path / "requirements.txt"
|
||||||
|
requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
requirements_utils,
|
||||||
|
"_parse_requirement_line",
|
||||||
|
lambda line: ("demo-package", None) if line.startswith("git+") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
can_precheck, requirement_lines = (
|
||||||
|
requirements_utils._load_requirement_lines_for_precheck(str(requirements_path))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert can_precheck is True
|
||||||
|
assert requirement_lines == ["git+https://example.com/demo.git"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_installed_distribution_versions_skips_nameless_distribution(
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
class NamelessDistribution:
|
||||||
|
metadata = {}
|
||||||
|
version = "1.0"
|
||||||
|
|
||||||
|
class NamedDistribution:
|
||||||
|
metadata = {"Name": "demo-package"}
|
||||||
|
version = "2.0"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
requirements_utils.importlib_metadata,
|
||||||
|
"distributions",
|
||||||
|
lambda path: [NamelessDistribution(), NamedDistribution()],
|
||||||
|
)
|
||||||
|
|
||||||
|
installed = requirements_utils.collect_installed_distribution_versions(
|
||||||
|
["/tmp/test"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert installed == {"demo-package": "2.0"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_core_constraints_logs_resolution_step_context(monkeypatch):
|
||||||
|
warning_logs = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
core_constraints_module,
|
||||||
|
"_resolve_core_dist_name",
|
||||||
|
lambda core_dist_name: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.utils.core_constraints.logger.warning",
|
||||||
|
lambda line, *args: warning_logs.append(line % args if args else line),
|
||||||
|
)
|
||||||
|
|
||||||
|
core_constraints_module._get_core_constraints.cache_clear()
|
||||||
|
try:
|
||||||
|
constraints = core_constraints_module._get_core_constraints(None)
|
||||||
|
finally:
|
||||||
|
core_constraints_module._get_core_constraints.cache_clear()
|
||||||
|
|
||||||
|
assert constraints == ()
|
||||||
|
assert any("解析核心分发名称失败" in log for log in warning_logs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_requirements_supports_direct_line_input():
|
||||||
|
parsed = list(
|
||||||
|
requirements_utils.iter_requirements(
|
||||||
|
lines=["demo-package>=1.0", 'other-package; sys_platform == "win32"']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert parsed == [
|
||||||
|
("demo-package", requirements_utils.Requirement("demo-package>=1.0").specifier)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_requirement_name_and_spec_preserves_direct_reference_rules():
|
||||||
|
named = requirements_utils._parse_requirement_name_and_spec(
|
||||||
|
"git+https://example.com/demo.git#egg=demo-package"
|
||||||
|
)
|
||||||
|
unnamed = requirements_utils._parse_requirement_name_and_spec(
|
||||||
|
"git+https://example.com/demo.git"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert named == ("demo-package", None)
|
||||||
|
assert unnamed == (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_requirement_name_and_spec_handles_plain_requirement_token():
|
||||||
|
parsed = requirements_utils._parse_requirement_name_and_spec("demo-package>=1.0")
|
||||||
|
|
||||||
|
assert parsed == (
|
||||||
|
"demo-package",
|
||||||
|
requirements_utils.Requirement("demo-package>=1.0").specifier,
|
||||||
|
)
|
||||||
+1304
-1
File diff suppressed because it is too large
Load Diff
+426
-189
@@ -1,235 +1,472 @@
|
|||||||
import sys
|
import asyncio
|
||||||
from asyncio import Queue
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import yaml
|
||||||
|
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.star.star_manager import PluginDependencyInstallError, PluginManager
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.utils.pip_installer import PipInstallError
|
||||||
from astrbot.core.star.context import Context
|
|
||||||
from astrbot.core.star.star import star_map, star_registry
|
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry
|
|
||||||
from astrbot.core.star.star_manager import PluginManager
|
|
||||||
|
|
||||||
|
# --- Test Data & Helpers ---
|
||||||
|
|
||||||
def _clear_module_cache() -> None:
|
|
||||||
"""Clear module cache for data module tree to ensure test isolation."""
|
|
||||||
modules_to_remove = [
|
|
||||||
key for key in sys.modules if key == "data" or key.startswith("data.")
|
|
||||||
]
|
|
||||||
for key in modules_to_remove:
|
|
||||||
del sys.modules[key]
|
|
||||||
|
|
||||||
|
|
||||||
def _clear_registry(plugin_name: str) -> None:
|
|
||||||
"""Clear plugin from global registries."""
|
|
||||||
# Clear star_registry (list)
|
|
||||||
star_registry[:] = [md for md in star_registry if md.name != plugin_name]
|
|
||||||
# Clear star_map (dict)
|
|
||||||
keys_to_remove = [
|
|
||||||
key for key, md in star_map.items() if md.name == plugin_name
|
|
||||||
]
|
|
||||||
for key in keys_to_remove:
|
|
||||||
del star_map[key]
|
|
||||||
# Clear star_handlers_registry (StarHandlerRegistry)
|
|
||||||
for handler in list(star_handlers_registry):
|
|
||||||
if plugin_name in (handler.handler_module_path or ""):
|
|
||||||
star_handlers_registry.remove(handler)
|
|
||||||
|
|
||||||
TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld"
|
|
||||||
TEST_PLUGIN_DIR = "helloworld"
|
|
||||||
TEST_PLUGIN_NAME = "helloworld"
|
TEST_PLUGIN_NAME = "helloworld"
|
||||||
|
TEST_PLUGIN_REPO = "https://github.com/AstrBotDevs/astrbot_plugin_helloworld"
|
||||||
|
TEST_PLUGIN_DIR = "helloworld"
|
||||||
|
|
||||||
|
|
||||||
def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None:
|
class MockStar:
|
||||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
def __init__(self):
|
||||||
(plugin_dir / "metadata.yaml").write_text(
|
self.root_dir_name = TEST_PLUGIN_DIR
|
||||||
"\n".join(
|
self.name = TEST_PLUGIN_NAME
|
||||||
[
|
self.repo = TEST_PLUGIN_REPO
|
||||||
f"name: {TEST_PLUGIN_NAME}",
|
self.reserved = False
|
||||||
"author: AstrBot Team",
|
self.info = {"repo": TEST_PLUGIN_REPO, "readme": ""}
|
||||||
"desc: Local test plugin",
|
|
||||||
"version: 1.0.0",
|
|
||||||
f"repo: {repo_url}",
|
def _write_local_test_plugin(plugin_path: Path, repo_url: str):
|
||||||
],
|
"""Creates a minimal valid plugin structure."""
|
||||||
)
|
plugin_path.mkdir(parents=True, exist_ok=True)
|
||||||
+ "\n",
|
metadata = {
|
||||||
encoding="utf-8",
|
"name": TEST_PLUGIN_NAME,
|
||||||
)
|
"repo": repo_url,
|
||||||
(plugin_dir / "main.py").write_text(
|
"version": "1.0.0",
|
||||||
"\n".join(
|
"author": "AstrBot Team",
|
||||||
[
|
"desc": "Local test plugin",
|
||||||
"from astrbot.api import star",
|
}
|
||||||
"",
|
with open(plugin_path / "info.yaml", "w", encoding="utf-8") as f:
|
||||||
"class Main(star.Star):",
|
yaml.dump(metadata, f)
|
||||||
" pass",
|
with open(plugin_path / "main.py", "w", encoding="utf-8") as f:
|
||||||
"",
|
f.write("from astrbot.api.star import Star, Context, StarManager\n")
|
||||||
],
|
f.write("@StarManager.register\n")
|
||||||
),
|
f.write("class HelloWorld(Star):\n")
|
||||||
encoding="utf-8",
|
f.write(" def __init__(self, context: Context): ...\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _write_requirements(plugin_path: Path):
|
||||||
|
"""Creates a requirements.txt file."""
|
||||||
|
with open(plugin_path / "requirements.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write("networkx\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_module_cache():
|
||||||
|
"""Clear test-specific modules from sys.modules to allow reloading."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
to_del = [m for m in sys.modules if m.startswith("data.plugins.helloworld")]
|
||||||
|
for m in to_del:
|
||||||
|
del sys.modules[m]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_load_mock(events):
|
||||||
|
async def mock_load(specified_dir_name=None, ignore_version_check=False):
|
||||||
|
del ignore_version_check
|
||||||
|
events.append(("load", specified_dir_name or TEST_PLUGIN_DIR))
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
return mock_load
|
||||||
|
|
||||||
|
|
||||||
|
def _build_reload_mock(events):
|
||||||
|
async def mock_reload(specified_dir_name=None):
|
||||||
|
events.append(("reload", specified_dir_name or TEST_PLUGIN_DIR))
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
return mock_reload
|
||||||
|
|
||||||
|
|
||||||
|
def _build_dependency_install_mock(events, fail: bool):
|
||||||
|
async def mock_install_requirements(
|
||||||
|
*, requirements_path: str = None, package_name: str = None, **kwargs
|
||||||
|
):
|
||||||
|
del kwargs
|
||||||
|
if requirements_path:
|
||||||
|
events.append(("deps", str(requirements_path)))
|
||||||
|
if package_name:
|
||||||
|
events.append(("deps_pkg", package_name))
|
||||||
|
if fail:
|
||||||
|
raise Exception("pip failed")
|
||||||
|
|
||||||
|
return mock_install_requirements
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_missing_requirements(monkeypatch, missing: set[str]):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.find_missing_requirements_or_raise",
|
||||||
|
lambda requirements_path: missing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
def _mock_precheck_fails(monkeypatch):
|
||||||
async def plugin_manager_pm(tmp_path, monkeypatch):
|
from astrbot.core import RequirementsPrecheckFailed
|
||||||
|
|
||||||
|
def mock_fail(requirements_path):
|
||||||
|
raise RequirementsPrecheckFailed("mock precheck failure")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.find_missing_requirements_or_raise",
|
||||||
|
mock_fail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Fixtures ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def plugin_manager_pm(tmp_path, monkeypatch):
|
||||||
"""Provides a fully isolated PluginManager instance for testing."""
|
"""Provides a fully isolated PluginManager instance for testing."""
|
||||||
# Clear module cache before setup to ensure isolation
|
# Clear module cache before setup to ensure isolation
|
||||||
_clear_module_cache()
|
_clear_module_cache()
|
||||||
|
|
||||||
test_root = tmp_path / "astrbot_root"
|
plugin_dir = tmp_path / "astrbot_root" / "data" / "plugins"
|
||||||
data_dir = test_root / "data"
|
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||||
plugin_dir = data_dir / "plugins"
|
|
||||||
config_dir = data_dir / "config"
|
|
||||||
temp_dir = data_dir / "temp"
|
|
||||||
for path in (plugin_dir, config_dir, temp_dir):
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Ensure `import data.plugins.<plugin>.main` resolves to this temp root.
|
class MockContext:
|
||||||
(data_dir / "__init__.py").write_text("", encoding="utf-8")
|
def __init__(self):
|
||||||
(plugin_dir / "__init__.py").write_text("", encoding="utf-8")
|
self.stars = []
|
||||||
|
|
||||||
# Use monkeypatch for both env var and sys.path to ensure proper cleanup
|
def get_all_stars(self):
|
||||||
monkeypatch.setenv("ASTRBOT_ROOT", str(test_root))
|
return self.stars
|
||||||
monkeypatch.syspath_prepend(str(test_root))
|
|
||||||
|
|
||||||
# Create fresh, isolated instances for the context
|
def get_registered_star(self, name):
|
||||||
event_queue = Queue()
|
for s in self.stars:
|
||||||
config = AstrBotConfig()
|
if s.root_dir_name == name or s.name == name:
|
||||||
db = SQLiteDatabase(str(data_dir / "test_db.db"))
|
return s
|
||||||
config.plugin_store_path = str(plugin_dir)
|
return None
|
||||||
|
|
||||||
provider_manager = MagicMock()
|
mock_context = MockContext()
|
||||||
platform_manager = MagicMock()
|
mock_config = {}
|
||||||
conversation_manager = MagicMock()
|
pm = PluginManager(mock_context, mock_config)
|
||||||
message_history_manager = MagicMock()
|
|
||||||
persona_manager = MagicMock()
|
|
||||||
persona_manager.personas_v3 = []
|
|
||||||
astrbot_config_mgr = MagicMock()
|
|
||||||
knowledge_base_manager = MagicMock()
|
|
||||||
cron_manager = MagicMock()
|
|
||||||
|
|
||||||
star_context = Context(
|
# Patch paths to use tmp_path
|
||||||
event_queue=event_queue,
|
monkeypatch.setattr(pm, "plugin_store_path", str(plugin_dir))
|
||||||
config=config,
|
monkeypatch.setattr(
|
||||||
db=db,
|
"astrbot.core.star.star_manager.get_astrbot_plugin_path",
|
||||||
provider_manager=provider_manager,
|
lambda: str(plugin_dir),
|
||||||
platform_manager=platform_manager,
|
|
||||||
conversation_manager=conversation_manager,
|
|
||||||
message_history_manager=message_history_manager,
|
|
||||||
persona_manager=persona_manager,
|
|
||||||
astrbot_config_mgr=astrbot_config_mgr,
|
|
||||||
knowledge_base_manager=knowledge_base_manager,
|
|
||||||
cron_manager=cron_manager,
|
|
||||||
subagent_orchestrator=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
manager = PluginManager(star_context, config)
|
return pm
|
||||||
try:
|
|
||||||
yield manager
|
|
||||||
finally:
|
|
||||||
# Cleanup global registries and module cache
|
|
||||||
_clear_registry(TEST_PLUGIN_NAME)
|
|
||||||
_clear_module_cache()
|
|
||||||
await db.engine.dispose()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def local_updator(plugin_manager_pm: PluginManager, monkeypatch):
|
def local_updator(plugin_manager_pm):
|
||||||
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
"""Helper to setup a local plugin directory simulating a download."""
|
||||||
|
path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||||
|
_write_local_test_plugin(path, TEST_PLUGIN_REPO)
|
||||||
|
return path
|
||||||
|
|
||||||
async def mock_install(repo_url: str, proxy=""): # noqa: ARG001
|
|
||||||
if repo_url != TEST_PLUGIN_REPO:
|
# --- Tests ---
|
||||||
raise Exception("Repo not found")
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||||
|
async def test_install_plugin_dependency_install_flow(
|
||||||
|
plugin_manager_pm: PluginManager, monkeypatch, dependency_install_fails: bool
|
||||||
|
):
|
||||||
|
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||||
|
events = []
|
||||||
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
|
|
||||||
|
async def mock_install(repo_url: str, proxy=""):
|
||||||
|
assert repo_url == TEST_PLUGIN_REPO
|
||||||
_write_local_test_plugin(plugin_path, repo_url)
|
_write_local_test_plugin(plugin_path, repo_url)
|
||||||
|
_write_requirements(plugin_path)
|
||||||
return str(plugin_path)
|
return str(plugin_path)
|
||||||
|
|
||||||
async def mock_update(plugin, proxy=""): # noqa: ARG001
|
|
||||||
if plugin.name != TEST_PLUGIN_NAME:
|
|
||||||
raise Exception("Plugin not found")
|
|
||||||
if not plugin_path.exists():
|
|
||||||
raise Exception("Plugin path missing")
|
|
||||||
(plugin_path / ".updated").write_text("ok", encoding="utf-8")
|
|
||||||
|
|
||||||
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
||||||
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
|
monkeypatch.setattr(
|
||||||
return plugin_path
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
_build_dependency_install_mock(events, dependency_install_fails),
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_load_and_register(*args, **kwargs):
|
||||||
|
plugin_manager_pm.context.stars.append(MockStar())
|
||||||
|
return _build_load_mock(events)(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||||
|
|
||||||
|
if dependency_install_fails:
|
||||||
|
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||||
|
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||||
|
assert events == [("deps", str(plugin_path / "requirements.txt"))]
|
||||||
|
else:
|
||||||
|
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||||
|
assert events == [
|
||||||
|
("deps", str(plugin_path / "requirements.txt")),
|
||||||
|
("load", TEST_PLUGIN_DIR),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||||
assert plugin_manager_pm is not None
|
async def test_install_plugin_from_file_dependency_install_flow(
|
||||||
assert plugin_manager_pm.context is not None
|
plugin_manager_pm: PluginManager,
|
||||||
assert plugin_manager_pm.config is not None
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
dependency_install_fails: bool,
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
|
||||||
success, err_message = await plugin_manager_pm.reload()
|
|
||||||
assert success is True
|
|
||||||
assert err_message is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
|
||||||
"""Tests successful plugin installation without external network."""
|
|
||||||
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
|
||||||
assert plugin_info is not None
|
|
||||||
assert plugin_info["name"] == TEST_PLUGIN_NAME
|
|
||||||
assert local_updator.exists()
|
|
||||||
assert any(md.name == TEST_PLUGIN_NAME for md in star_registry)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_install_nonexistent_plugin(
|
|
||||||
plugin_manager_pm: PluginManager, local_updator
|
|
||||||
):
|
):
|
||||||
"""Tests that installing a non-existent plugin raises an exception."""
|
zip_file_path = tmp_path / f"{TEST_PLUGIN_DIR}.zip"
|
||||||
with pytest.raises(Exception):
|
zip_file_path.write_text("placeholder", encoding="utf-8")
|
||||||
await plugin_manager_pm.install_plugin(
|
events = []
|
||||||
"https://github.com/Soulter/non_existent_repo"
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
|
|
||||||
|
def mock_unzip_file(zip_path: str, target_dir: str) -> None:
|
||||||
|
assert zip_path == str(zip_file_path)
|
||||||
|
plugin_path = Path(target_dir)
|
||||||
|
_write_local_test_plugin(plugin_path, TEST_PLUGIN_REPO)
|
||||||
|
_write_requirements(plugin_path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm.updator, "unzip_file", mock_unzip_file)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
_build_dependency_install_mock(events, dependency_install_fails),
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_load_and_register(*args, **kwargs):
|
||||||
|
plugin_manager_pm.context.stars.append(MockStar())
|
||||||
|
return _build_load_mock(events)(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||||
|
|
||||||
|
if dependency_install_fails:
|
||||||
|
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||||
|
await plugin_manager_pm.install_plugin_from_file(str(zip_file_path))
|
||||||
|
assert any(e[0] == "deps" for e in events)
|
||||||
|
else:
|
||||||
|
await plugin_manager_pm.install_plugin_from_file(str(zip_file_path))
|
||||||
|
assert any(e[0] == "deps" for e in events)
|
||||||
|
assert ("load", TEST_PLUGIN_DIR) in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||||
|
async def test_reload_failed_plugin_dependency_install_flow(
|
||||||
|
plugin_manager_pm: PluginManager,
|
||||||
|
local_updator: Path,
|
||||||
|
monkeypatch,
|
||||||
|
dependency_install_fails: bool,
|
||||||
|
):
|
||||||
|
_write_requirements(local_updator)
|
||||||
|
plugin_manager_pm.failed_plugin_dict[TEST_PLUGIN_DIR] = {"error": "init fail"}
|
||||||
|
events = []
|
||||||
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
_build_dependency_install_mock(events, dependency_install_fails),
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_load_and_register(*args, **kwargs):
|
||||||
|
plugin_manager_pm.context.stars.append(MockStar())
|
||||||
|
return _build_load_mock(events)(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||||
|
|
||||||
|
if dependency_install_fails:
|
||||||
|
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||||
|
await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR)
|
||||||
|
assert events == [("deps", str(local_updator / "requirements.txt"))]
|
||||||
|
else:
|
||||||
|
await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR)
|
||||||
|
assert events == [
|
||||||
|
("deps", str(local_updator / "requirements.txt")),
|
||||||
|
("load", TEST_PLUGIN_DIR),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_plugin_requirements_reraises_cancelled_error(
|
||||||
|
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||||
|
):
|
||||||
|
_write_requirements(local_updator)
|
||||||
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
|
|
||||||
|
async def mock_install_requirements(*args, **kwargs):
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
mock_install_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await plugin_manager_pm._ensure_plugin_requirements(
|
||||||
|
str(local_updator),
|
||||||
|
TEST_PLUGIN_DIR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
async def test_ensure_plugin_requirements_wraps_generic_dependency_install_failure(
|
||||||
"""Tests updating an existing plugin without external network."""
|
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||||
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
|
||||||
assert plugin_info is not None
|
|
||||||
plugin_name = plugin_info["name"]
|
|
||||||
await plugin_manager_pm.update_plugin(plugin_name)
|
|
||||||
assert (local_updator / ".updated").exists()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_nonexistent_plugin(
|
|
||||||
plugin_manager_pm: PluginManager, local_updator
|
|
||||||
):
|
):
|
||||||
"""Tests that updating a non-existent plugin raises an exception."""
|
_write_requirements(local_updator)
|
||||||
with pytest.raises(Exception):
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
await plugin_manager_pm.update_plugin("non_existent_plugin")
|
|
||||||
|
|
||||||
|
async def mock_install_requirements(*args, **kwargs):
|
||||||
|
raise RuntimeError("pip failed")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
monkeypatch.setattr(
|
||||||
async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
"""Tests successful plugin uninstallation."""
|
mock_install_requirements,
|
||||||
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
|
||||||
assert plugin_info is not None
|
|
||||||
plugin_name = plugin_info["name"]
|
|
||||||
assert local_updator.exists()
|
|
||||||
|
|
||||||
await plugin_manager_pm.uninstall_plugin(plugin_name)
|
|
||||||
|
|
||||||
assert not local_updator.exists()
|
|
||||||
assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry)
|
|
||||||
assert not any(
|
|
||||||
TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with pytest.raises(PluginDependencyInstallError, match="pip failed") as exc_info:
|
||||||
|
await plugin_manager_pm._ensure_plugin_requirements(
|
||||||
|
str(local_updator),
|
||||||
|
TEST_PLUGIN_DIR,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.plugin_label == TEST_PLUGIN_DIR
|
||||||
|
assert exc_info.value.requirements_path == str(local_updator / "requirements.txt")
|
||||||
|
assert isinstance(exc_info.value.__cause__, RuntimeError)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
async def test_ensure_plugin_requirements_wraps_pip_install_error(
|
||||||
"""Tests that uninstalling a non-existent plugin raises an exception."""
|
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||||
with pytest.raises(Exception):
|
):
|
||||||
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")
|
_write_requirements(local_updator)
|
||||||
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
|
|
||||||
|
async def mock_install_requirements(*args, **kwargs):
|
||||||
|
raise PipInstallError("install failed", code=2)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
mock_install_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(PluginDependencyInstallError, match="install failed") as exc_info:
|
||||||
|
await plugin_manager_pm._ensure_plugin_requirements(
|
||||||
|
str(local_updator),
|
||||||
|
TEST_PLUGIN_DIR,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(exc_info.value.__cause__, PipInstallError)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_plugin_requirements_logs_requirements_file_install_for_missing_dependencies(
|
||||||
|
plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch
|
||||||
|
):
|
||||||
|
_write_requirements(local_updator)
|
||||||
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
|
logged_lines = []
|
||||||
|
|
||||||
|
async def mock_install_requirements(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
mock_install_requirements,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.logger.info",
|
||||||
|
lambda line, *args: logged_lines.append(line % args if args else line),
|
||||||
|
)
|
||||||
|
|
||||||
|
await plugin_manager_pm._ensure_plugin_requirements(
|
||||||
|
str(local_updator),
|
||||||
|
TEST_PLUGIN_DIR,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert any("按 requirements.txt 安装" in line for line in logged_lines)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("dependency_install_fails", [False, True])
|
||||||
|
async def test_update_plugin_dependency_install_flow(
|
||||||
|
plugin_manager_pm: PluginManager,
|
||||||
|
local_updator: Path,
|
||||||
|
monkeypatch,
|
||||||
|
dependency_install_fails: bool,
|
||||||
|
):
|
||||||
|
mock_star = MockStar()
|
||||||
|
plugin_manager_pm.context.stars.append(mock_star)
|
||||||
|
|
||||||
|
_write_requirements(local_updator)
|
||||||
|
events = []
|
||||||
|
_mock_missing_requirements(monkeypatch, {"networkx"})
|
||||||
|
|
||||||
|
async def mock_update(plugin, proxy=""):
|
||||||
|
del proxy
|
||||||
|
events.append(("update", plugin.name))
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
_build_dependency_install_mock(events, dependency_install_fails),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(plugin_manager_pm, "reload", _build_reload_mock(events))
|
||||||
|
|
||||||
|
if dependency_install_fails:
|
||||||
|
with pytest.raises(PluginDependencyInstallError, match="pip failed"):
|
||||||
|
await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME)
|
||||||
|
assert ("deps", str(local_updator / "requirements.txt")) in events
|
||||||
|
else:
|
||||||
|
await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME)
|
||||||
|
assert ("deps", str(local_updator / "requirements.txt")) in events
|
||||||
|
assert ("reload", TEST_PLUGIN_DIR) in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_install_plugin_skips_dependency_install_when_no_requirements_missing(
|
||||||
|
plugin_manager_pm: PluginManager, monkeypatch
|
||||||
|
):
|
||||||
|
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||||
|
events = []
|
||||||
|
_mock_missing_requirements(monkeypatch, set())
|
||||||
|
|
||||||
|
async def mock_install(repo_url: str, proxy=""):
|
||||||
|
_write_local_test_plugin(plugin_path, repo_url)
|
||||||
|
_write_requirements(plugin_path)
|
||||||
|
return str(plugin_path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
_build_dependency_install_mock(events, False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_load_and_register(*args, **kwargs):
|
||||||
|
plugin_manager_pm.context.stars.append(MockStar())
|
||||||
|
return _build_load_mock(events)(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||||
|
|
||||||
|
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||||
|
|
||||||
|
assert "deps" not in [e[0] for e in events]
|
||||||
|
assert ("load", TEST_PLUGIN_DIR) in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_install_plugin_runs_dependency_install_when_precheck_fails(
|
||||||
|
plugin_manager_pm: PluginManager, monkeypatch
|
||||||
|
):
|
||||||
|
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def mock_install(repo_url: str, proxy=""):
|
||||||
|
_write_local_test_plugin(plugin_path, repo_url)
|
||||||
|
_write_requirements(plugin_path)
|
||||||
|
return str(plugin_path)
|
||||||
|
|
||||||
|
_mock_precheck_fails(monkeypatch)
|
||||||
|
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.star.star_manager.pip_installer.install",
|
||||||
|
_build_dependency_install_mock(events, False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_load_and_register(*args, **kwargs):
|
||||||
|
plugin_manager_pm.context.stars.append(MockStar())
|
||||||
|
return _build_load_mock(events)(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm, "load", mock_load_and_register)
|
||||||
|
|
||||||
|
await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||||
|
|
||||||
|
assert ("deps", str(plugin_path / "requirements.txt")) in events
|
||||||
|
assert ("load", TEST_PLUGIN_DIR) in events
|
||||||
|
|||||||
Reference in New Issue
Block a user