fix: 修复 /alter_cmd 指令无法控制指令组、子指令组和子指令组下子指令的问题 (#2873)

* fix: revert changes in command_group.py at 782c036 to fix command group permission check

* fix: 不传递 GroupCommand handler

* perf: alter_cmd 指令支持对子指令、指令组进行配置

* chore: remove test commands and subcommands from test_group

* chore: add cache for complete command names list in CommandFilter and CommandGroupFilter

---------

Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
Soulter
2025-09-26 14:16:50 +08:00
parent 26aa18d980
commit 09d1f96603
5 changed files with 77 additions and 41 deletions
+10 -5
View File
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
activated_handlers.append(handler)
if "parsed_params" in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
is_group_cmd_handler = any(
isinstance(f, CommandGroupFilter) for f in handler.event_filters
)
if not is_group_cmd_handler:
activated_handlers.append(handler)
if "parsed_params" in event.get_extra(default={}):
handlers_parsed_params[handler.handler_full_name] = (
event.get_extra("parsed_params")
)
event._extras.pop("parsed_params", None)
+9 -5
View File
@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
_VT = TypeVar("_VT")
class AstrMessageEvent(abc.ABC):
def __init__(
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
"""是否唤醒(是否通过 WakingStage)"""
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self._extras: dict[str, Any] = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
message_type=message_obj.type,
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self._result: MessageEventResult = None
self._result: MessageEventResult | None = None
"""消息事件的结果"""
self._has_send_oper = False
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(self, key=None):
def get_extra(
self, key: str | None = None, default: _VT = None
) -> dict[str, Any] | _VT:
"""
获取额外的信息。
"""
if key is None:
return self._extras
return self._extras.get(key, None)
return self._extras.get(key, default)
def clear_extra(self):
"""
+26 -13
View File
@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
self.init_handler_md(handler_md)
self.custom_filter_list: List[CustomFilter] = []
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def print_types(self):
result = ""
for k, v in self.handler_params.items():
@@ -136,6 +139,28 @@ class CommandFilter(HandlerFilter):
)
return result
def get_complete_command_names(self):
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
self._cmpl_cmd_names = [
f"{parent} {cmd}" if parent else cmd
for cmd in [self.command_name] + list(self.alias)
for parent in self.parent_command_names or [""]
]
return self._cmpl_cmd_names
def startswith(self, message_str: str) -> bool:
for full_cmd in self.get_complete_command_names():
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
return True
return False
def equals(self, message_str: str) -> bool:
for full_cmd in self.get_complete_command_names():
if message_str == full_cmd:
return True
return False
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -145,19 +170,7 @@ class CommandFilter(HandlerFilter):
# 检查是否以指令开头
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
candidates = [self.command_name] + list(self.alias)
ok = False
for candidate in candidates:
for parent_command_name in self.parent_command_names:
if parent_command_name:
_full = f"{parent_command_name} {candidate}"
else:
_full = candidate
if message_str.startswith(f"{_full} ") or message_str == _full:
message_str = message_str[len(_full) :].strip()
ok = True
break
if not ok:
if not self.startswith(message_str):
return False
# 分割为列表
+15 -5
View File
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
self.custom_filter_list: List[CustomFilter] = []
self.parent_group = parent_group
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def add_sub_command_filter(
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
):
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
"""遍历父节点获取完整的指令名。
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
parent_cmd_names = (
self.parent_group.get_complete_command_names() if self.parent_group else []
)
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
for parent_cmd_name in parent_cmd_names:
for candidate in candidates:
result.append(parent_cmd_name + " " + candidate)
self._cmpl_cmd_names = result
return result
# 以树的形式打印出来
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
return False
return True
def startswith(self, message_str: str) -> bool:
return message_str.startswith(tuple(self.get_complete_command_names()))
def equals(self, message_str: str) -> bool:
return message_str in self.get_complete_command_names()
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
if not self.custom_filter_ok(event, cfg):
return False
complete_command_names = self.get_complete_command_names()
if event.message_str.strip() in complete_command_names:
if self.equals(event.message_str.strip()):
tree = (
self.group_name
+ "\n"
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
)
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False
return self.startswith(event.message_str)
+17 -13
View File
@@ -1348,22 +1348,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
logger.error(f"ltm: {e}")
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd")
@filter.command("alter_cmd", alias={"alter"})
async def alter_cmd(self, event: AstrMessageEvent):
# token = event.message_str.split(" ")
token = self.parse_commands(event.message_str)
if token.len < 2:
if token.len < 3:
yield event.plain_result(
"可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置"
"该指令用于设置指令或指令组的权限。\n"
"格式: /alter_cmd <cmd_name> <admin/member>\n"
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
"/alter_cmd reset config 打开 reset 权限配置"
)
return
cmd_name = token.get(1)
cmd_type = token.get(2)
cmd_name = " ".join(token.tokens[1:-1])
cmd_type = token.get(-1)
# ============================
# 对reset权限进行特殊处理
# ============================
if cmd_name == "reset" and cmd_type == "config":
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get("astrbot", {})
@@ -1413,16 +1413,18 @@ UID: {user_id} 此 ID 可用于设置管理员。
# 查找指令
found_command = None
cmd_group = False
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
if filter_.command_name == cmd_name:
if filter_.equals(cmd_name):
found_command = handler
break
elif isinstance(filter_, CommandGroupFilter):
if cmd_name == filter_.group_name:
if filter_.equals(cmd_name):
found_command = handler
cmd_group = True
break
if not found_command:
@@ -1459,8 +1461,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
else filter.PermissionType.MEMBER
),
)
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
cmd_group_str = "指令组" if cmd_group else "指令"
yield event.plain_result(
f"已将「{cmd_name}{cmd_group_str} 的权限级别调整为 {cmd_type}"
)
async def update_reset_permission(self, scene_key: str, perm_type: str):
"""更新reset命令在特定场景下的权限设置