From 09d1f96603b2224e4a5ab92a3130fd9d7550bdb7 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 26 Sep 2025 14:16:50 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20/alter=5Fcmd=20?= =?UTF-8?q?=E6=8C=87=E4=BB=A4=E6=97=A0=E6=B3=95=E6=8E=A7=E5=88=B6=E6=8C=87?= =?UTF-8?q?=E4=BB=A4=E7=BB=84=E3=80=81=E5=AD=90=E6=8C=87=E4=BB=A4=E7=BB=84?= =?UTF-8?q?=E5=92=8C=E5=AD=90=E6=8C=87=E4=BB=A4=E7=BB=84=E4=B8=8B=E5=AD=90?= =?UTF-8?q?=E6=8C=87=E4=BB=A4=E7=9A=84=E9=97=AE=E9=A2=98=20(#2873)=20*=20f?= =?UTF-8?q?ix:=20revert=20changes=20in=20command=5Fgroup.py=20at=20782c036?= =?UTF-8?q?=20to=20fix=20command=20group=20permission=20check?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 不传递 GroupCommand handler * perf: alter_cmd 指令支持对子指令、指令组进行配置 * chore: remove test commands and subcommands from test_group * chore: add cache for complete command names list in CommandFilter and CommandGroupFilter --------- Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com> Co-authored-by: Soulter <905617992@qq.com> --- astrbot/core/pipeline/waking_check/stage.py | 15 +++++--- astrbot/core/platform/astr_message_event.py | 14 +++++--- astrbot/core/star/filter/command.py | 39 ++++++++++++++------- astrbot/core/star/filter/command_group.py | 20 ++++++++--- packages/astrbot/main.py | 30 +++++++++------- 5 files changed, 77 insertions(+), 41 deletions(-) diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 63bc8b52d..de6ad5e35 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply from astrbot.core.message.message_event_result import MessageChain, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.filter.permission import PermissionTypeFilter +from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.session_plugin_manager import SessionPluginManager from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry @@ -170,11 +171,15 @@ class WakingCheckStage(Stage): is_wake = True event.is_wake = True - activated_handlers.append(handler) - if "parsed_params" in event.get_extra(): - handlers_parsed_params[handler.handler_full_name] = event.get_extra( - "parsed_params" - ) + is_group_cmd_handler = any( + isinstance(f, CommandGroupFilter) for f in handler.event_filters + ) + if not is_group_cmd_handler: + activated_handlers.append(handler) + if "parsed_params" in event.get_extra(default={}): + handlers_parsed_params[handler.handler_full_name] = ( + event.get_extra("parsed_params") + ) event._extras.pop("parsed_params", None) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index c38dfddef..52a882433 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -4,7 +4,7 @@ import re import hashlib import uuid -from typing import List, Union, Optional, AsyncGenerator +from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any from astrbot import logger from astrbot.core.db.po import Conversation @@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group from .platform_metadata import PlatformMetadata from .message_session import MessageSession, MessageSesion # noqa +_VT = TypeVar("_VT") + class AstrMessageEvent(abc.ABC): def __init__( @@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC): """是否唤醒(是否通过 WakingStage)""" self.is_at_or_wake_command = False """是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)""" - self._extras = {} + self._extras: dict[str, Any] = {} self.session = MessageSesion( platform_name=platform_meta.id, message_type=message_obj.type, @@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC): ) self.unified_msg_origin = str(self.session) """统一的消息来源字符串。格式为 platform_name:message_type:session_id""" - self._result: MessageEventResult = None + self._result: MessageEventResult | None = None """消息事件的结果""" self._has_send_oper = False @@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC): """ self._extras[key] = value - def get_extra(self, key=None): + def get_extra( + self, key: str | None = None, default: _VT = None + ) -> dict[str, Any] | _VT: """ 获取额外的信息。 """ if key is None: return self._extras - return self._extras.get(key, None) + return self._extras.get(key, default) def clear_extra(self): """ diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 9ceed54a9..d6b34d832 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter): self.init_handler_md(handler_md) self.custom_filter_list: List[CustomFilter] = [] + # Cache for complete command names list + self._cmpl_cmd_names: list | None = None + def print_types(self): result = "" for k, v in self.handler_params.items(): @@ -136,6 +139,28 @@ class CommandFilter(HandlerFilter): ) return result + def get_complete_command_names(self): + if self._cmpl_cmd_names is not None: + return self._cmpl_cmd_names + self._cmpl_cmd_names = [ + f"{parent} {cmd}" if parent else cmd + for cmd in [self.command_name] + list(self.alias) + for parent in self.parent_command_names or [""] + ] + return self._cmpl_cmd_names + + def startswith(self, message_str: str) -> bool: + for full_cmd in self.get_complete_command_names(): + if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd: + return True + return False + + def equals(self, message_str: str) -> bool: + for full_cmd in self.get_complete_command_names(): + if message_str == full_cmd: + return True + return False + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: if not event.is_at_or_wake_command: return False @@ -145,19 +170,7 @@ class CommandFilter(HandlerFilter): # 检查是否以指令开头 message_str = re.sub(r"\s+", " ", event.get_message_str().strip()) - candidates = [self.command_name] + list(self.alias) - ok = False - for candidate in candidates: - for parent_command_name in self.parent_command_names: - if parent_command_name: - _full = f"{parent_command_name} {candidate}" - else: - _full = candidate - if message_str.startswith(f"{_full} ") or message_str == _full: - message_str = message_str[len(_full) :].strip() - ok = True - break - if not ok: + if not self.startswith(message_str): return False # 分割为列表 diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 0b8cd6e86..e01fa2c58 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter): self.custom_filter_list: List[CustomFilter] = [] self.parent_group = parent_group + # Cache for complete command names list + self._cmpl_cmd_names: list | None = None + def add_sub_command_filter( self, sub_command_filter: Union[CommandFilter, CommandGroupFilter] ): @@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter): """遍历父节点获取完整的指令名。 新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。""" + if self._cmpl_cmd_names is not None: + return self._cmpl_cmd_names + parent_cmd_names = ( self.parent_group.get_complete_command_names() if self.parent_group else [] ) @@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter): for parent_cmd_name in parent_cmd_names: for candidate in candidates: result.append(parent_cmd_name + " " + candidate) + self._cmpl_cmd_names = result return result # 以树的形式打印出来 @@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter): return False return True + def startswith(self, message_str: str) -> bool: + return message_str.startswith(tuple(self.get_complete_command_names())) + + def equals(self, message_str: str) -> bool: + return message_str in self.get_complete_command_names() + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: if not event.is_at_or_wake_command: return False @@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter): if not self.custom_filter_ok(event, cfg): return False - complete_command_names = self.get_complete_command_names() - if event.message_str.strip() in complete_command_names: + if self.equals(event.message_str.strip()): tree = ( self.group_name + "\n" @@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter): f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree ) - # complete_command_names = [name + " " for name in complete_command_names] - # return event.message_str.startswith(tuple(complete_command_names)) - return False + return self.startswith(event.message_str) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 5f92170df..943363e07 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1348,22 +1348,22 @@ UID: {user_id} 此 ID 可用于设置管理员。 logger.error(f"ltm: {e}") @filter.permission_type(filter.PermissionType.ADMIN) - @filter.command("alter_cmd") + @filter.command("alter_cmd", alias={"alter"}) async def alter_cmd(self, event: AstrMessageEvent): - # token = event.message_str.split(" ") token = self.parse_commands(event.message_str) - if token.len < 2: + if token.len < 3: yield event.plain_result( - "可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd \n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置" + "该指令用于设置指令或指令组的权限。\n" + "格式: /alter_cmd \n" + "例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n" + "例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n" + "/alter_cmd reset config 打开 reset 权限配置" ) return - cmd_name = token.get(1) - cmd_type = token.get(2) + cmd_name = " ".join(token.tokens[1:-1]) + cmd_type = token.get(-1) - # ============================ - # 对reset权限进行特殊处理 - # ============================ if cmd_name == "reset" and cmd_type == "config": alter_cmd_cfg = await sp.global_get("alter_cmd", {}) plugin_ = alter_cmd_cfg.get("astrbot", {}) @@ -1413,16 +1413,18 @@ UID: {user_id} 此 ID 可用于设置管理员。 # 查找指令 found_command = None + cmd_group = False for handler in star_handlers_registry: assert isinstance(handler, StarHandlerMetadata) for filter_ in handler.event_filters: if isinstance(filter_, CommandFilter): - if filter_.command_name == cmd_name: + if filter_.equals(cmd_name): found_command = handler break elif isinstance(filter_, CommandGroupFilter): - if cmd_name == filter_.group_name: + if filter_.equals(cmd_name): found_command = handler + cmd_group = True break if not found_command: @@ -1459,8 +1461,10 @@ UID: {user_id} 此 ID 可用于设置管理员。 else filter.PermissionType.MEMBER ), ) - - yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令") + cmd_group_str = "指令组" if cmd_group else "指令" + yield event.plain_result( + f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。" + ) async def update_reset_permission(self, scene_key: str, perm_type: str): """更新reset命令在特定场景下的权限设置