feat: 事件钩子支持 yield 方式发送消息

This commit is contained in:
Soulter
2025-02-19 15:29:10 +08:00
parent 4678222e9b
commit 782c0367d0
6 changed files with 30 additions and 30 deletions
@@ -64,12 +64,14 @@ class LLMRequestSubStage(Stage):
if not req.prompt and not req.image_urls: if not req.prompt and not req.image_urls:
return return
# 执行请求 LLM 前事件。 # 执行请求 LLM 前事件钩子
# 装饰 system_prompt 等功能 # 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent) handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers: for handler in handlers:
try: try:
await handler.handler(event, req) wrapper = self._call_handler(self.ctx, event, handler.handler, req)
async for ret in wrapper:
yield ret
except BaseException: except BaseException:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
@@ -86,7 +88,9 @@ class LLMRequestSubStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent) handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers: for handler in handlers:
try: try:
await handler.handler(event, llm_response) wrapper = self._call_handler(self.ctx, event, handler.handler, llm_response)
async for ret in wrapper:
yield ret
except BaseException: except BaseException:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
+7 -2
View File
@@ -1,6 +1,7 @@
import random import random
import asyncio import asyncio
import math import math
import traceback
from typing import Union, AsyncGenerator from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage from ..stage import register_stage, Stage
from ..context import PipelineContext from ..context import PipelineContext
@@ -88,7 +89,11 @@ class RespondStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent) handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers: for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。 try:
await handler.handler(event) wrapper = self._call_handler(self.ctx, event, handler.handler)
async for ret in wrapper:
yield ret
except BaseException:
logger.error(traceback.format_exc())
event.clear_result() event.clear_result()
@@ -59,9 +59,16 @@ class ResultDecorateStage(Stage):
async for _ in self.content_safe_check_stage.process(event, check_text=text): async for _ in self.content_safe_check_stage.process(event, check_text=text):
yield yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent) handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers: for handler in handlers:
await handler.handler(event) try:
wrapper = self._call_handler(self.ctx, event, handler.handler)
async for ret in wrapper:
yield ret
except BaseException:
logger.error(traceback.format_exc())
# 需要再获取一次。插件可能直接对 chain 进行了替换。 # 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result() result = event.get_result()
+4 -3
View File
@@ -36,16 +36,17 @@ class Stage(abc.ABC):
ctx: PipelineContext, ctx: PipelineContext,
event: AstrMessageEvent, event: AstrMessageEvent,
handler: Awaitable, handler: Awaitable,
**params *args,
**kwargs,
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
'''调用 Handler。''' '''调用 Handler。'''
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None ready_to_call = None
try: try:
ready_to_call = handler(event, **params) ready_to_call = handler(event, *args, **kwargs)
except TypeError as e: except TypeError as e:
# 向下兼容 # 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params) ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator): if isinstance(ready_to_call, AsyncGenerator):
async for ret in ready_to_call: async for ret in ready_to_call:
+1 -19
View File
@@ -77,15 +77,11 @@ class WakingCheckStage(Stage):
# 检查插件的 handler filter # 检查插件的 handler filter
activated_handlers = [] activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent): for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
# filter 需满足 AND 逻辑关系 # filter 需满足 AND 逻辑关系
passed = True passed = True
permission_not_pass = False permission_not_pass = False
# 在输入指令组正确但是子指令错误的情况下提醒用户
command_group_passed = False
command_group_tree = None
if len(handler.event_filters) == 0: if len(handler.event_filters) == 0:
continue continue
@@ -94,12 +90,6 @@ class WakingCheckStage(Stage):
if isinstance(filter, PermissionTypeFilter): if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config): if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True permission_not_pass = True
elif isinstance(filter, CommandGroupFilter):
if filter.filter(event, self.ctx.astrbot_config):
command_group_passed = True
command_group_tree = filter.print_cmd_tree(filter.sub_command_filters)
passed = False
break
else: else:
if not filter.filter(event, self.ctx.astrbot_config): if not filter.filter(event, self.ctx.astrbot_config):
passed = False passed = False
@@ -128,14 +118,6 @@ class WakingCheckStage(Stage):
handlers_parsed_params[handler.handler_full_name] = event.get_extra( handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params" "parsed_params"
) )
if not passed and command_group_passed:
await event.send(
MessageEventResult().message(
f"插件 {star_map[handler.handler_module_path].name} 没有该指令。指令树:\n{command_group_tree}"
)
)
event.stop_event()
return
event.clear_extra() event.clear_extra()
+3 -2
View File
@@ -97,5 +97,6 @@ class CommandGroupFilter(HandlerFilter):
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree) raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
complete_command_names = [name + " " for name in complete_command_names] # complete_command_names = [name + " " for name in complete_command_names]
return event.message_str.startswith(tuple(complete_command_names)) # return event.message_str.startswith(tuple(complete_command_names))
return False