fix: 修复 /alter_cmd 指令无法控制指令组、子指令组和子指令组下子指令的问题 (#2873)
* fix: revert changes in command_group.py at 782c036 to fix command group permission check
* fix: 不传递 GroupCommand handler
* perf: alter_cmd 指令支持对子指令、指令组进行配置
* chore: remove test commands and subcommands from test_group
* chore: add cache for complete command names list in CommandFilter and CommandGroupFilter
---------
Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra():
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||
"parsed_params"
|
||||
)
|
||||
is_group_cmd_handler = any(
|
||||
isinstance(f, CommandGroupFilter) for f in handler.event_filters
|
||||
)
|
||||
if not is_group_cmd_handler:
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra(default={}):
|
||||
handlers_parsed_params[handler.handler_full_name] = (
|
||||
event.get_extra("parsed_params")
|
||||
)
|
||||
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
def __init__(
|
||||
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""是否唤醒(是否通过 WakingStage)"""
|
||||
self.is_at_or_wake_command = False
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras = {}
|
||||
self._extras: dict[str, Any] = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self.unified_msg_origin = str(self.session)
|
||||
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||
self._result: MessageEventResult = None
|
||||
self._result: MessageEventResult | None = None
|
||||
"""消息事件的结果"""
|
||||
|
||||
self._has_send_oper = False
|
||||
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(self, key=None):
|
||||
def get_extra(
|
||||
self, key: str | None = None, default: _VT = None
|
||||
) -> dict[str, Any] | _VT:
|
||||
"""
|
||||
获取额外的信息。
|
||||
"""
|
||||
if key is None:
|
||||
return self._extras
|
||||
return self._extras.get(key, None)
|
||||
return self._extras.get(key, default)
|
||||
|
||||
def clear_extra(self):
|
||||
"""
|
||||
|
||||
@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
|
||||
self.init_handler_md(handler_md)
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
|
||||
# Cache for complete command names list
|
||||
self._cmpl_cmd_names: list | None = None
|
||||
|
||||
def print_types(self):
|
||||
result = ""
|
||||
for k, v in self.handler_params.items():
|
||||
@@ -136,6 +139,28 @@ class CommandFilter(HandlerFilter):
|
||||
)
|
||||
return result
|
||||
|
||||
def get_complete_command_names(self):
|
||||
if self._cmpl_cmd_names is not None:
|
||||
return self._cmpl_cmd_names
|
||||
self._cmpl_cmd_names = [
|
||||
f"{parent} {cmd}" if parent else cmd
|
||||
for cmd in [self.command_name] + list(self.alias)
|
||||
for parent in self.parent_command_names or [""]
|
||||
]
|
||||
return self._cmpl_cmd_names
|
||||
|
||||
def startswith(self, message_str: str) -> bool:
|
||||
for full_cmd in self.get_complete_command_names():
|
||||
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
|
||||
return True
|
||||
return False
|
||||
|
||||
def equals(self, message_str: str) -> bool:
|
||||
for full_cmd in self.get_complete_command_names():
|
||||
if message_str == full_cmd:
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
@@ -145,19 +170,7 @@ class CommandFilter(HandlerFilter):
|
||||
|
||||
# 检查是否以指令开头
|
||||
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
||||
candidates = [self.command_name] + list(self.alias)
|
||||
ok = False
|
||||
for candidate in candidates:
|
||||
for parent_command_name in self.parent_command_names:
|
||||
if parent_command_name:
|
||||
_full = f"{parent_command_name} {candidate}"
|
||||
else:
|
||||
_full = candidate
|
||||
if message_str.startswith(f"{_full} ") or message_str == _full:
|
||||
message_str = message_str[len(_full) :].strip()
|
||||
ok = True
|
||||
break
|
||||
if not ok:
|
||||
if not self.startswith(message_str):
|
||||
return False
|
||||
|
||||
# 分割为列表
|
||||
|
||||
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
self.parent_group = parent_group
|
||||
|
||||
# Cache for complete command names list
|
||||
self._cmpl_cmd_names: list | None = None
|
||||
|
||||
def add_sub_command_filter(
|
||||
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
|
||||
):
|
||||
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
|
||||
"""遍历父节点获取完整的指令名。
|
||||
|
||||
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
|
||||
if self._cmpl_cmd_names is not None:
|
||||
return self._cmpl_cmd_names
|
||||
|
||||
parent_cmd_names = (
|
||||
self.parent_group.get_complete_command_names() if self.parent_group else []
|
||||
)
|
||||
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
for parent_cmd_name in parent_cmd_names:
|
||||
for candidate in candidates:
|
||||
result.append(parent_cmd_name + " " + candidate)
|
||||
self._cmpl_cmd_names = result
|
||||
return result
|
||||
|
||||
# 以树的形式打印出来
|
||||
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def startswith(self, message_str: str) -> bool:
|
||||
return message_str.startswith(tuple(self.get_complete_command_names()))
|
||||
|
||||
def equals(self, message_str: str) -> bool:
|
||||
return message_str in self.get_complete_command_names()
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
if not self.custom_filter_ok(event, cfg):
|
||||
return False
|
||||
|
||||
complete_command_names = self.get_complete_command_names()
|
||||
if event.message_str.strip() in complete_command_names:
|
||||
if self.equals(event.message_str.strip()):
|
||||
tree = (
|
||||
self.group_name
|
||||
+ "\n"
|
||||
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||
)
|
||||
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
# return event.message_str.startswith(tuple(complete_command_names))
|
||||
return False
|
||||
return self.startswith(event.message_str)
|
||||
|
||||
+17
-13
@@ -1348,22 +1348,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("alter_cmd")
|
||||
@filter.command("alter_cmd", alias={"alter"})
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
# token = event.message_str.split(" ")
|
||||
token = self.parse_commands(event.message_str)
|
||||
if token.len < 2:
|
||||
if token.len < 3:
|
||||
yield event.plain_result(
|
||||
"可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置"
|
||||
"该指令用于设置指令或指令组的权限。\n"
|
||||
"格式: /alter_cmd <cmd_name> <admin/member>\n"
|
||||
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
|
||||
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
|
||||
"/alter_cmd reset config 打开 reset 权限配置"
|
||||
)
|
||||
return
|
||||
|
||||
cmd_name = token.get(1)
|
||||
cmd_type = token.get(2)
|
||||
cmd_name = " ".join(token.tokens[1:-1])
|
||||
cmd_type = token.get(-1)
|
||||
|
||||
# ============================
|
||||
# 对reset权限进行特殊处理
|
||||
# ============================
|
||||
if cmd_name == "reset" and cmd_type == "config":
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
||||
@@ -1413,16 +1413,18 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
|
||||
# 查找指令
|
||||
found_command = None
|
||||
cmd_group = False
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
if filter_.command_name == cmd_name:
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
if cmd_name == filter_.group_name:
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
cmd_group = True
|
||||
break
|
||||
|
||||
if not found_command:
|
||||
@@ -1459,8 +1461,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
else filter.PermissionType.MEMBER
|
||||
),
|
||||
)
|
||||
|
||||
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
|
||||
cmd_group_str = "指令组" if cmd_group else "指令"
|
||||
yield event.plain_result(
|
||||
f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。"
|
||||
)
|
||||
|
||||
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||
"""更新reset命令在特定场景下的权限设置
|
||||
|
||||
Reference in New Issue
Block a user