114 lines
4.3 KiB
Python
114 lines
4.3 KiB
Python
import inspect
|
|
import traceback
|
|
import typing as T
|
|
from dataclasses import dataclass
|
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
from astrbot.core.star import PluginManager
|
|
from astrbot.api import logger
|
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
|
from astrbot.core.star.star import star_map
|
|
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
|
|
|
|
|
@dataclass
|
|
class PipelineContext:
|
|
"""上下文对象,包含管道执行所需的上下文信息"""
|
|
|
|
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
|
plugin_manager: PluginManager # 插件管理器对象
|
|
|
|
async def call_event_hook(
|
|
self,
|
|
event: AstrMessageEvent,
|
|
hook_type: EventType,
|
|
*args,
|
|
) -> bool:
|
|
"""调用事件钩子函数
|
|
|
|
Returns:
|
|
bool: 如果事件被终止,返回 True
|
|
"""
|
|
platform_id = event.get_platform_id()
|
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
|
hook_type, platform_id=platform_id
|
|
)
|
|
for handler in handlers:
|
|
try:
|
|
logger.debug(
|
|
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
|
)
|
|
await handler.handler(event, *args)
|
|
except BaseException:
|
|
logger.error(traceback.format_exc())
|
|
|
|
if event.is_stopped():
|
|
logger.info(
|
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
|
)
|
|
|
|
return event.is_stopped()
|
|
|
|
async def call_handler(
|
|
self,
|
|
event: AstrMessageEvent,
|
|
handler: T.Awaitable,
|
|
*args,
|
|
**kwargs,
|
|
) -> T.AsyncGenerator[None, None]:
|
|
"""执行事件处理函数并处理其返回结果
|
|
|
|
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
|
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
|
2. 协程: 执行一次并处理返回值
|
|
|
|
Args:
|
|
ctx (PipelineContext): 消息管道上下文对象
|
|
event (AstrMessageEvent): 事件对象
|
|
handler (Awaitable): 事件处理函数
|
|
|
|
Returns:
|
|
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
|
"""
|
|
ready_to_call = None # 一个协程或者异步生成器
|
|
|
|
trace_ = None
|
|
|
|
try:
|
|
ready_to_call = handler(event, *args, **kwargs)
|
|
except TypeError as _:
|
|
# 向下兼容
|
|
trace_ = traceback.format_exc()
|
|
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
|
|
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
|
|
|
|
if inspect.isasyncgen(ready_to_call):
|
|
_has_yielded = False
|
|
try:
|
|
async for ret in ready_to_call:
|
|
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
|
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
|
_has_yielded = True
|
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
|
# 如果返回值是 MessageEventResult, 设置结果并继续
|
|
event.set_result(ret)
|
|
yield
|
|
else:
|
|
# 如果返回值是 None, 则不设置结果并继续
|
|
# 继续执行后续阶段
|
|
yield ret
|
|
if not _has_yielded:
|
|
# 如果这个异步生成器没有执行到 yield 分支
|
|
yield
|
|
except Exception as e:
|
|
logger.error(f"Previous Error: {trace_}")
|
|
raise e
|
|
elif inspect.iscoroutine(ready_to_call):
|
|
# 如果只是一个协程, 直接执行
|
|
ret = await ready_to_call
|
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
|
event.set_result(ret)
|
|
yield
|
|
else:
|
|
yield ret
|