From 9294b448315584da011b7986e98792902c835e96 Mon Sep 17 00:00:00 2001 From: whatevertogo <1879483647@qq.com> Date: Tue, 24 Feb 2026 13:53:29 +0800 Subject: [PATCH] fix: resolve pipeline and star import cycles (#5353) * fix: resolve pipeline and star import cycles - Add bootstrap.py and stage_order.py to break circular dependencies - Export Context, PluginManager, StarTools from star module - Update pipeline __init__ to defer imports - Split pipeline initialization into separate bootstrap module Co-Authored-By: Claude Sonnet 4.6 * fix: add logging for get_config() failure in Star class * fix: reorder logger initialization in base.py --------- Co-authored-by: whatevertogo Co-authored-by: Claude Sonnet 4.6 --- astrbot/core/pipeline/__init__.py | 87 ++++++++++++++----- astrbot/core/pipeline/bootstrap.py | 52 +++++++++++ astrbot/core/pipeline/context.py | 6 +- .../method/agent_sub_stages/internal.py | 2 +- .../method/agent_sub_stages/third_party.py | 4 +- astrbot/core/pipeline/scheduler.py | 4 +- astrbot/core/pipeline/stage_order.py | 15 ++++ astrbot/core/star/__init__.py | 75 +++------------- astrbot/core/star/base.py | 87 +++++++++++++++++++ astrbot/core/star/context.py | 17 +++- astrbot/core/star/register/star.py | 2 +- astrbot/core/star/register/star_handler.py | 9 +- astrbot/core/star/star_manager.py | 3 + 13 files changed, 264 insertions(+), 99 deletions(-) create mode 100644 astrbot/core/pipeline/bootstrap.py create mode 100644 astrbot/core/pipeline/stage_order.py create mode 100644 astrbot/core/star/base.py diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 75fef84d3..0363d4692 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -1,30 +1,60 @@ +"""Pipeline package exports. + +This module intentionally avoids eager imports of all pipeline stage modules to +prevent import-time cycles. Stage classes remain available via lazy attribute +resolution for backward compatibility. +""" + +from __future__ import annotations + +from importlib import import_module +from typing import Any + from astrbot.core.message.message_event_result import ( EventResultType, MessageEventResult, ) -from .content_safety_check.stage import ContentSafetyCheckStage -from .preprocess_stage.stage import PreProcessStage -from .process_stage.stage import ProcessStage -from .rate_limit_check.stage import RateLimitStage -from .respond.stage import RespondStage -from .result_decorate.stage import ResultDecorateStage -from .session_status_check.stage import SessionStatusCheckStage -from .waking_check.stage import WakingCheckStage -from .whitelist_check.stage import WhitelistCheckStage +from .stage_order import STAGES_ORDER -# 管道阶段顺序 -STAGES_ORDER = [ - "WakingCheckStage", # 检查是否需要唤醒 - "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 - "SessionStatusCheckStage", # 检查会话是否整体启用 - "RateLimitStage", # 检查会话是否超过频率限制 - "ContentSafetyCheckStage", # 检查内容安全 - "PreProcessStage", # 预处理 - "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 - "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 - "RespondStage", # 发送消息 -] +_LAZY_EXPORTS = { + "ContentSafetyCheckStage": ( + "astrbot.core.pipeline.content_safety_check.stage", + "ContentSafetyCheckStage", + ), + "PreProcessStage": ( + "astrbot.core.pipeline.preprocess_stage.stage", + "PreProcessStage", + ), + "ProcessStage": ( + "astrbot.core.pipeline.process_stage.stage", + "ProcessStage", + ), + "RateLimitStage": ( + "astrbot.core.pipeline.rate_limit_check.stage", + "RateLimitStage", + ), + "RespondStage": ( + "astrbot.core.pipeline.respond.stage", + "RespondStage", + ), + "ResultDecorateStage": ( + "astrbot.core.pipeline.result_decorate.stage", + "ResultDecorateStage", + ), + "SessionStatusCheckStage": ( + "astrbot.core.pipeline.session_status_check.stage", + "SessionStatusCheckStage", + ), + "WakingCheckStage": ( + "astrbot.core.pipeline.waking_check.stage", + "WakingCheckStage", + ), + "WhitelistCheckStage": ( + "astrbot.core.pipeline.whitelist_check.stage", + "WhitelistCheckStage", + ), +} __all__ = [ "ContentSafetyCheckStage", @@ -36,6 +66,21 @@ __all__ = [ "RespondStage", "ResultDecorateStage", "SessionStatusCheckStage", + "STAGES_ORDER", "WakingCheckStage", "WhitelistCheckStage", ] + + +def __getattr__(name: str) -> Any: + if name not in _LAZY_EXPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module_path, attr_name = _LAZY_EXPORTS[name] + module = import_module(module_path) + value = getattr(module, attr_name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/astrbot/core/pipeline/bootstrap.py b/astrbot/core/pipeline/bootstrap.py new file mode 100644 index 000000000..4bb7ceadb --- /dev/null +++ b/astrbot/core/pipeline/bootstrap.py @@ -0,0 +1,52 @@ +"""Pipeline bootstrap utilities.""" + +from importlib import import_module + +from .stage import registered_stages + +_BUILTIN_STAGE_MODULES = ( + "astrbot.core.pipeline.waking_check.stage", + "astrbot.core.pipeline.whitelist_check.stage", + "astrbot.core.pipeline.session_status_check.stage", + "astrbot.core.pipeline.rate_limit_check.stage", + "astrbot.core.pipeline.content_safety_check.stage", + "astrbot.core.pipeline.preprocess_stage.stage", + "astrbot.core.pipeline.process_stage.stage", + "astrbot.core.pipeline.result_decorate.stage", + "astrbot.core.pipeline.respond.stage", +) + +_EXPECTED_STAGE_NAMES = { + "WakingCheckStage", + "WhitelistCheckStage", + "SessionStatusCheckStage", + "RateLimitStage", + "ContentSafetyCheckStage", + "PreProcessStage", + "ProcessStage", + "ResultDecorateStage", + "RespondStage", +} + +_builtin_stages_registered = False + + +def ensure_builtin_stages_registered() -> None: + """Ensure built-in pipeline stages are imported and registered.""" + global _builtin_stages_registered + + if _builtin_stages_registered: + return + + stage_names = {stage_cls.__name__ for stage_cls in registered_stages} + if _EXPECTED_STAGE_NAMES.issubset(stage_names): + _builtin_stages_registered = True + return + + for module_path in _BUILTIN_STAGE_MODULES: + import_module(module_path) + + _builtin_stages_registered = True + + +__all__ = ["ensure_builtin_stages_registered"] diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index a6cd567e0..963f4bdac 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from dataclasses import dataclass +from typing import Any from astrbot.core.config import AstrBotConfig -from astrbot.core.star import PluginManager from .context_utils import call_event_hook, call_handler @@ -11,7 +13,7 @@ class PipelineContext: """上下文对象,包含管道执行所需的上下文信息""" astrbot_config: AstrBotConfig # AstrBot 配置对象 - plugin_manager: PluginManager # 插件管理器对象 + plugin_manager: Any # 插件管理器对象 astrbot_config_id: str call_handler = call_handler call_event_hook = call_event_hook diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 91e7f0f5a..98cf77fcc 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -19,6 +19,7 @@ from astrbot.core.message.message_event_result import ( MessageEventResult, ResultContentType, ) +from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( LLMResponse, @@ -30,7 +31,6 @@ from astrbot.core.utils.session_lock import session_lock_manager from .....astr_agent_run_util import run_agent, run_live_agent from ....context import PipelineContext, call_event_hook -from ...stage import Stage class InternalAgentSubStage(Stage): diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index b590bd77e..7fb5cee82 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -8,6 +8,7 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( DashscopeAgentRunner, ) from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner +from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( MessageChain, @@ -17,6 +18,7 @@ from astrbot.core.message.message_event_result import ( if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner +from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( ProviderRequest, @@ -25,9 +27,7 @@ from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric from .....astr_agent_context import AgentContextWrapper, AstrAgentContext -from .....astr_agent_hooks import MAIN_AGENT_HOOKS from ....context import PipelineContext, call_event_hook -from ...stage import Stage AGENT_RUNNER_TYPE_KEY = { "dify": "dify_agent_runner_provider_id", diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index c4a65077a..ffb9c5c99 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -8,15 +8,17 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import ( ) from astrbot.core.utils.active_event_registry import active_event_registry -from . import STAGES_ORDER +from .bootstrap import ensure_builtin_stages_registered from .context import PipelineContext from .stage import registered_stages +from .stage_order import STAGES_ORDER class PipelineScheduler: """管道调度器,负责调度各个阶段的执行""" def __init__(self, context: PipelineContext) -> None: + ensure_builtin_stages_registered() registered_stages.sort( key=lambda x: STAGES_ORDER.index(x.__name__), ) # 按照顺序排序 diff --git a/astrbot/core/pipeline/stage_order.py b/astrbot/core/pipeline/stage_order.py new file mode 100644 index 000000000..f99f57264 --- /dev/null +++ b/astrbot/core/pipeline/stage_order.py @@ -0,0 +1,15 @@ +"""Pipeline stage execution order.""" + +STAGES_ORDER = [ + "WakingCheckStage", # 检查是否需要唤醒 + "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 + "SessionStatusCheckStage", # 检查会话是否整体启用 + "RateLimitStage", # 检查会话是否超过频率限制 + "ContentSafetyCheckStage", # 检查内容安全 + "PreProcessStage", # 预处理 + "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 + "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 + "RespondStage", # 发送消息 +] + +__all__ = ["STAGES_ORDER"] diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 2bf86872e..796e0bd68 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -1,68 +1,19 @@ -from astrbot.core import html_renderer +# 兼容导出: Provider 从 provider 模块重新导出 from astrbot.core.provider import Provider -from astrbot.core.star.star_tools import StarTools -from astrbot.core.utils.command_parser import CommandParserMixin -from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin +from .base import Star from .context import Context from .star import StarMetadata, star_map, star_registry from .star_manager import PluginManager +from .star_tools import StarTools - -class Star(CommandParserMixin, PluginKVStoreMixin): - """所有插件(Star)的父类,所有插件都应该继承于这个类""" - - author: str - name: str - - def __init__(self, context: Context, config: dict | None = None) -> None: - StarTools.initialize(context) - self.context = context - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if not star_map.get(cls.__module__): - metadata = StarMetadata( - star_cls_type=cls, - module_path=cls.__module__, - ) - star_map[cls.__module__] = metadata - star_registry.append(metadata) - else: - star_map[cls.__module__].star_cls_type = cls - star_map[cls.__module__].module_path = cls.__module__ - - async def text_to_image(self, text: str, return_url=True) -> str: - """将文本转换为图片""" - return await html_renderer.render_t2i( - text, - return_url=return_url, - template_name=self.context._config.get("t2i_active_template"), - ) - - async def html_render( - self, - tmpl: str, - data: dict, - return_url=True, - options: dict | None = None, - ) -> str: - """渲染 HTML""" - return await html_renderer.render_custom_template( - tmpl, - data, - return_url=return_url, - options=options, - ) - - async def initialize(self) -> None: - """当插件被激活时会调用这个方法""" - - async def terminate(self) -> None: - """当插件被禁用、重载插件时会调用这个方法""" - - def __del__(self) -> None: - """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" - - -__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"] +__all__ = [ + "Context", + "PluginManager", + "Provider", + "Star", + "StarMetadata", + "StarTools", + "star_map", + "star_registry", +] diff --git a/astrbot/core/star/base.py b/astrbot/core/star/base.py new file mode 100644 index 000000000..dd3ae3f0e --- /dev/null +++ b/astrbot/core/star/base.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import logging +from typing import Any, Protocol + +from astrbot.core import html_renderer +from astrbot.core.utils.command_parser import CommandParserMixin +from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin + +from .star import StarMetadata, star_map, star_registry + +logger = logging.getLogger("astrbot") + + +class Star(CommandParserMixin, PluginKVStoreMixin): + """所有插件(Star)的父类,所有插件都应该继承于这个类""" + + author: str + name: str + + class _ContextLike(Protocol): + def get_config(self, umo: str | None = None) -> Any: ... + + def __init__(self, context: _ContextLike, config: dict | None = None) -> None: + self.context = context + + def _get_context_config(self) -> Any: + get_config = getattr(self.context, "get_config", None) + if callable(get_config): + try: + return get_config() + except Exception as e: + logger.debug(f"get_config() failed: {e}") + return None + return getattr(self.context, "_config", None) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not star_map.get(cls.__module__): + metadata = StarMetadata( + star_cls_type=cls, + module_path=cls.__module__, + ) + star_map[cls.__module__] = metadata + star_registry.append(metadata) + else: + star_map[cls.__module__].star_cls_type = cls + star_map[cls.__module__].module_path = cls.__module__ + + async def text_to_image(self, text: str, return_url=True) -> str: + """将文本转换为图片""" + config_obj = self._get_context_config() + template_name = None + if hasattr(config_obj, "get"): + try: + template_name = config_obj.get("t2i_active_template") + except Exception: + template_name = None + return await html_renderer.render_t2i( + text, + return_url=return_url, + template_name=template_name, + ) + + async def html_render( + self, + tmpl: str, + data: dict, + return_url=True, + options: dict | None = None, + ) -> str: + """渲染 HTML""" + return await html_renderer.render_custom_template( + tmpl, + data, + return_url=return_url, + options=options, + ) + + async def initialize(self) -> None: + """当插件被激活时会调用这个方法""" + + async def terminate(self) -> None: + """当插件被禁用、重载插件时会调用这个方法""" + + def __del__(self) -> None: + """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 6a74580f6..ef8c60e5f 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import logging from asyncio import Queue from collections.abc import Awaitable, Callable -from typing import Any +from typing import TYPE_CHECKING, Any, Protocol from deprecated import deprecated @@ -12,14 +14,12 @@ from astrbot.core.agent.tool import ToolSet from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.conversation_mgr import ConversationManager -from astrbot.core.cron.manager import CronJobManager from astrbot.core.db import BaseDatabase from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.message.message_event_result import MessageChain from astrbot.core.persona_mgr import PersonaManager from astrbot.core.platform import Platform from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion -from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager @@ -45,6 +45,15 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry logger = logging.getLogger("astrbot") +if TYPE_CHECKING: + from astrbot.core.cron.manager import CronJobManager +else: + CronJobManager = Any + + +class PlatformManagerProtocol(Protocol): + platform_insts: list[Platform] + class Context: """暴露给插件的接口上下文。""" @@ -61,7 +70,7 @@ class Context: config: AstrBotConfig, db: BaseDatabase, provider_manager: ProviderManager, - platform_manager: PlatformManager, + platform_manager: PlatformManagerProtocol, conversation_manager: ConversationManager, message_history_manager: PlatformMessageHistoryManager, persona_manager: PersonaManager, diff --git a/astrbot/core/star/register/star.py b/astrbot/core/star/register/star.py index 617cd5ff7..c1a0ce10c 100644 --- a/astrbot/core/star/register/star.py +++ b/astrbot/core/star/register/star.py @@ -1,6 +1,6 @@ import warnings -from astrbot.core.star import StarMetadata, star_map +from astrbot.core.star.star import StarMetadata, star_map _warned_register_star = False diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 87b9b9998..1385b5056 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -11,7 +11,6 @@ from astrbot.core.agent.agent import Agent from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.tool import FunctionTool -from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools @@ -617,7 +616,7 @@ class RegisteringAgent: kwargs["registering_agent"] = self return register_llm_tool(*args, **kwargs) - def __init__(self, agent: Agent[AstrAgentContext]) -> None: + def __init__(self, agent: Agent[Any]) -> None: self._agent = agent @@ -625,7 +624,7 @@ def register_agent( name: str, instruction: str, tools: list[str | FunctionTool] | None = None, - run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None, + run_hooks: BaseAgentRunHooks[Any] | None = None, ): """注册一个 Agent @@ -639,12 +638,12 @@ def register_agent( tools_ = tools or [] def decorator(awaitable: Callable[..., Awaitable[Any]]): - AstrAgent = Agent[AstrAgentContext] + AstrAgent = Agent[Any] agent = AstrAgent( name=name, instructions=instruction, tools=tools_, - run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](), + run_hooks=run_hooks or BaseAgentRunHooks[Any](), ) handoff_tool = HandoffTool(agent=agent) handoff_tool.handler = awaitable diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 0f61c0274..815b306aa 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -49,10 +49,13 @@ class PluginVersionIncompatibleError(Exception): class PluginManager: def __init__(self, context: Context, config: AstrBotConfig) -> None: + from .star_tools import StarTools + self.updator = PluginUpdator() self.context = context self.context._star_manager = self # type: ignore + StarTools.initialize(context) self.config = config self.plugin_store_path = get_astrbot_plugin_path()