diff --git a/.gitignore b/.gitignore index 953723dff..f9c750715 100644 --- a/.gitignore +++ b/.gitignore @@ -61,4 +61,5 @@ GenieData/ .codex/ .opencode/ .kilocode/ -.serena \ No newline at end of file +.serena +.worktrees/ diff --git a/README.md b/README.md index d9b3cf019..8ad697131 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,8 @@ pre-commit install - Group 7: 743746109 - Group 8: 1030353265 -- Developer Group: 975206796 +- Developer Group(Chit-chat): 975206796 +- Developer Group(Formal): 1039761811 ### Discord Server diff --git a/README_fr.md b/README_fr.md index 179e42005..e406d32b2 100644 --- a/README_fr.md +++ b/README_fr.md @@ -222,6 +222,7 @@ pre-commit install - Groupe 5 : 822130018 - Groupe 6 : 753075035 - Groupe développeurs : 975206796 +- Groupe développeurs (officiel) : 1039761811 ### Serveur Discord diff --git a/README_ja.md b/README_ja.md index b8a7bba04..7aa146c13 100644 --- a/README_ja.md +++ b/README_ja.md @@ -223,6 +223,7 @@ pre-commit install - 5群: 822130018 - 6群: 753075035 - 開発者群: 975206796 +- 開発者群(正式): 1039761811 ### Discord サーバー diff --git a/README_ru.md b/README_ru.md index 45cebf627..35da14acb 100644 --- a/README_ru.md +++ b/README_ru.md @@ -222,6 +222,7 @@ pre-commit install - Группа 5: 822130018 - Группа 6: 753075035 - Группа разработчиков: 975206796 +- Группа разработчиков (официальная): 1039761811 ### Сервер Discord diff --git a/README_zh-TW.md b/README_zh-TW.md index e42b228e5..1ace852b8 100644 --- a/README_zh-TW.md +++ b/README_zh-TW.md @@ -225,7 +225,8 @@ pre-commit install - 6 群:753075035 - 7 群:743746109 - 8 群:1030353265 -- 開發者群:975206796 +- 開發者群(闲聊吹水):975206796 +- 開發者群(正式):1039761811 ### Discord 群組 diff --git a/README_zh.md b/README_zh.md index 1c45ac5a9..e13d9b4e5 100644 --- a/README_zh.md +++ b/README_zh.md @@ -226,7 +226,8 @@ pre-commit install - 6 群:753075035 - 7 群:743746109 - 8 群:1030353265 -- 开发者群:975206796 +- 开发者群(偏闲聊吹水):975206796 +- 开发者群(正式):1039761811 ### Discord 频道 diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 5c015e96e..51690ede2 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -4,7 +4,21 @@ from astrbot.core.config import AstrBotConfig from astrbot.core.config.default import DB_PATH from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.file_token_service import FileTokenService -from astrbot.core.utils.pip_installer import PipInstaller +from astrbot.core.utils.pip_installer import ( + DependencyConflictError as DependencyConflictError, +) +from astrbot.core.utils.pip_installer import ( + PipInstaller, +) +from astrbot.core.utils.requirements_utils import ( + RequirementsPrecheckFailed as RequirementsPrecheckFailed, +) +from astrbot.core.utils.requirements_utils import ( + find_missing_requirements as find_missing_requirements, +) +from astrbot.core.utils.requirements_utils import ( + find_missing_requirements_or_raise as find_missing_requirements_or_raise, +) from astrbot.core.utils.shared_preferences import SharedPreferences from astrbot.core.utils.t2i.renderer import HtmlRenderer diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index a80ef0da2..cf7d2e079 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import locale import os import shutil import subprocess @@ -52,6 +53,31 @@ def _ensure_safe_path(path: str) -> str: return abs_path +def _decode_shell_output(output: bytes | None) -> str: + if output is None: + return "" + + preferred = locale.getpreferredencoding(False) or "utf-8" + try: + return output.decode("utf-8") + except (LookupError, UnicodeDecodeError): + pass + + if os.name == "nt": + for encoding in ("mbcs", "cp936", "gbk", "gb18030"): + try: + return output.decode(encoding) + except (LookupError, UnicodeDecodeError): + continue + + try: + return output.decode(preferred) + except (LookupError, UnicodeDecodeError): + pass + + return output.decode("utf-8", errors="replace") + + @dataclass class LocalShellComponent(ShellComponent): async def exec( @@ -72,28 +98,32 @@ class LocalShellComponent(ShellComponent): run_env.update({str(k): str(v) for k, v in env.items()}) working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root() if background: - proc = subprocess.Popen( + # `command` is intentionally executed through the current shell so + # local computer-use behavior matches existing tool semantics. + # Safety relies on `_is_safe_command()` and the allowed-root checks. + proc = subprocess.Popen( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit command, shell=shell, cwd=working_dir, env=run_env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, ) return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None} - result = subprocess.run( + # `command` is intentionally executed through the current shell so + # local computer-use behavior matches existing tool semantics. + # Safety relies on `_is_safe_command()` and the allowed-root checks. + result = subprocess.run( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit command, shell=shell, cwd=working_dir, env=run_env, timeout=timeout, capture_output=True, - text=True, ) return { - "stdout": result.stdout, - "stderr": result.stderr, + "stdout": _decode_shell_output(result.stdout), + "stderr": _decode_shell_output(result.stderr), "exit_code": result.returncode, } diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 029f3bfa9..48c957574 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -434,6 +434,12 @@ async def get_booter( ) -> ComputerBooter: config = context.get_config(umo=session_id) + runtime = config.get("provider_settings", {}).get("computer_use_runtime", "local") + if runtime == "local": + return get_local_booter() + elif runtime == "none": + raise RuntimeError("Sandbox runtime is disabled by configuration.") + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) booter_type = sandbox_cfg.get("booter", "shipyard_neo") diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 2e32073a9..ba656ff53 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -219,6 +219,9 @@ DEFAULT_CONFIG = { "telegram": { "pre_ack_emoji": {"enable": False, "emojis": ["✍️"]}, }, + "discord": { + "pre_ack_emoji": {"enable": False, "emojis": ["🤔"]}, + }, }, "wake_prefix": ["/"], "log_level": "INFO", diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 2b417f45f..d1fd0e187 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -18,7 +18,7 @@ from botpy.types.message import MarkdownPayload, Media from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Image, Plain, Record +from astrbot.api.message_components import File, Image, Plain, Record, Video from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.io import download_image_by_url, file_to_base64 @@ -47,6 +47,10 @@ _patch_qq_botpy_formdata() class QQOfficialMessageEvent(AstrMessageEvent): MARKDOWN_NOT_ALLOWED_ERROR = "不允许发送原生 markdown" + IMAGE_FILE_TYPE = 1 + VIDEO_FILE_TYPE = 2 + VOICE_FILE_TYPE = 3 + FILE_FILE_TYPE = 4 def __init__( self, @@ -126,6 +130,9 @@ class QQOfficialMessageEvent(AstrMessageEvent): image_base64, image_path, record_file_path, + video_file_source, + file_source, + file_name, ) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer) if ( @@ -133,6 +140,8 @@ class QQOfficialMessageEvent(AstrMessageEvent): and not image_base64 and not image_path and not record_file_path + and not video_file_source + and not file_source ): return None @@ -157,7 +166,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): if image_base64: media = await self.upload_group_and_c2c_image( image_base64, - 1, + self.IMAGE_FILE_TYPE, group_openid=source.group_openid, ) payload["media"] = media @@ -165,15 +174,39 @@ class QQOfficialMessageEvent(AstrMessageEvent): payload.pop("markdown", None) payload["content"] = plain_text or None if record_file_path: # group record msg - media = await self.upload_group_and_c2c_record( + media = await self.upload_group_and_c2c_media( record_file_path, - 3, + self.VOICE_FILE_TYPE, group_openid=source.group_openid, ) - payload["media"] = media - payload["msg_type"] = 7 - payload.pop("markdown", None) - payload["content"] = plain_text or None + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + payload["content"] = plain_text or None + if video_file_source: + media = await self.upload_group_and_c2c_media( + video_file_source, + self.VIDEO_FILE_TYPE, + group_openid=source.group_openid, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + payload["content"] = plain_text or None + if file_source: + media = await self.upload_group_and_c2c_media( + file_source, + self.FILE_FILE_TYPE, + file_name=file_name, + group_openid=source.group_openid, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + payload["content"] = plain_text or None ret = await self._send_with_markdown_fallback( send_func=lambda retry_payload: self.bot.api.post_group_message( group_openid=source.group_openid, # type: ignore @@ -187,7 +220,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): if image_base64: media = await self.upload_group_and_c2c_image( image_base64, - 1, + self.IMAGE_FILE_TYPE, openid=source.author.user_openid, ) payload["media"] = media @@ -195,15 +228,39 @@ class QQOfficialMessageEvent(AstrMessageEvent): payload.pop("markdown", None) payload["content"] = plain_text or None if record_file_path: # c2c record - media = await self.upload_group_and_c2c_record( + media = await self.upload_group_and_c2c_media( record_file_path, - 3, + self.VOICE_FILE_TYPE, openid=source.author.user_openid, ) - payload["media"] = media - payload["msg_type"] = 7 - payload.pop("markdown", None) - payload["content"] = plain_text or None + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + payload["content"] = plain_text or None + if video_file_source: + media = await self.upload_group_and_c2c_media( + video_file_source, + self.VIDEO_FILE_TYPE, + openid=source.author.user_openid, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + payload["content"] = plain_text or None + if file_source: + media = await self.upload_group_and_c2c_media( + file_source, + self.FILE_FILE_TYPE, + file_name=file_name, + openid=source.author.user_openid, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + payload["content"] = plain_text or None if stream: ret = await self._send_with_markdown_fallback( send_func=lambda retry_payload: self.post_c2c_message( @@ -327,16 +384,19 @@ class QQOfficialMessageEvent(AstrMessageEvent): ttl=result.get("ttl", 0), ) - async def upload_group_and_c2c_record( + async def upload_group_and_c2c_media( self, file_source: str, file_type: int, srv_send_msg: bool = False, + file_name: str | None = None, **kwargs, ) -> Media | None: """上传媒体文件""" # 构建基础payload payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} + if file_name: + payload["file_name"] = file_name # 处理文件数据 if os.path.exists(file_source): @@ -416,6 +476,9 @@ class QQOfficialMessageEvent(AstrMessageEvent): image_base64 = None # only one img supported image_file_path = None record_file_path = None + video_file_source = None + file_source = None + file_name = None for i in message.chain: if isinstance(i, Plain): plain_text += i.text @@ -454,6 +517,30 @@ class QQOfficialMessageEvent(AstrMessageEvent): except Exception as e: logger.error(f"处理语音时出错: {e}") record_file_path = None + elif isinstance(i, Video) and not video_file_source: + if i.file.startswith("file:///"): + video_file_source = i.file[8:] + else: + video_file_source = i.file + elif isinstance(i, File) and not file_source: + file_name = i.name + if i.file_: + file_path = i.file_ + if file_path.startswith("file:///"): + file_path = file_path[8:] + elif file_path.startswith("file://"): + file_path = file_path[7:] + file_source = file_path + elif i.url: + file_source = i.url else: logger.debug(f"qq_official 忽略 {i.type}") - return plain_text, image_base64, image_file_path, record_file_path + return ( + plain_text, + image_base64, + image_file_path, + record_file_path, + video_file_source, + file_source, + file_name, + ) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 603bc8f58..88d4a2128 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio import logging import os +import random import time -from typing import cast +from types import SimpleNamespace +from typing import Any, cast import botpy import botpy.message @@ -12,7 +14,7 @@ from botpy import Client from astrbot import logger from astrbot.api.event import MessageChain -from astrbot.api.message_components import At, File, Image, Plain +from astrbot.api.message_components import At, File, Image, Plain, Record, Video from astrbot.api.platform import ( AstrBotMessage, MessageMember, @@ -46,6 +48,7 @@ class botClient(Client): ) abm.group_id = cast(str, message.group_openid) abm.session_id = abm.group_id + self.platform.remember_session_scene(abm.session_id, "group") self._commit(abm) # 收到频道消息 @@ -56,6 +59,7 @@ class botClient(Client): ) abm.group_id = message.channel_id abm.session_id = abm.group_id + self.platform.remember_session_scene(abm.session_id, "channel") self._commit(abm) # 收到私聊消息 @@ -67,6 +71,7 @@ class botClient(Client): MessageType.FRIEND_MESSAGE, ) abm.session_id = abm.sender.user_id + self.platform.remember_session_scene(abm.session_id, "friend") self._commit(abm) # 收到 C2C 消息 @@ -76,9 +81,11 @@ class botClient(Client): MessageType.FRIEND_MESSAGE, ) abm.session_id = abm.sender.user_id + self.platform.remember_session_scene(abm.session_id, "friend") self._commit(abm) def _commit(self, abm: AstrBotMessage) -> None: + self.platform.remember_session_message_id(abm.session_id, abm.message_id) self.platform.commit_event( QQOfficialMessageEvent( abm.message_str, @@ -124,6 +131,9 @@ class QQOfficialPlatformAdapter(Platform): self.client.set_platform(self) + self._session_last_message_id: dict[str, str] = {} + self._session_scene: dict[str, str] = {} + self.test_mode = os.environ.get("TEST_MODE", "off") == "on" async def send_by_session( @@ -131,14 +141,185 @@ class QQOfficialPlatformAdapter(Platform): session: MessageSesion, message_chain: MessageChain, ) -> None: - raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") + await self._send_by_session_common(session, message_chain) + + async def _send_by_session_common( + self, + session: MessageSesion, + message_chain: MessageChain, + ) -> None: + ( + plain_text, + image_base64, + image_path, + record_file_path, + video_file_source, + file_source, + file_name, + ) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain) + if ( + not plain_text + and not image_path + and not image_base64 + and not record_file_path + and not video_file_source + and not file_source + ): + return + + msg_id = self._session_last_message_id.get(session.session_id) + if not msg_id: + logger.warning( + "[QQOfficial] No cached msg_id for session: %s, skip send_by_session", + session.session_id, + ) + return + + payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id} + ret: Any = None + send_helper = SimpleNamespace(bot=self.client) + + if session.message_type == MessageType.GROUP_MESSAGE: + scene = self._session_scene.get(session.session_id) + if scene == "group": + payload["msg_seq"] = random.randint(1, 10000) + if image_base64: + media = await QQOfficialMessageEvent.upload_group_and_c2c_image( + send_helper, # type: ignore + image_base64, + QQOfficialMessageEvent.IMAGE_FILE_TYPE, + group_openid=session.session_id, + ) + payload["media"] = media + payload["msg_type"] = 7 + if record_file_path: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + record_file_path, + QQOfficialMessageEvent.VOICE_FILE_TYPE, + group_openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if video_file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + video_file_source, + QQOfficialMessageEvent.VIDEO_FILE_TYPE, + group_openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + file_source, + QQOfficialMessageEvent.FILE_FILE_TYPE, + file_name=file_name, + group_openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + ret = await self.client.api.post_group_message( + group_openid=session.session_id, + **payload, + ) + else: + if image_path: + payload["file_image"] = image_path + ret = await self.client.api.post_message( + channel_id=session.session_id, + **payload, + ) + + elif session.message_type == MessageType.FRIEND_MESSAGE: + payload["msg_seq"] = random.randint(1, 10000) + if image_base64: + media = await QQOfficialMessageEvent.upload_group_and_c2c_image( + send_helper, # type: ignore + image_base64, + QQOfficialMessageEvent.IMAGE_FILE_TYPE, + openid=session.session_id, + ) + payload["media"] = media + payload["msg_type"] = 7 + if record_file_path: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + record_file_path, + QQOfficialMessageEvent.VOICE_FILE_TYPE, + openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if video_file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + video_file_source, + QQOfficialMessageEvent.VIDEO_FILE_TYPE, + openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + file_source, + QQOfficialMessageEvent.FILE_FILE_TYPE, + file_name=file_name, + openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + + ret = await QQOfficialMessageEvent.post_c2c_message( + send_helper, # type: ignore + openid=session.session_id, + **payload, + ) + else: + logger.warning( + "[QQOfficial] Unsupported message type for send_by_session: %s", + session.message_type, + ) + return + + sent_message_id = self._extract_message_id(ret) + if sent_message_id: + self.remember_session_message_id(session.session_id, sent_message_id) + await super().send_by_session(session, message_chain) + + def remember_session_message_id(self, session_id: str, message_id: str) -> None: + if not session_id or not message_id: + return + self._session_last_message_id[session_id] = message_id + + def remember_session_scene(self, session_id: str, scene: str) -> None: + if not session_id or not scene: + return + self._session_scene[session_id] = scene + + def _extract_message_id(self, ret: Any) -> str | None: + if isinstance(ret, dict): + message_id = ret.get("id") + return str(message_id) if message_id else None + message_id = getattr(ret, "id", None) + if message_id: + return str(message_id) + return None def meta(self) -> PlatformMetadata: return PlatformMetadata( name="qq_official", description="QQ 机器人官方 API 适配器", id=cast(str, self.config.get("id")), - support_proactive_message=False, + support_proactive_message=True, ) @staticmethod @@ -158,7 +339,10 @@ class QQOfficialPlatformAdapter(Platform): return for attachment in attachments: - content_type = cast(str, getattr(attachment, "content_type", "") or "") + content_type = cast( + str, + getattr(attachment, "content_type", "") or "", + ).lower() url = QQOfficialPlatformAdapter._normalize_attachment_url( cast(str | None, getattr(attachment, "url", None)) ) @@ -174,7 +358,32 @@ class QQOfficialPlatformAdapter(Platform): or getattr(attachment, "name", None) or "attachment", ) - msg.append(File(name=filename, file=url, url=url)) + ext = os.path.splitext(filename)[1].lower() + image_exts = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} + audio_exts = { + ".mp3", + ".wav", + ".ogg", + ".m4a", + ".amr", + ".silk", + } + video_exts = { + ".mp4", + ".mov", + ".avi", + ".mkv", + ".webm", + } + + if content_type.startswith("audio") or ext in audio_exts: + msg.append(Record.fromURL(url)) + elif content_type.startswith("video") or ext in video_exts: + msg.append(Video.fromURL(url)) + elif content_type.startswith("image") or ext in image_exts: + msg.append(Image.fromURL(url)) + else: + msg.append(File(name=filename, file=url, url=url)) @staticmethod def _parse_from_qqofficial( diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 6aae6b9ce..4c73fdf38 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,7 +1,5 @@ import asyncio import logging -import random -from types import SimpleNamespace from typing import Any, cast import botpy @@ -15,7 +13,6 @@ from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.utils.webhook_utils import log_webhook_info from ...register import register_platform_adapter -from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter from .qo_webhook_event import QQOfficialWebhookMessageEvent from .qo_webhook_server import QQOfficialWebhook @@ -123,95 +120,11 @@ class QQOfficialWebhookPlatformAdapter(Platform): session: MessageSesion, message_chain: MessageChain, ) -> None: - ( - plain_text, - image_base64, - image_path, - record_file_path, - ) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain) - if not plain_text and not image_path: - return - - msg_id = self._session_last_message_id.get(session.session_id) - if not msg_id: - logger.warning( - "[QQOfficialWebhook] No cached msg_id for session: %s, skip send_by_session", - session.session_id, - ) - return - - payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id} - ret: Any = None - send_helper = SimpleNamespace(bot=self.client) - if session.message_type == MessageType.GROUP_MESSAGE: - scene = self._session_scene.get(session.session_id) - if scene == "group": - payload["msg_seq"] = random.randint(1, 10000) - if image_base64: - media = await QQOfficialMessageEvent.upload_group_and_c2c_image( - send_helper, # type: ignore - image_base64, - 1, - group_openid=session.session_id, - ) - payload["media"] = media - payload["msg_type"] = 7 - if record_file_path: - media = await QQOfficialMessageEvent.upload_group_and_c2c_record( - send_helper, # type: ignore - record_file_path, - 3, - group_openid=session.session_id, - ) - payload["media"] = media - payload["msg_type"] = 7 - ret = await self.client.api.post_group_message( - group_openid=session.session_id, - **payload, - ) - else: - if image_path: - payload["file_image"] = image_path - ret = await self.client.api.post_message( - channel_id=session.session_id, - **payload, - ) - elif session.message_type == MessageType.FRIEND_MESSAGE: - payload["msg_seq"] = random.randint(1, 10000) - if image_base64: - media = await QQOfficialMessageEvent.upload_group_and_c2c_image( - send_helper, # type: ignore - image_base64, - 1, - openid=session.session_id, - ) - payload["media"] = media - payload["msg_type"] = 7 - if record_file_path: - media = await QQOfficialMessageEvent.upload_group_and_c2c_record( - send_helper, # type: ignore - record_file_path, - 3, - openid=session.session_id, - ) - payload["media"] = media - payload["msg_type"] = 7 - ret = await QQOfficialMessageEvent.post_c2c_message( - send_helper, # type: ignore - openid=session.session_id, - **payload, - ) - else: - logger.warning( - "[QQOfficialWebhook] Unsupported message type for send_by_session: %s", - session.message_type, - ) - return - - sent_message_id = self._extract_message_id(ret) - if sent_message_id: - self.remember_session_message_id(session.session_id, sent_message_id) - await super().send_by_session(session, message_chain) + await QQOfficialPlatformAdapter._send_by_session_common( + cast(Any, self), + session, + message_chain, + ) def remember_session_message_id(self, session_id: str, message_id: str) -> None: if not session_id or not message_id: diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 43e58960e..e75fb9214 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -278,7 +278,6 @@ class TelegramPlatformEvent(AstrMessageEvent): try: md_text = telegramify_markdown.markdownify( chunk, - normalize_whitespace=False, ) await client.send_message( text=md_text, @@ -456,7 +455,6 @@ class TelegramPlatformEvent(AstrMessageEvent): try: markdown_text = telegramify_markdown.markdownify( delta, - normalize_whitespace=False, ) await self.client.send_message( text=markdown_text, @@ -537,7 +535,6 @@ class TelegramPlatformEvent(AstrMessageEvent): try: md = telegramify_markdown.markdownify( draft_text, - normalize_whitespace=False, ) await self._send_message_draft( user_name, @@ -695,7 +692,6 @@ class TelegramPlatformEvent(AstrMessageEvent): try: markdown_text = telegramify_markdown.markdownify( delta, - normalize_whitespace=False, ) await self.client.edit_message_text( text=markdown_text, diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index 626e2752f..9bbdb5aee 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import os import re +import shlex import shutil import tempfile import zipfile @@ -79,7 +80,59 @@ def _parse_frontmatter_description(text: str) -> str: # Regex for sanitizing paths used in prompt examples — only allow # safe path characters to prevent prompt injection via crafted skill paths. -_SAFE_PATH_RE = re.compile(r"[^A-Za-z0-9_./ -]") +_SAFE_PATH_RE = re.compile(r"[^\w./ ,()'\-]", re.UNICODE) +_WINDOWS_DRIVE_PATH_RE = re.compile(r"^[A-Za-z]:(?:/|\\)") +_WINDOWS_UNC_PATH_RE = re.compile(r"^(//|\\\\)[^/\\]+[/\\][^/\\]+") +_CONTROL_CHARS_RE = re.compile(r"[\x00-\x1F\x7F]") + + +def _is_windows_prompt_path(path: str) -> bool: + if os.name != "nt": + return False + return bool(_WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path)) + + +def _sanitize_prompt_path_for_prompt(path: str) -> str: + if not path: + return "" + + if _WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path): + path = path.replace("\\", "/") + + drive_prefix = "" + if _WINDOWS_DRIVE_PATH_RE.match(path): + drive_prefix = path[:2] + path = path[2:] + + path = path.replace("`", "") + path = _CONTROL_CHARS_RE.sub("", path) + sanitized = _SAFE_PATH_RE.sub("", path) + return f"{drive_prefix}{sanitized}" + + +def _sanitize_prompt_description(description: str) -> str: + description = description.replace("`", "") + description = _CONTROL_CHARS_RE.sub(" ", description) + description = " ".join(description.split()) + return description + + +def _sanitize_skill_display_name(name: str) -> str: + if _SKILL_NAME_RE.fullmatch(name): + return name + return "" + + +def _build_skill_read_command_example(path: str) -> str: + if path == "//SKILL.md": + return f"cat {path}" + if _is_windows_prompt_path(path): + command = "type" + path_arg = f'"{path}"' + else: + command = "cat" + path_arg = shlex.quote(path) + return f"{command} {path_arg}" def build_skills_prompt(skills: list[SkillInfo]) -> str: @@ -92,16 +145,37 @@ def build_skills_prompt(skills: list[SkillInfo]) -> str: skills_lines: list[str] = [] example_path = "" for skill in skills: + display_name = _sanitize_skill_display_name(skill.name) + description = skill.description or "No description" + if skill.source_type == "sandbox_only": + description = _sanitize_prompt_description(description) + if not description: + description = "Read SKILL.md for details." + + if skill.source_type == "sandbox_only": + rendered_path = ( + f"{str(SANDBOX_WORKSPACE_ROOT)}/{str(SANDBOX_SKILLS_ROOT)}/" + f"{display_name}/SKILL.md" + ) + else: + rendered_path = _sanitize_prompt_path_for_prompt(skill.path) + if not rendered_path: + rendered_path = "//SKILL.md" + skills_lines.append( - f"- **{skill.name}**: {description}\n File: `{skill.path}`" + f"- **{display_name}**: {description}\n File: `{rendered_path}`" ) if not example_path: - example_path = skill.path + example_path = rendered_path skills_block = "\n".join(skills_lines) # Sanitize example_path — it may originate from sandbox cache (untrusted) - example_path = _SAFE_PATH_RE.sub("", example_path) if example_path else "" - example_path = example_path or "//SKILL.md" + if example_path == "//SKILL.md": + example_path = "//SKILL.md" + else: + example_path = _sanitize_prompt_path_for_prompt(example_path) + example_path = example_path or "//SKILL.md" + example_command = _build_skill_read_command_example(example_path) return ( "## Skills\n\n" @@ -119,8 +193,9 @@ def build_skills_prompt(skills: list[SkillInfo]) -> str: "*Never silently skip a matching skill* — either use it or briefly " "explain why you chose not to.\n" "3. **Mandatory grounding** — Before executing any skill you MUST " - "first read its `SKILL.md` by running a shell command with the " - f"**absolute path** shown above (e.g. `cat {example_path}`). " + "first read its `SKILL.md` by running a shell command compatible " + "with the current runtime shell and using the **absolute path** " + f"shown above (e.g. `{example_command}`). " "Never rely on memory or assumptions about a skill's content.\n" "4. **Progressive disclosure** — Load only what is directly " "referenced from `SKILL.md`:\n" diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index b812698f2..cf000c5a4 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -14,7 +14,12 @@ import yaml from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import InvalidVersion, Version -from astrbot.core import logger, pip_installer, sp +from astrbot.core import ( + DependencyConflictError, + logger, + pip_installer, + sp, +) from astrbot.core.agent.handoff import FunctionTool, HandoffTool from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.default import VERSION @@ -27,6 +32,10 @@ from astrbot.core.utils.astrbot_path import ( ) from astrbot.core.utils.io import remove_dir from astrbot.core.utils.metrics import Metric +from astrbot.core.utils.requirements_utils import ( + RequirementsPrecheckFailed, + find_missing_requirements_or_raise, +) from . import StarMetadata from .command_management import sync_command_configs @@ -48,6 +57,49 @@ class PluginVersionIncompatibleError(Exception): """Raised when plugin astrbot_version is incompatible with current AstrBot.""" +class PluginDependencyInstallError(Exception): + """Raised when plugin dependency installation fails.""" + + def __init__( + self, + *, + plugin_label: str, + requirements_path: str, + error: Exception, + ) -> None: + message = f"插件 {plugin_label} 依赖安装失败: {error!s}" + super().__init__(message) + self.plugin_label = plugin_label + self.requirements_path = requirements_path + self.error = error + + +async def _install_requirements_with_precheck( + *, + plugin_label: str, + requirements_path: str, +) -> None: + try: + missing = find_missing_requirements_or_raise(requirements_path) + except RequirementsPrecheckFailed: + logger.info( + f"正在安装插件 {plugin_label} 的依赖库(预检查失败,回退到完整安装): " + f"{requirements_path}" + ) + await pip_installer.install(requirements_path=requirements_path) + return + + if not missing: + logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。") + return + + logger.info( + f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: " + f"{requirements_path} -> {sorted(missing)}" + ) + await pip_installer.install(requirements_path=requirements_path) + + class PluginManager: def __init__(self, context: Context, config: AstrBotConfig) -> None: from .star_tools import StarTools @@ -198,15 +250,37 @@ class PluginManager: to_update.append(p.root_dir_name) for p in to_update: plugin_path = os.path.join(plugin_dir, p) - if os.path.exists(os.path.join(plugin_path, "requirements.txt")): - pth = os.path.join(plugin_path, "requirements.txt") - logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}") - try: - await pip_installer.install(requirements_path=pth) - except Exception as e: - logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}") + await self._ensure_plugin_requirements(plugin_path, p) return True + async def _ensure_plugin_requirements( + self, + plugin_dir_path: str, + plugin_label: str, + ) -> None: + requirements_path = os.path.join(plugin_dir_path, "requirements.txt") + if not os.path.exists(requirements_path): + return + + try: + await _install_requirements_with_precheck( + plugin_label=plugin_label, + requirements_path=requirements_path, + ) + except asyncio.CancelledError: + raise + except DependencyConflictError as e: + logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}") + raise + except Exception as e: + dependency_error = PluginDependencyInstallError( + plugin_label=plugin_label, + requirements_path=requirements_path, + error=e, + ) + logger.exception(str(dependency_error)) + raise dependency_error from e + async def _import_plugin_with_dependency_recovery( self, path: str, @@ -422,7 +496,7 @@ class PluginManager: root_dir_name: str, plugin_dir_path: str, reserved: bool, - error: Exception | str, + error: BaseException | str, error_trace: str, ) -> dict: record: dict = { @@ -495,6 +569,9 @@ class PluginManager: self._cleanup_plugin_state(dir_name) + plugin_path = os.path.join(self.plugin_store_path, dir_name) + await self._ensure_plugin_requirements(plugin_path, dir_name) + success, error = await self.load(specified_dir_name=dir_name) if success: self.failed_plugin_dict.pop(dir_name, None) @@ -1078,6 +1155,10 @@ class PluginManager: # reload the plugin dir_name = os.path.basename(plugin_path) + await self._ensure_plugin_requirements( + plugin_path, + dir_name, + ) success, error_message = await self.load( specified_dir_name=dir_name, ignore_version_check=ignore_version_check, @@ -1317,6 +1398,12 @@ class PluginManager: raise Exception("该插件是 AstrBot 保留插件,无法更新。") await self.updator.update(plugin, proxy=proxy) + if plugin.root_dir_name: + plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) + await self._ensure_plugin_requirements( + plugin_dir_path, + plugin_name, + ) await self.reload(plugin_name) async def turn_off_plugin(self, plugin_name: str) -> None: @@ -1488,6 +1575,7 @@ class PluginManager: os.remove(zip_file_path) except BaseException as e: logger.warning(f"删除插件压缩包失败: {e!s}") + await self._ensure_plugin_requirements(desti_dir, dir_name) # await self.reload() success, error_message = await self.load( specified_dir_name=dir_name, diff --git a/astrbot/core/tools/send_message.py b/astrbot/core/tools/send_message.py index f7dd4fc9d..a1e7c589a 100644 --- a/astrbot/core/tools/send_message.py +++ b/astrbot/core/tools/send_message.py @@ -42,7 +42,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]): "type": "string", "description": ( "Component type. One of: " - "plain, image, record, file, mention_user" + "plain, image, record, video, file, mention_user. Record is voice message." ), }, "text": { @@ -160,6 +160,19 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]): components.append(Comp.Record.fromURL(url=url)) else: return f"error: messages[{idx}] must include path or url for record component." + elif msg_type == "video": + path = msg.get("path") + url = msg.get("url") + if path: + ( + local_path, + file_from_sandbox, + ) = await self._resolve_path_from_sandbox(context, path) + components.append(Comp.Video.fromFileSystem(path=local_path)) + elif url: + components.append(Comp.Video.fromURL(url=url)) + else: + return f"error: messages[{idx}] must include path or url for video component." elif msg_type == "file": path = msg.get("path") url = msg.get("url") diff --git a/astrbot/core/utils/core_constraints.py b/astrbot/core/utils/core_constraints.py new file mode 100644 index 000000000..b43f00122 --- /dev/null +++ b/astrbot/core/utils/core_constraints.py @@ -0,0 +1,121 @@ +import contextlib +import functools +import importlib.metadata as importlib_metadata +import logging +import os +from collections.abc import Iterator + +from packaging.requirements import Requirement + +from astrbot.core.utils.requirements_utils import ( + canonicalize_distribution_name, + collect_installed_distribution_versions, + get_requirement_check_paths, +) + +logger = logging.getLogger("astrbot") + + +def _resolve_core_dist_name(core_dist_name: str | None) -> str | None: + if core_dist_name: + try: + importlib_metadata.distribution(core_dist_name) + return core_dist_name + except importlib_metadata.PackageNotFoundError: + return None + + try: + importlib_metadata.distribution("AstrBot") + return "AstrBot" + except importlib_metadata.PackageNotFoundError: + pass + + if not __package__: + return None + + top_pkg = __package__.split(".")[0] + for dist in importlib_metadata.distributions(): + try: + top_level = dist.read_text("top_level.txt") or "" + except Exception: + continue + if top_pkg in top_level.splitlines(): + if "Name" in dist.metadata: + return dist.metadata["Name"] + + return None + + +@functools.cache +def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]: + try: + resolved_core_dist_name = _resolve_core_dist_name(core_dist_name) + except Exception as exc: + logger.warning("解析核心分发名称失败: %s", exc) + return () + + if not resolved_core_dist_name: + return () + + try: + dist = importlib_metadata.distribution(resolved_core_dist_name) + except importlib_metadata.PackageNotFoundError: + return () + except Exception as exc: + logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc) + return () + + if not dist or not dist.requires: + return () + + installed = collect_installed_distribution_versions(get_requirement_check_paths()) + if not installed: + return () + + constraints: list[str] = [] + for req_str in dist.requires: + try: + req = Requirement(req_str) + if req.marker and not req.marker.evaluate(): + continue + name = canonicalize_distribution_name(req.name) + if name in installed: + constraints.append(f"{name}=={installed[name]}") + except Exception: + continue + + return tuple(constraints) + + +class CoreConstraintsProvider: + def __init__(self, core_dist_name: str | None) -> None: + self._core_dist_name = core_dist_name + + @contextlib.contextmanager + def constraints_file(self) -> Iterator[str | None]: + constraints = _get_core_constraints(self._core_dist_name) + if not constraints: + yield None + return + + path: str | None = None + try: + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8" + ) as f: + f.write("\n".join(constraints)) + path = f.name + logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints)) + except Exception as exc: + logger.warning("创建临时约束文件失败: %s", exc) + yield None + return + + try: + yield path + finally: + if path and os.path.exists(path): + with contextlib.suppress(Exception): + os.remove(path) diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 562a0ed30..97e9653d6 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -7,21 +7,71 @@ import io import logging import os import re +import shlex import sys import threading from collections import deque +from dataclasses import dataclass +from urllib.parse import urlparse from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path +from astrbot.core.utils.core_constraints import CoreConstraintsProvider +from astrbot.core.utils.requirements_utils import ( + canonicalize_distribution_name as _canonicalize_distribution_name, +) +from astrbot.core.utils.requirements_utils import ( + extract_requirement_name, + extract_requirement_names, + parse_package_install_input, +) from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime logger = logging.getLogger("astrbot") _DISTLIB_FINDER_PATCH_ATTEMPTED = False _SITE_PACKAGES_IMPORT_LOCK = threading.RLock() +_PIP_FAILURE_PATTERNS = { + "error_prefix": re.compile(r"^\s*error:", re.IGNORECASE), + "user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE), + "resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE), + "cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE), + "conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE), + "constraint": re.compile(r"\(constraint\)", re.IGNORECASE), + "dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE), +} +_SENSITIVE_PIP_VALUE_KEYS = frozenset( + {"password", "passwd", "pass", "api_token", "token", "auth_token"} +) +_MAX_PIP_OUTPUT_LINES = 200 -def _canonicalize_distribution_name(name: str) -> str: - return re.sub(r"[-_.]+", "-", name).strip("-").lower() +class DependencyConflictError(Exception): + """Raised when pip encounters a dependency conflict.""" + + def __init__( + self, message: str, errors: list[str], *, is_core_conflict: bool + ) -> None: + super().__init__(message) + self.errors = errors + self.is_core_conflict = is_core_conflict + + +class PipInstallError(Exception): + """Raised when pip install fails without a classified dependency conflict.""" + + def __init__(self, message: str, *, code: int) -> None: + super().__init__(message) + self.code = code + + +@dataclass +class PipConflictContext: + relevant_lines: list[str] + requested_lines: list[str] + dependency_detail_lines: list[str] + constraint_lines: list[str] + has_strong_conflict_signal: bool + has_contextual_conflict_signal: bool def _get_pip_main(): @@ -41,11 +91,12 @@ def _get_pip_main(): return pip_main -def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]: - stream = io.StringIO() - with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream): - result_code = pip_main(args) - return result_code, stream.getvalue() +def _prepend_sys_path(path: str) -> None: + normalized_target = os.path.realpath(path) + sys.path[:] = [ + item for item in sys.path if os.path.realpath(item) != normalized_target + ] + sys.path.insert(0, normalized_target) def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None: @@ -59,76 +110,258 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No handler.close() -def _prepend_sys_path(path: str) -> None: - normalized_target = os.path.realpath(path) - sys.path[:] = [ - item for item in sys.path if os.path.realpath(item) != normalized_target - ] - sys.path.insert(0, normalized_target) +def _get_trusted_host_for_index_url(index_url: str) -> str | None: + parsed = urlparse(index_url if "://" in index_url else f"//{index_url}") + host = parsed.hostname + if host == "mirrors.aliyun.com": + return host + return None -def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool: - base_path = os.path.join(site_packages_path, *module_name.split(".")) - package_init = os.path.join(base_path, "__init__.py") - module_file = f"{base_path}.py" - return os.path.isfile(package_init) or os.path.isfile(module_file) +def _normalize_sensitive_pip_key(raw_key: str) -> str: + return raw_key.lstrip("-").replace("-", "_").lower() -def _is_module_loaded_from_site_packages( - module_name: str, - site_packages_path: str, -) -> bool: - module = sys.modules.get(module_name) - if module is None: - try: - module = importlib.import_module(module_name) - except Exception: - return False +def _is_sensitive_pip_value_key(raw_key: str) -> bool: + return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS - module_file = getattr(module, "__file__", None) - if not module_file: - return False - module_path = os.path.realpath(module_file) - site_packages_real = os.path.realpath(site_packages_path) - try: - return ( - os.path.commonpath([module_path, site_packages_real]) == site_packages_real +def _redact_url_credentials(raw_value: str) -> str: + """Redact URL credentials and known inline secret values for safe logging.""" + parsed = urlparse(raw_value) + if parsed.netloc and "@" in parsed.netloc: + hostname = parsed.hostname or "" + port = f":{parsed.port}" if parsed.port else "" + return parsed._replace(netloc=f"@{hostname}{port}").geturl() + + if raw_value.startswith("--"): + option, separator, _ = raw_value.partition("=") + if separator and _is_sensitive_pip_value_key(option): + return f"{option}=****" + return raw_value + + key, separator, _ = raw_value.partition("=") + if separator and _is_sensitive_pip_value_key(key): + return f"{key}=****" + + return raw_value + + +def _redact_pip_args_for_logging(args: list[str]) -> list[str]: + redacted_args: list[str] = [] + redact_next_value = False + + for arg in args: + if redact_next_value: + redacted_args.append("****") + redact_next_value = False + continue + + if arg.startswith("--") and "=" in arg: + option, value = arg.split("=", 1) + if _is_sensitive_pip_value_key(option): + redacted_args.append(f"{option}=****") + else: + redacted_args.append(f"{option}={_redact_url_credentials(value)}") + continue + + if arg.startswith("-i") and arg != "-i": + redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}") + continue + + if _is_sensitive_pip_value_key(arg): + redacted_args.append(arg) + redact_next_value = True + continue + + redacted_args.append(_redact_url_credentials(arg)) + + return redacted_args + + +def _package_specs_override_index(package_specs: list[str]) -> bool: + for index, spec in enumerate(package_specs): + if spec == "--no-index": + return True + if spec in {"-i", "--index-url"}: + if index + 1 < len(package_specs): + return True + continue + if spec.startswith("--index-url="): + return True + if spec.startswith("-i") and spec != "-i": + return True + return False + + +class _StreamingLogWriter(io.TextIOBase): + def __init__(self, log_func, *, max_lines: int | None = None) -> None: + self._log_func = log_func + self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES) + self._buffer = "" + + def write(self, text: str) -> int: + if not text: + return 0 + + self._buffer += text.replace("\r\n", "\n").replace("\r", "\n") + while "\n" in self._buffer: + raw_line, self._buffer = self._buffer.split("\n", 1) + line = raw_line.rstrip("\r\n") + self._log_func(line) + self._lines.append(line) + return len(text) + + def flush(self) -> None: + line = self._buffer.rstrip("\r\n") + if line: + self._log_func(line) + self._lines.append(line) + self._buffer = "" + + @property + def lines(self) -> list[str]: + return list(self._lines) + + +def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]: + stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES) + with ( + contextlib.redirect_stdout(stream), + contextlib.redirect_stderr(stream), + ): + result_code = pip_main(args) + stream.flush() + return result_code, stream.lines + + +def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool: + names = pattern_names or tuple(_PIP_FAILURE_PATTERNS) + return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names) + + +def _normalize_conflict_detail_line(line: str) -> str: + stripped = line.strip() + if _matches_pip_failure_pattern(stripped, "user_requested"): + return re.sub( + r"^\s*The user requested\s+", + "", + stripped, + flags=re.IGNORECASE, ) - except ValueError: - return False + return stripped -def _extract_requirement_name(raw_requirement: str) -> str | None: - line = raw_requirement.split("#", 1)[0].strip() - if not line: - return None - if line.startswith(("-r", "--requirement", "-c", "--constraint")): - return None - if line.startswith("-"): +def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None: + matched_indices = [ + index + for index, line in enumerate(output_lines) + if _matches_pip_failure_pattern(line) + ] + if matched_indices: + relevant_index_set: set[int] = set() + for index in matched_indices: + start = max(0, index - 1) + end = min(len(output_lines), index + 2) + relevant_index_set.update(range(start, end)) + relevant_output_lines = [ + line + for index, line in enumerate(output_lines) + if index in relevant_index_set + ] + else: + relevant_output_lines = output_lines[-5:] + + if not relevant_output_lines: return None - egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement) - if egg_match: - return _canonicalize_distribution_name(egg_match.group(1)) + dependency_detail_lines = [ + line.strip() + for line in relevant_output_lines + if _matches_pip_failure_pattern(line, "dependency_detail") + ] + requested_lines = [ + line.strip() + for line in relevant_output_lines + if _matches_pip_failure_pattern(line, "user_requested") + and not _matches_pip_failure_pattern(line, "constraint") + ] + if not requested_lines: + requested_lines = [ + line + for line in dependency_detail_lines + if not _matches_pip_failure_pattern(line, "constraint") + ] + constraint_lines = [ + line.strip() + for line in relevant_output_lines + if _matches_pip_failure_pattern(line, "constraint") + ] - candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip() - if not candidate: + has_strong_conflict_signal = any( + _matches_pip_failure_pattern( + line, + "resolution_impossible", + "cannot_install", + ) + for line in relevant_output_lines + ) + + has_contextual_conflict_signal = any( + _matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines + ) and bool(dependency_detail_lines or requested_lines or constraint_lines) + + return PipConflictContext( + relevant_lines=relevant_output_lines, + requested_lines=requested_lines, + dependency_detail_lines=dependency_detail_lines, + constraint_lines=constraint_lines, + has_strong_conflict_signal=has_strong_conflict_signal, + has_contextual_conflict_signal=has_contextual_conflict_signal, + ) + + +def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None: + context = _build_pip_conflict_context(output_lines) + if context is None: return None - return _canonicalize_distribution_name(candidate) + if ( + not context.has_strong_conflict_signal + and not context.has_contextual_conflict_signal + and not (context.requested_lines and context.constraint_lines) + ): + return None -def _extract_requirement_names(requirements_path: str) -> set[str]: - names: set[str] = set() - try: - with open(requirements_path, encoding="utf-8") as requirements_file: - for line in requirements_file: - requirement_name = _extract_requirement_name(line) - if requirement_name: - names.add(requirement_name) - except Exception as exc: - logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc) - return names + is_core_conflict = bool(context.constraint_lines) + + detail = "" + if context.constraint_lines and context.requested_lines: + detail = ( + " 冲突详情: " + f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs " + f"{_normalize_conflict_detail_line(context.constraint_lines[0])}。" + ) + elif len(context.dependency_detail_lines) >= 2: + detail = ( + " 冲突详情: " + f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs " + f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}。" + ) + + if is_core_conflict: + message = ( + f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容," + "为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。" + ) + else: + message = f"检测到依赖冲突。{detail}" + + return DependencyConflictError( + message, + context.relevant_lines, + is_core_conflict=is_core_conflict, + ) def _extract_top_level_modules( @@ -155,7 +388,11 @@ def _collect_candidate_modules( by_name: dict[str, list[importlib_metadata.Distribution]] = {} try: for distribution in importlib_metadata.distributions(path=[site_packages_path]): - distribution_name = distribution.metadata.get("Name") + distribution_name = ( + distribution.metadata["Name"] + if "Name" in distribution.metadata + else None + ) if not distribution_name: continue canonical_name = _canonicalize_distribution_name(distribution_name) @@ -173,7 +410,7 @@ def _collect_candidate_modules( for distribution in by_name.get(requirement_name, []): for dependency_line in distribution.requires or []: - dependency_name = _extract_requirement_name(dependency_line) + dependency_name = extract_requirement_name(dependency_line) if not dependency_name: continue if dependency_name in expanded_requirement_names: @@ -230,6 +467,38 @@ def _ensure_preferred_modules( raise RuntimeError(conflict_message) +def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool: + base_path = os.path.join(site_packages_path, *module_name.split(".")) + package_init = os.path.join(base_path, "__init__.py") + module_file = f"{base_path}.py" + return os.path.isfile(package_init) or os.path.isfile(module_file) + + +def _is_module_loaded_from_site_packages( + module_name: str, + site_packages_path: str, +) -> bool: + module = sys.modules.get(module_name) + if module is None: + try: + module = importlib.import_module(module_name) + except Exception: + return False + + module_file = getattr(module, "__file__", None) + if not module_file: + return False + + module_path = os.path.realpath(module_file) + site_packages_real = os.path.realpath(site_packages_path) + try: + return ( + os.path.commonpath([module_path, site_packages_real]) == site_packages_real + ) + except ValueError: + return False + + def _prefer_module_from_site_packages( module_name: str, site_packages_path: str ) -> bool: @@ -531,9 +800,63 @@ def _patch_distlib_finder_for_frozen_runtime() -> None: class PipInstaller: - def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None: + def __init__( + self, + pip_install_arg: str, + pypi_index_url: str | None = None, + core_dist_name: str | None = "AstrBot", + ) -> None: self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url + self.core_dist_name = core_dist_name + self._core_constraints = CoreConstraintsProvider(core_dist_name) + + def _build_pip_args( + self, + package_name: str | None, + requirements_path: str | None, + mirror: str | None, + ) -> tuple[list[str], set[str]]: + args: list[str] = [] + requested_requirements: set[str] = set() + normalized_requirements_path = ( + requirements_path.strip() if requirements_path else "" + ) + + if package_name and normalized_requirements_path: + raise ValueError( + "package_name and requirements_path cannot be used together" + ) + + if package_name: + parsed_package = parse_package_install_input(package_name) + if parsed_package.specs: + args = ["install", *parsed_package.specs] + requested_requirements = set(parsed_package.requirement_names) + elif normalized_requirements_path: + args = ["install", "-r", normalized_requirements_path] + requested_requirements = extract_requirement_names( + normalized_requirements_path + ) + + if not args: + return [], requested_requirements + + pip_install_args = ( + shlex.split(self.pip_install_arg) if self.pip_install_arg else [] + ) + + if not _package_specs_override_index([*args[1:], *pip_install_args]): + index_url = mirror or self.pypi_index_url or "https://pypi.org/simple" + trusted_host = _get_trusted_host_for_index_url(index_url) + if trusted_host: + args.extend(["--trusted-host", trusted_host]) + args.extend(["-i", index_url]) + + if pip_install_args: + args.extend(pip_install_args) + + return args, requested_requirements async def install( self, @@ -541,36 +864,37 @@ class PipInstaller: requirements_path: str | None = None, mirror: str | None = None, ) -> None: - args = ["install"] - requested_requirements: set[str] = set() - if package_name: - args.append(package_name) - requirement_name = _extract_requirement_name(package_name) - if requirement_name: - requested_requirements.add(requirement_name) - elif requirements_path: - args.extend(["-r", requirements_path]) - requested_requirements = _extract_requirement_names(requirements_path) - - index_url = mirror or self.pypi_index_url or "https://pypi.org/simple" - args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url]) + args, requested_requirements = self._build_pip_args( + package_name, requirements_path, mirror + ) + if not args: + logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。") + return target_site_packages = None if is_packaged_desktop_runtime(): target_site_packages = get_astrbot_site_packages_path() os.makedirs(target_site_packages, exist_ok=True) _prepend_sys_path(target_site_packages) - args.extend(["--target", target_site_packages]) - args.extend(["--upgrade", "--force-reinstall"]) + args.extend( + [ + "--target", + target_site_packages, + "--upgrade", + "--upgrade-strategy", + "only-if-needed", + ] + ) - if self.pip_install_arg: - args.extend(self.pip_install_arg.split()) + with self._core_constraints.constraints_file() as constraints_file_path: + if constraints_file_path: + args.extend(["-c", constraints_file_path]) - logger.info(f"Pip 包管理器: pip {' '.join(args)}") - result_code = await self._run_pip_in_process(args) - - if result_code != 0: - raise Exception(f"安装失败,错误码:{result_code}") + logger.info( + "Pip 包管理器 argv: %s", + ["pip", *_redact_pip_args_for_logging(args)], + ) + await self._run_pip_with_classification(args) if target_site_packages: _prepend_sys_path(target_site_packages) @@ -589,7 +913,7 @@ class PipInstaller: if not os.path.isdir(target_site_packages): return - requested_requirements = _extract_requirement_names(requirements_path) + requested_requirements = extract_requirement_names(requirements_path) if not requested_requirements: return @@ -605,13 +929,21 @@ class PipInstaller: _patch_distlib_finder_for_frozen_runtime() original_handlers = list(logging.getLogger().handlers) - result_code, output = await asyncio.to_thread( - _run_pip_main_with_output, pip_main, args - ) - for line in output.splitlines(): - line = line.strip() - if line: - logger.info(line) + try: + result_code, output_lines = await asyncio.to_thread( + _run_pip_main_streaming, pip_main, args + ) + finally: + _cleanup_added_root_handlers(original_handlers) + + if result_code != 0: + conflict = _classify_pip_failure(output_lines) + if conflict: + raise conflict - _cleanup_added_root_handlers(original_handlers) return result_code + + async def _run_pip_with_classification(self, args: list[str]) -> None: + result_code = await self._run_pip_in_process(args) + if result_code != 0: + raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code) diff --git a/astrbot/core/utils/requirements_utils.py b/astrbot/core/utils/requirements_utils.py new file mode 100644 index 000000000..7f3827256 --- /dev/null +++ b/astrbot/core/utils/requirements_utils.py @@ -0,0 +1,408 @@ +import importlib.metadata as importlib_metadata +import logging +import os +import re +import shlex +import sys +from collections.abc import Iterable, Iterator +from dataclasses import dataclass + +from packaging.requirements import InvalidRequirement, Requirement +from packaging.specifiers import SpecifierSet +from packaging.version import InvalidVersion, Version + +from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path +from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime + +logger = logging.getLogger("astrbot") + + +class RequirementsPrecheckFailed(Exception): + """Raised when the pre-check of requirements fails.""" + + pass + + +@dataclass(frozen=True) +class ParsedPackageInput: + specs: tuple[str, ...] + requirement_names: frozenset[str] + + +def canonicalize_distribution_name(name: str) -> str: + return re.sub(r"[-_.]+", "-", name).strip("-").lower() + + +def strip_inline_requirement_comment(raw_input: str) -> str: + if raw_input.lstrip().startswith("#"): + return "" + return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip() + + +def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool: + try: + parsed_version = Version(version) + except InvalidVersion: + return False + return specifier.contains(parsed_version, prereleases=True) + + +def _looks_like_local_path_reference(token: str) -> bool: + candidate = token.strip() + if not candidate: + return False + return candidate in {".", ".."} or candidate.startswith( + ("./", "../", "/", "~/", ".\\", "..\\", "\\") + ) + + +def looks_like_direct_reference(token: str) -> bool: + candidate = token.strip() + if not candidate: + return False + return ( + _looks_like_local_path_reference(candidate) + or candidate.startswith("git+") + or "://" in candidate + ) + + +def extract_requirement_name(raw_requirement: str) -> str | None: + line = raw_requirement.split("#", 1)[0].strip() + if not line: + return None + if line.startswith(("-r", "--requirement", "-c", "--constraint")): + return None + + egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement) + if egg_match: + return canonicalize_distribution_name(egg_match.group(1)) + + if line.startswith("-"): + return None + + candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip() + if not candidate: + return None + return canonicalize_distribution_name(candidate) + + +def _parse_editable_or_direct_name(target: str) -> str | None: + name = extract_requirement_name(target) + if not name: + return None + if "#egg=" in target or not looks_like_direct_reference(target): + return name + return None + + +def _parse_requirement_name_and_spec( + line: str, +) -> tuple[str | None, SpecifierSet | None]: + if line.startswith(("-c", "--constraint")): + return None, None + + try: + req = Requirement(line) + except InvalidRequirement: + tokens = shlex.split(line) + if not tokens: + return None, None + + editable_target: str | None = None + if tokens[0] in {"-e", "--editable"} and len(tokens) > 1: + editable_target = tokens[1] + elif tokens[0].startswith("--editable="): + editable_target = tokens[0].split("=", 1)[1] + + if editable_target: + name = _parse_editable_or_direct_name(editable_target) + return (name, None) if name else (None, None) + + name = _parse_editable_or_direct_name(line) + return (name, None) if name else (None, None) + + if req.marker and not req.marker.evaluate(): + return None, None + + return canonicalize_distribution_name(req.name), (req.specifier or None) + + +def _parse_requirement_line( + line: str, +) -> tuple[str, SpecifierSet | None] | None: + name, specifier = _parse_requirement_name_and_spec(line) + return (name, specifier) if name else None + + +def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]: + requirement_names: set[str] = set() + skip_next_for: str | None = None + + for token in tokens: + if skip_next_for: + if skip_next_for == "editable": + name = _parse_editable_or_direct_name(token) + if name: + requirement_names.add(name) + skip_next_for = None + continue + + if token in {"-e", "--editable"}: + skip_next_for = "editable" + continue + + if token in { + "-i", + "--index-url", + "--extra-index-url", + "-f", + "--find-links", + "--trusted-host", + "-r", + "--requirement", + "-c", + "--constraint", + }: + skip_next_for = "option-value" + continue + + if token.startswith(("--editable=",)): + editable_target = token.split("=", 1)[1] + name = _parse_editable_or_direct_name(editable_target) + if name: + requirement_names.add(name) + continue + + if token.startswith( + ( + "--index-url=", + "--extra-index-url=", + "--find-links=", + "--trusted-host=", + "--requirement=", + "--constraint=", + ) + ): + continue + + if ( + (token.startswith("-i") and token != "-i") + or (token.startswith("-f") and token != "-f") + or token == "--no-index" + ): + continue + + if token.startswith("-"): + continue + + name, _ = _parse_requirement_name_and_spec(token) + if name: + requirement_names.add(name) + + return frozenset(requirement_names) + + +def parse_package_install_input(raw_input: str) -> ParsedPackageInput: + specs: list[str] = [] + requirement_names: set[str] = set() + normalized = raw_input.strip() + if not normalized: + return ParsedPackageInput(specs=(), requirement_names=frozenset()) + + for raw_line in normalized.splitlines(): + line = strip_inline_requirement_comment(raw_line) + if not line: + continue + + try: + Requirement(line) + except InvalidRequirement: + tokens = shlex.split(line) + if not tokens: + continue + specs.extend(tokens) + requirement_names.update( + _extract_requirement_names_from_package_tokens(tokens) + ) + continue + + specs.append(line) + name, _ = _parse_requirement_name_and_spec(line) + if name: + requirement_names.add(name) + + return ParsedPackageInput( + specs=tuple(specs), + requirement_names=frozenset(requirement_names), + ) + + +def _iter_requirement_lines( + requirements_path: str, + _visited: set[str] | None = None, +) -> Iterator[str]: + visited = _visited or set() + resolved_path = os.path.realpath(requirements_path) + if resolved_path in visited: + logger.warning( + "检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path + ) + return + visited.add(resolved_path) + + with open(resolved_path, encoding="utf-8") as f: + for raw_line in f: + line = strip_inline_requirement_comment(raw_line) + if not line: + continue + + tokens = shlex.split(line) + if not tokens: + continue + + nested: str | None = None + if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1: + nested = tokens[1] + elif tokens[0].startswith("--requirement="): + nested = tokens[0].split("=", 1)[1] + + if nested: + if not os.path.isabs(nested): + nested = os.path.join(os.path.dirname(resolved_path), nested) + yield from _iter_requirement_lines(nested, _visited=visited) + continue + + yield line + + +def iter_requirements( + requirements_path: str | None = None, + lines: Iterable[str] | None = None, +) -> Iterator[tuple[str, SpecifierSet | None]]: + if lines is None: + if requirements_path is None: + raise ValueError("Either requirements_path or lines must be provided") + lines = _iter_requirement_lines(requirements_path) + + for line in lines: + parsed = _parse_requirement_line(line) + if parsed is not None: + yield parsed + + +def extract_requirement_names(requirements_path: str) -> set[str]: + try: + return { + name for name, _ in iter_requirements(requirements_path=requirements_path) + } + except Exception as exc: + logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc) + return set() + + +def get_requirement_check_paths() -> list[str]: + paths = list(sys.path) + if is_packaged_desktop_runtime(): + target_site_packages = get_astrbot_site_packages_path() + if os.path.isdir(target_site_packages): + paths.insert(0, target_site_packages) + return paths + + +def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]: + distribution_name = ( + distribution.metadata["Name"] if "Name" in distribution.metadata else None + ) + if not distribution_name: + return None, None + return canonicalize_distribution_name(distribution_name), distribution.version + + +def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None: + installed: dict[str, str] = {} + try: + for distribution in importlib_metadata.distributions(path=paths): + distribution_name, version = _canonical_distribution_identity(distribution) + if not distribution_name or not version: + continue + installed.setdefault(distribution_name, version) + except Exception as exc: + logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc) + return None + return installed + + +def _load_requirement_lines_for_precheck( + requirements_path: str, +) -> tuple[bool, list[str] | None]: + try: + requirement_lines = list(_iter_requirement_lines(requirements_path)) + except Exception as exc: + logger.warning( + "预检查缺失依赖失败,将回退到完整安装: %s (%s)", + requirements_path, + exc, + ) + return False, None + + fallback_line = next( + ( + line + for line in requirement_lines + if ( + ( + line.startswith(("-e ", "--editable ", "--editable=")) + and "#egg=" not in line + ) + or ( + _parse_requirement_line(line) is None + and looks_like_direct_reference(line) + ) + ) + ), + None, + ) + if fallback_line is not None: + logger.warning( + "预检查缺失依赖失败,将回退到完整安装: unresolved direct reference in %s: %s", + requirements_path, + fallback_line, + ) + return False, None + + return True, requirement_lines + + +def find_missing_requirements(requirements_path: str) -> set[str] | None: + can_precheck, requirement_lines = _load_requirement_lines_for_precheck( + requirements_path + ) + if not can_precheck or requirement_lines is None: + return None + + required = list(iter_requirements(lines=requirement_lines)) + if not required: + return set() + + installed = collect_installed_distribution_versions(get_requirement_check_paths()) + if installed is None: + return None + + missing: set[str] = set() + for name, specifier in required: + installed_version = installed.get(name) + if not installed_version: + missing.add(name) + continue + if specifier and not _specifier_contains_version(specifier, installed_version): + missing.add(name) + + return missing + + +def find_missing_requirements_or_raise(requirements_path: str) -> set[str]: + missing = find_missing_requirements(requirements_path) + if missing is None: + raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}") + return missing diff --git a/dashboard/src/components/extension/componentPanel/components/CommandFilters.vue b/dashboard/src/components/extension/componentPanel/components/CommandFilters.vue index c4b212803..15a6be180 100644 --- a/dashboard/src/components/extension/componentPanel/components/CommandFilters.vue +++ b/dashboard/src/components/extension/componentPanel/components/CommandFilters.vue @@ -1,6 +1,7 @@