From 53b9497c18ccea208625fd2585fcc34b04658035 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Thu, 27 Mar 2025 21:32:38 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E5=A2=9E=E5=8A=A0=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/pipeline/scheduler.py | 20 ++++++++++++++------ main.py | 4 ++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 2ed3c0d2a..c5339ac4b 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -25,37 +25,45 @@ class PipelineScheduler: async def _process_stages(self, event: AstrMessageEvent, from_stage=0): """依次执行各个阶段""" for i in range(from_stage, len(registered_stages)): - stage = registered_stages[i] + stage = registered_stages[i] # 获取当前要执行的阶段 # logger.debug(f"执行阶段 {stage.__class__ .__name__}") - coroutine = stage.process(event) + coroutine = stage.process( + event + ) # 调用阶段的process方法, 返回协程或者异步生成器 + if isinstance(coroutine, AsyncGenerator): + # 如果返回的是异步生成器, 实现洋葱模型的核心 async for _ in coroutine: + # 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段 if event.is_stopped(): logger.debug( f"阶段 {stage.__class__.__name__} 已终止事件传播。" ) break + + # 递归调用, 处理所有后续阶段 await self._process_stages(event, i + 1) + + # 此处是后续所有阶段处理完毕后返回的点, 执行后置处理 if event.is_stopped(): logger.debug( f"阶段 {stage.__class__.__name__} 已终止事件传播。" ) break else: + # 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件) + # 简单地等待它执行完成, 然后继续执行下一个阶段 await coroutine if event.is_stopped(): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - if event.is_stopped(): - logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") - break - async def execute(self, event: AstrMessageEvent): """执行 pipeline""" await self._process_stages(event) + # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 if not event._has_send_oper and event.get_platform_name() == "webchat": await event.send(None) diff --git a/main.py b/main.py index 9937bb10f..c0ca48274 100644 --- a/main.py +++ b/main.py @@ -79,5 +79,5 @@ if __name__ == "__main__": # print logo logger.info(logo_tmpl) - dashboard_lifecycle = InitialLoader(db, log_broker) - asyncio.run(dashboard_lifecycle.start()) + core_lifecycle = InitialLoader(db, log_broker) + asyncio.run(core_lifecycle.start())