diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 6b8c6b878..5d7489d35 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -17,7 +17,12 @@ def get_handler_full_name(awaitable: Awaitable) -> str: '''获取 Handler 的全名''' return f"{awaitable.__module__}_{awaitable.__name__}" -def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False, **kwargs) -> StarHandlerMetadata: +def get_handler_or_create( + handler: Awaitable, + event_type: EventType, + dont_add = False, + **kwargs +) -> StarHandlerMetadata: '''获取 Handler 或者创建一个新的 Handler''' handler_full_name = get_handler_full_name(handler) md = star_handlers_registry.get_handler_by_full_name(handler_full_name) @@ -30,18 +35,27 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = handler_name=handler.__name__, handler_module_path=handler.__module__, handler=handler, - event_filters=[], + event_filters=[] ) + + # 插件handler的附加额外信息 + if 'desc' in kwargs: + md.desc = kwargs['desc'] + del kwargs['desc'] + md.extras_configs = kwargs + if handler.__doc__: md.desc = handler.__doc__.strip() if not dont_add: star_handlers_registry.append(md) return md -def register_command(command_name: str = None, *args): +def register_command(command_name: str = None, *args, **kwargs): '''注册一个 Command. ''' + # print("command: ", command_name, args, kwargs) + new_command = None add_to_event_filters = False if isinstance(command_name, RegisteringCommandable): @@ -54,7 +68,7 @@ def register_command(command_name: str = None, *args): add_to_event_filters = True def decorator(awaitable): - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) new_command.init_handler_md(handler_md) if add_to_event_filters: # 裸指令 @@ -64,10 +78,12 @@ def register_command(command_name: str = None, *args): return decorator -def register_command_group(command_group_name: str = None, *args): +def register_command_group(command_group_name: str = None, *args, **kwargs): '''注册一个 CommandGroup ''' + # print("commandgroup: ", command_group_name,args, kwargs) + new_group = None add_to_event_filters = False if isinstance(command_group_name, RegisteringCommandable): @@ -82,7 +98,7 @@ def register_command_group(command_group_name: str = None, *args): def decorator(obj): if add_to_event_filters: # 根指令组 - handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent) + handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs) handler_md.event_filters.append(new_group) return RegisteringCommandable(new_group) @@ -97,16 +113,16 @@ class RegisteringCommandable(): def __init__(self, parent_group: CommandGroupFilter): self.parent_group = parent_group -def register_event_message_type(event_message_type: EventMessageType): +def register_event_message_type(event_message_type: EventMessageType, **kwargs): '''注册一个 EventMessageType''' def decorator(awaitable): - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, kwargs) handler_md.event_filters.append(EventMessageTypeFilter(event_message_type)) return awaitable return decorator -def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType): +def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType, **kwargs): '''注册一个 PlatformAdapterType''' def decorator(awaitable): handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) @@ -115,10 +131,10 @@ def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType): return decorator -def register_regex(regex: str): +def register_regex(regex: str, **kwargs): '''注册一个 Regex''' def decorator(awaitable): - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) handler_md.event_filters.append(RegexFilter(regex)) return awaitable @@ -138,7 +154,7 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool return decorator -def register_on_llm_request(): +def register_on_llm_request(**kwargs): '''当有 LLM 请求时的事件 Examples: @@ -153,12 +169,12 @@ def register_on_llm_request(): 请务必接收两个参数:event, request ''' def decorator(awaitable): - _ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent) + _ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs) return awaitable return decorator -def register_on_llm_response(): +def register_on_llm_response(**kwargs): '''当有 LLM 请求后的事件 Examples: @@ -173,7 +189,7 @@ def register_on_llm_response(): 请务必接收两个参数:event, request ''' def decorator(awaitable): - _ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent) + _ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs) return awaitable return decorator @@ -219,18 +235,18 @@ def register_llm_tool(name: str = None): return decorator -def register_on_decorating_result(): +def register_on_decorating_result(**kwargs): '''在发送消息前的事件''' def decorator(awaitable): - _ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent) + _ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent, **kwargs) return awaitable return decorator -def register_after_message_sent(): +def register_after_message_sent(**kwargs): '''在消息发送后的事件''' def decorator(awaitable): - _ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent) + _ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent, **kwargs) return awaitable return decorator \ No newline at end of file diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 378038327..389935a37 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,34 +1,41 @@ from __future__ import annotations import enum -from dataclasses import dataclass +import heapq +from dataclasses import dataclass, field from typing import Awaitable, List, Dict, TypeVar, Generic from .filter import HandlerFilter from .star import star_map T = TypeVar('T', bound='StarHandlerMetadata') -class StarHandlerRegistry(Generic[T], List[T]): +class StarHandlerRegistry(Generic[T]): '''用于存储所有的 Star Handler''' star_handlers_map: Dict[str, StarHandlerMetadata] = {} '''用于快速查找。key 是 handler_full_name''' + _handlers = [] def append(self, handler: StarHandlerMetadata): '''添加一个 Handler''' - super().append(handler) + if 'priority' not in handler.extras_configs: + handler.extras_configs['priority'] = 0 + + heapq.heappush(self._handlers, (-handler.extras_configs['priority'], handler)) self.star_handlers_map[handler.handler_full_name] = handler - def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]: + def _print_handlers(self): + '''打印所有的 Handler''' + for _, handler in self._handlers: + print(handler.handler_full_name) + + def get_handlers_by_event_type(self, event_type: EventType, only_activated=True) -> List[StarHandlerMetadata]: '''通过事件类型获取 Handler''' - if only_activated: - return [ - handler - for handler in self - if handler.event_type == event_type and - star_map[handler.handler_module_path] and - star_map[handler.handler_module_path].activated - ] - else: - return [handler for handler in self if handler.event_type == event_type] + handlers = [ + handler + for _, handler in self._handlers + if handler.event_type == event_type and + (not only_activated or (star_map[handler.handler_module_path] and star_map[handler.handler_module_path].activated)) + ] + return handlers def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata: '''通过 Handler 的全名获取 Handler''' @@ -36,7 +43,25 @@ class StarHandlerRegistry(Generic[T], List[T]): def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]: '''通过模块名获取 Handler''' - return [handler for handler in self if handler.handler_module_path == module_name] + return [handler for _, handler in self._handlers if handler.handler_module_path == module_name] + + def clear(self): + '''清空所有的 Handler''' + self.star_handlers_map.clear() + self._handlers.clear() + + def remove(self, handler: StarHandlerMetadata): + '''删除一个 Handler''' + self._handlers.remove(handler) + del self.star_handlers_map[handler.handler_full_name] + + def __iter__(self): + '''使 StarHandlerRegistry 支持迭代''' + return (handler for _, handler in self._handlers) + + def __len__(self): + '''返回 Handler 的数量''' + return len(self._handlers) star_handlers_registry = StarHandlerRegistry() @@ -76,3 +101,10 @@ class StarHandlerMetadata(): desc: str = "" '''Handler 的描述信息''' + + extras_configs: dict = field(default_factory=dict) + '''插件注册的一些其他的信息, 如 priority 等''' + + def __lt__(self, other: StarHandlerMetadata): + '''定义小于运算符以支持优先队列''' + return self.extras_configs.get('priority', 0) < other.extras_configs.get('priority', 0) \ No newline at end of file diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index a7a694fce..e06908605 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -146,7 +146,6 @@ class PluginManager: smd.star_cls.__del__() star_handlers_registry.clear() - star_handlers_registry.star_handlers_map.clear() star_map.clear() star_registry.clear() for key in list(sys.modules.keys()): diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1d336a57a..86a99c6c1 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -30,7 +30,6 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): success, err_message = await plugin_manager_pm.reload() assert success is True assert err_message is None - assert len(star_handlers_registry) > 0 # package @pytest.mark.asyncio async def test_plugin_crud(plugin_manager_pm: PluginManager):