diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py index 69e884b89..7174be0fd 100644 --- a/astrbot/api/event/filter/__init__.py +++ b/astrbot/api/event/filter/__init__.py @@ -5,6 +5,7 @@ from astrbot.core.star.register import ( register_regex as regex, register_platform_adapter_type as platform_adapter_type, register_permission_type as permission_type, + register_custom_filter as custom_filter, register_on_llm_request as on_llm_request, register_on_llm_response as on_llm_response, register_llm_tool as llm_tool, @@ -15,6 +16,7 @@ from astrbot.core.star.register import ( from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType +from astrbot.core.star.filter.custom_filter import CustomFilter __all__ = [ 'command', @@ -28,6 +30,8 @@ __all__ = [ 'PlatformAdapterTypeFilter', 'PlatformAdapterType', 'PermissionTypeFilter', + 'CustomFilter', + 'custom_filter', 'PermissionType', 'on_llm_request', 'llm_tool', diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index c12025ef2..32afcb86a 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -32,6 +32,7 @@ class ResultDecorateStage(Stage): self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result'] self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex'] self.content_cleanup_rule = ctx.astrbot_config['platform_settings']['segmented_reply']['content_cleanup_rule'] + # exception self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response'] diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 3e1394f38..1715846c5 100644 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -1,19 +1,23 @@ import re import inspect +from typing import List from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin +from .custom_filter import CustomFilter from ..star_handler import StarHandlerMetadata # 标准指令受到 wake_prefix 的制约。 class CommandFilter(HandlerFilter, ParameterValidationMixin): '''标准指令过滤器''' - def __init__(self, command_name: str, handler_md: StarHandlerMetadata = None): + def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None): self.command_name = command_name + self.alias = alias if alias else set() if handler_md: self.init_handler_md(handler_md) + self.custom_filter_list: List[CustomFilter] = [] def print_types(self): result = "" @@ -42,10 +46,22 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin): def get_handler_md(self) -> StarHandlerMetadata: return self.handler_md + def add_custom_filter(self, custom_filter: CustomFilter): + self.custom_filter_list.append(custom_filter) + + def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + for custom_filter in self.custom_filter_list: + if not custom_filter.filter(event, cfg): + return False + return True + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: if not event.is_at_or_wake_command: return False - + + if not self.custom_filter_ok(event, cfg): + return False + if event.get_extra("parsing_command"): message_str = event.get_extra("parsing_command").strip() else: @@ -53,7 +69,7 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin): # 分割为列表(每个参数之间可能会有多个空格) ls = re.split(r"\s+", message_str) - if self.command_name != ls[0]: + if self.command_name != ls[0] and ls[0] not in self.alias: return False # if len(self.handler_params) == 0 and len(ls) > 1: # # 一定程度避免 LLM 聊天时误判为指令 diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 0dca9916f..228615d81 100644 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -6,36 +6,61 @@ from . import HandlerFilter from .command import CommandFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig +from .custom_filter import CustomFilter from ..star_handler import StarHandlerMetadata # 指令组受到 wake_prefix 的制约。 class CommandGroupFilter(HandlerFilter): - def __init__(self, group_name: str): + def __init__(self, group_name: str, alias: set = None): self.group_name = group_name + self.alias = alias if alias else set() self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = [] + self.custom_filter_list: List[CustomFilter] = [] def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]): self.sub_command_filters.append(sub_command_filter) - + + def add_custom_filter(self, custom_filter: CustomFilter): + self.custom_filter_list.append(custom_filter) + # 以树的形式打印出来 - def print_cmd_tree(self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "") -> str: + def print_cmd_tree(self, + sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], + prefix: str = "", + event: AstrMessageEvent = None, + cfg: AstrBotConfig = None, + ) -> str: result = "" for sub_filter in sub_command_filters: if isinstance(sub_filter, CommandFilter): - cmd_th = sub_filter.print_types() - result += f"{prefix}├── {sub_filter.command_name}" - if cmd_th: - result += f" ({cmd_th})" - else: - result += " (无参数指令)" - - result += "\n" + custom_filter_pass = True + if event and cfg: + custom_filter_pass = sub_filter.custom_filter_ok(event, cfg) + if custom_filter_pass: + cmd_th = sub_filter.print_types() + result += f"{prefix}├── {sub_filter.command_name}" + if cmd_th: + result += f" ({cmd_th})" + else: + result += " (无参数指令)" + result += "\n" elif isinstance(sub_filter, CommandGroupFilter): - result += f"{prefix}├── {sub_filter.group_name}" - result += "\n" - result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"│ ") + custom_filter_pass = True + if event and cfg: + custom_filter_pass = sub_filter.custom_filter_ok(event, cfg) + if custom_filter_pass: + result += f"{prefix}├── {sub_filter.group_name}" + result += "\n" + result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"│ ", event=event, cfg=cfg) + return result - + + def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + for custom_filter in self.custom_filter_list: + if not custom_filter.filter(event, cfg): + return False + return True + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]: if not event.is_at_or_wake_command: return False, None @@ -46,8 +71,8 @@ class CommandGroupFilter(HandlerFilter): message_str = event.get_message_str().strip() ls = re.split(r"\s+", message_str) - - if ls[0] != self.group_name: + + if ls[0] != self.group_name and ls[0] not in self.alias: return False, None # 改写 message_str ls = ls[1:] @@ -56,12 +81,16 @@ class CommandGroupFilter(HandlerFilter): parsing_command = " ".join(ls) parsing_command = parsing_command.strip() event.set_extra("parsing_command", parsing_command) - + + # 判断当前指令组的自定义过滤器 + if not self.custom_filter_ok(event, cfg): + return False, None + if parsing_command == "": # 当前还是指令组 - tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters) + tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree) - + child_command_handler_md = None for sub_filter in self.sub_command_filters: if isinstance(sub_filter, CommandFilter): @@ -73,5 +102,5 @@ class CommandGroupFilter(HandlerFilter): if ok: child_command_handler_md = handler return True, child_command_handler_md - tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters) + tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree) diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py new file mode 100644 index 000000000..5be1b8dbe --- /dev/null +++ b/astrbot/core/star/filter/custom_filter.py @@ -0,0 +1,53 @@ +from abc import abstractmethod, ABCMeta + +from . import HandlerFilter +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig + +class CustomFilterMeta(ABCMeta): + def __and__(cls, other): + if not issubclass(other, CustomFilter): + raise TypeError("Operands must be subclasses of CustomFilter.") + return CustomFilterAnd(cls(), other()) + + def __or__(cls, other): + if not issubclass(other, CustomFilter): + raise TypeError("Operands must be subclasses of CustomFilter.") + return CustomFilterOr(cls(), other()) + +class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): + def __init__(self, raise_error: bool = True, **kwargs): + self.raise_error = raise_error + + @abstractmethod + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + ''' 一个用于重写的自定义Filter ''' + raise NotImplementedError + + def __or__(self, other): + return CustomFilterOr(self, other) + + def __and__(self, other): + return CustomFilterAnd(self, other) + +class CustomFilterOr(CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + super().__init__() + if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): + raise ValueError("CustomFilter lass can only operate with other CustomFilter.") + self.filter1 = filter1 + self.filter2 = filter2 + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + return self.filter1.filter(event, cfg) or self.filter2.filter(event, cfg) + +class CustomFilterAnd(CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + super().__init__() + if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): + raise ValueError("CustomFilter lass can only operate with other CustomFilter.") + self.filter1 = filter1 + self.filter2 = filter2 + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + return self.filter1.filter(event, cfg) and self.filter2.filter(event, cfg) diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index 422c87b4c..ba51f7ab6 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -6,6 +6,7 @@ from .star_handler import ( register_platform_adapter_type, register_regex, register_permission_type, + register_custom_filter, register_on_llm_request, register_on_llm_response, register_llm_tool, @@ -21,6 +22,7 @@ __all__ = [ 'register_platform_adapter_type', 'register_regex', 'register_permission_type', + 'register_custom_filter', 'register_on_llm_request', 'register_on_llm_response', 'register_llm_tool', diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index c5e1239dd..e6bc45900 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -7,6 +7,7 @@ from ..filter.command_group import CommandGroupFilter from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType from ..filter.permission import PermissionTypeFilter, PermissionType +from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr from ..filter.regex import RegexFilter from typing import Awaitable from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES @@ -50,7 +51,7 @@ def get_handler_or_create( star_handlers_registry.append(md) return md -def register_command(command_name: str = None, *args, **kwargs): +def register_command(command_name: str = None, sub_command: str = None, alias: set = None, **kwargs): '''注册一个 Command. ''' @@ -60,11 +61,11 @@ def register_command(command_name: str = None, *args, **kwargs): add_to_event_filters = False if isinstance(command_name, RegisteringCommandable): # 子指令 - new_command = CommandFilter(args[0], None) + new_command = CommandFilter(sub_command, alias, None) command_name.parent_group.add_sub_command_filter(new_command) else: # 裸指令 - new_command = CommandFilter(command_name, None) + new_command = CommandFilter(command_name, alias, None) add_to_event_filters = True def decorator(awaitable): @@ -80,7 +81,65 @@ def register_command(command_name: str = None, *args, **kwargs): return decorator -def register_command_group(command_group_name: str = None, *args, **kwargs): + +def register_custom_filter(custom_type_filter, *args, **kwargs): + '''注册一个自定义的 CustomFilter + + Args: + custom_type_filter: 在裸指令时为CustomFilter对象 + 在指令组时为父指令的RegisteringCommandable对象,即self或者command_group的返回 + raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True + ''' + add_to_event_filters = False + raise_error = True + + # 判断是否是指令组,指令组则添加到指令组的CommandGroupFilter对象中在waking_check的时候一起判断 + if isinstance(custom_type_filter, RegisteringCommandable): + # 子指令, 此时函数为RegisteringCommandable对象的方法,首位参数为RegisteringCommandable对象的self。 + parent_register_commandable = custom_type_filter + custom_filter = args[0] + if len(args) > 1: + raise_error = args[1] + else: + # 裸指令 + add_to_event_filters = True + custom_filter = custom_type_filter + if args: + raise_error = args[0] + + if not isinstance(custom_filter, (CustomFilterAnd, CustomFilterOr)): + custom_filter = custom_filter(raise_error) + + def decorator(awaitable): + # 裸指令,子指令与指令组的区分,指令组会因为标记跳过wake。 + if not add_to_event_filters and isinstance(awaitable, RegisteringCommandable) or \ + (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)): + # 指令组 与 根指令组,添加到本层的grouphandle中一起判断 + awaitable.parent_group.add_custom_filter(custom_filter) + else: + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) + + if not add_to_event_filters and not isinstance(awaitable, RegisteringCommandable): + # 底层子指令 + handle_full_name = get_handler_full_name(awaitable) + for sub_handle in parent_register_commandable.parent_group.sub_command_filters: + # 所有符合fullname一致的子指令handle添加自定义过滤器。 + # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? + sub_handle_md = sub_handle.get_handler_md() + if sub_handle_md and sub_handle_md.handler_full_name == handle_full_name: + sub_handle.add_custom_filter(custom_filter) + + else: + # 裸指令 + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) + handler_md.event_filters.append(custom_filter) + + return awaitable + return decorator + +def register_command_group( + command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs +): '''注册一个 CommandGroup ''' @@ -90,11 +149,11 @@ def register_command_group(command_group_name: str = None, *args, **kwargs): add_to_event_filters = False if isinstance(command_group_name, RegisteringCommandable): # 子指令组 - new_group = CommandGroupFilter(args[0]) + new_group = CommandGroupFilter(sub_command, alias) command_group_name.parent_group.add_sub_command_filter(new_group) else: # 根指令组 - new_group = CommandGroupFilter(command_group_name) + new_group = CommandGroupFilter(command_group_name, alias) add_to_event_filters = True def decorator(obj): @@ -102,7 +161,7 @@ def register_command_group(command_group_name: str = None, *args, **kwargs): # 根指令组 handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs) handler_md.event_filters.append(new_group) - + return RegisteringCommandable(new_group) return decorator @@ -111,7 +170,8 @@ class RegisteringCommandable(): '''用于指令组级联注册''' group = register_command_group command = register_command - + custom_filter = register_custom_filter + def __init__(self, parent_group: CommandGroupFilter): self.parent_group = parent_group