diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index cf7472edf..5678b1af4 100644 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -79,7 +79,7 @@ class CommandGroupFilter(HandlerFilter): message_str = event.get_message_str().strip() ls = re.split(r"\s+", message_str) - + if ls[0] != self.group_name: return False, None # 改写 message_str @@ -89,16 +89,16 @@ class CommandGroupFilter(HandlerFilter): parsing_command = " ".join(ls) parsing_command = parsing_command.strip() event.set_extra("parsing_command", parsing_command) - - if parsing_command == "": - # 当前还是指令组 - tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) - raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree) # 判断当前指令组的自定义过滤器 if not self.custom_filter_ok(event, cfg): return False, None + if parsing_command == "": + # 当前还是指令组 + tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) + raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree) + child_command_handler_md = None for sub_filter in self.sub_command_filters: if isinstance(sub_filter, CommandFilter): diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py index cfe80099e..4a5fe5357 100644 --- a/astrbot/core/star/filter/custom_filter.py +++ b/astrbot/core/star/filter/custom_filter.py @@ -1,10 +1,22 @@ -from abc import abstractmethod +from abc import abstractmethod, ABCMeta +from typing import Union from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig -class CustomFilter(HandlerFilter): +class CustomFilterMeta(ABCMeta): + def __and__(cls, other): + if not issubclass(other, CustomFilter): + raise TypeError("Operands must be subclasses of CustomFilter.") + return CustomFilterAnd(cls(), other()) + + def __or__(cls, other): + if not issubclass(other, CustomFilter): + raise TypeError("Operands must be subclasses of CustomFilter.") + return CustomFilterOr(cls(), other()) + +class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): def __init__(self, raise_error: bool = True): self.raise_error = raise_error @@ -12,3 +24,31 @@ class CustomFilter(HandlerFilter): def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: ''' 一个用于重写的自定义Filter ''' raise NotImplementedError + + def __or__(self, other): + return CustomFilterOr(self, other) + + def __and__(self, other): + return CustomFilterAnd(self, other) + +class CustomFilterOr(CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + super().__init__() + if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): + raise ValueError("CustomFilter lass can only operate with other CustomFilter.") + self.filter1 = filter1 + self.filter2 = filter2 + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + return self.filter1.filter(event, cfg) or self.filter2.filter(event, cfg) + +class CustomFilterAnd(CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + super().__init__() + if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): + raise ValueError("CustomFilter lass can only operate with other CustomFilter.") + self.filter1 = filter1 + self.filter2 = filter2 + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + return self.filter1.filter(event, cfg) and self.filter2.filter(event, cfg) diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 14b710ace..c2909e47e 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -7,7 +7,7 @@ from ..filter.command_group import CommandGroupFilter from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType from ..filter.permission import PermissionTypeFilter, PermissionType -from ..filter.custom_filter import CustomFilter +from ..filter.custom_filter import CustomFilter, CustomFilterAnd, CustomFilterOr from ..filter.regex import RegexFilter from typing import Awaitable from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES @@ -81,7 +81,7 @@ def register_command(command_name: str = None, *args, **kwargs): return decorator -def register_custom_filter(custom_type_filter: CustomFilter, *args, **kwargs): +def register_custom_filter(custom_type_filter, *args, **kwargs): '''注册一个自定义的 CustomFilter Args: @@ -106,11 +106,13 @@ def register_custom_filter(custom_type_filter: CustomFilter, *args, **kwargs): if args: raise_error = args[0] + if not isinstance(custom_filter, (CustomFilterAnd, CustomFilterOr)): + custom_filter = custom_filter(raise_error) def decorator(awaitable): # 裸指令,子指令与指令组的区分,指令组会因为标记跳过wake。 if not add_to_event_filters and isinstance(awaitable, RegisteringCommandable): # 指令组,添加到本层的grouphandle中一起判断 - awaitable.parent_group.add_custom_filter(custom_filter(raise_error)) + awaitable.parent_group.add_custom_filter(custom_filter) else: handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) @@ -122,12 +124,12 @@ def register_custom_filter(custom_type_filter: CustomFilter, *args, **kwargs): # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? sub_handle_md = sub_handle.get_handler_md() if sub_handle_md and sub_handle_md.handler_full_name == handle_full_name: - sub_handle.add_custom_filter(custom_filter(raise_error)) + sub_handle.add_custom_filter(custom_filter) else: # 裸指令 handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) - handler_md.event_filters.append(custom_filter(raise_error)) + handler_md.event_filters.append(custom_filter) return awaitable return decorator