feat: 支持插件handler优先级

This commit is contained in:
Soulter
2025-02-06 12:35:43 +08:00
parent 7d4c07e4f6
commit 461f1bb07c
4 changed files with 82 additions and 36 deletions
+35 -19
View File
@@ -17,7 +17,12 @@ def get_handler_full_name(awaitable: Awaitable) -> str:
'''获取 Handler 的全名'''
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False, **kwargs) -> StarHandlerMetadata:
def get_handler_or_create(
handler: Awaitable,
event_type: EventType,
dont_add = False,
**kwargs
) -> StarHandlerMetadata:
'''获取 Handler 或者创建一个新的 Handler'''
handler_full_name = get_handler_full_name(handler)
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
@@ -30,18 +35,27 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
handler_name=handler.__name__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[],
event_filters=[]
)
# 插件handler的附加额外信息
if 'desc' in kwargs:
md.desc = kwargs['desc']
del kwargs['desc']
md.extras_configs = kwargs
if handler.__doc__:
md.desc = handler.__doc__.strip()
if not dont_add:
star_handlers_registry.append(md)
return md
def register_command(command_name: str = None, *args):
def register_command(command_name: str = None, *args, **kwargs):
'''注册一个 Command.
'''
# print("command: ", command_name, args, kwargs)
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
@@ -54,7 +68,7 @@ def register_command(command_name: str = None, *args):
add_to_event_filters = True
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
new_command.init_handler_md(handler_md)
if add_to_event_filters:
# 裸指令
@@ -64,10 +78,12 @@ def register_command(command_name: str = None, *args):
return decorator
def register_command_group(command_group_name: str = None, *args):
def register_command_group(command_group_name: str = None, *args, **kwargs):
'''注册一个 CommandGroup
'''
# print("commandgroup: ", command_group_name,args, kwargs)
new_group = None
add_to_event_filters = False
if isinstance(command_group_name, RegisteringCommandable):
@@ -82,7 +98,7 @@ def register_command_group(command_group_name: str = None, *args):
def decorator(obj):
if add_to_event_filters:
# 根指令组
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
@@ -97,16 +113,16 @@ class RegisteringCommandable():
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
def register_event_message_type(event_message_type: EventMessageType):
def register_event_message_type(event_message_type: EventMessageType, **kwargs):
'''注册一个 EventMessageType'''
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, kwargs)
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
return awaitable
return decorator
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType, **kwargs):
'''注册一个 PlatformAdapterType'''
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
@@ -115,10 +131,10 @@ def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
return decorator
def register_regex(regex: str):
def register_regex(regex: str, **kwargs):
'''注册一个 Regex'''
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(RegexFilter(regex))
return awaitable
@@ -138,7 +154,7 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
return decorator
def register_on_llm_request():
def register_on_llm_request(**kwargs):
'''当有 LLM 请求时的事件
Examples:
@@ -153,12 +169,12 @@ def register_on_llm_request():
请务必接收两个参数:event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent)
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs)
return awaitable
return decorator
def register_on_llm_response():
def register_on_llm_response(**kwargs):
'''当有 LLM 请求后的事件
Examples:
@@ -173,7 +189,7 @@ def register_on_llm_response():
请务必接收两个参数:event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs)
return awaitable
return decorator
@@ -219,18 +235,18 @@ def register_llm_tool(name: str = None):
return decorator
def register_on_decorating_result():
def register_on_decorating_result(**kwargs):
'''在发送消息前的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent)
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent, **kwargs)
return awaitable
return decorator
def register_after_message_sent():
def register_after_message_sent(**kwargs):
'''在消息发送后的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent)
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent, **kwargs)
return awaitable
return decorator
+47 -15
View File
@@ -1,34 +1,41 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
import heapq
from dataclasses import dataclass, field
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
T = TypeVar('T', bound='StarHandlerMetadata')
class StarHandlerRegistry(Generic[T], List[T]):
class StarHandlerRegistry(Generic[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
_handlers = []
def append(self, handler: StarHandlerMetadata):
'''添加一个 Handler'''
super().append(handler)
if 'priority' not in handler.extras_configs:
handler.extras_configs['priority'] = 0
heapq.heappush(self._handlers, (-handler.extras_configs['priority'], handler))
self.star_handlers_map[handler.handler_full_name] = handler
def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]:
def _print_handlers(self):
'''打印所有的 Handler'''
for _, handler in self._handlers:
print(handler.handler_full_name)
def get_handlers_by_event_type(self, event_type: EventType, only_activated=True) -> List[StarHandlerMetadata]:
'''通过事件类型获取 Handler'''
if only_activated:
return [
handler
for handler in self
if handler.event_type == event_type and
star_map[handler.handler_module_path] and
star_map[handler.handler_module_path].activated
]
else:
return [handler for handler in self if handler.event_type == event_type]
handlers = [
handler
for _, handler in self._handlers
if handler.event_type == event_type and
(not only_activated or (star_map[handler.handler_module_path] and star_map[handler.handler_module_path].activated))
]
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
'''通过 Handler 的全名获取 Handler'''
@@ -36,7 +43,25 @@ class StarHandlerRegistry(Generic[T], List[T]):
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for handler in self if handler.handler_module_path == module_name]
return [handler for _, handler in self._handlers if handler.handler_module_path == module_name]
def clear(self):
'''清空所有的 Handler'''
self.star_handlers_map.clear()
self._handlers.clear()
def remove(self, handler: StarHandlerMetadata):
'''删除一个 Handler'''
self._handlers.remove(handler)
del self.star_handlers_map[handler.handler_full_name]
def __iter__(self):
'''使 StarHandlerRegistry 支持迭代'''
return (handler for _, handler in self._handlers)
def __len__(self):
'''返回 Handler 的数量'''
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()
@@ -76,3 +101,10 @@ class StarHandlerMetadata():
desc: str = ""
'''Handler 的描述信息'''
extras_configs: dict = field(default_factory=dict)
'''插件注册的一些其他的信息, 如 priority 等'''
def __lt__(self, other: StarHandlerMetadata):
'''定义小于运算符以支持优先队列'''
return self.extras_configs.get('priority', 0) < other.extras_configs.get('priority', 0)
-1
View File
@@ -146,7 +146,6 @@ class PluginManager:
smd.star_cls.__del__()
star_handlers_registry.clear()
star_handlers_registry.star_handlers_map.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
-1
View File
@@ -30,7 +30,6 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = await plugin_manager_pm.reload()
assert success is True
assert err_message is None
assert len(star_handlers_registry) > 0 # package
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):