fix: resolve pipeline and star import cycles (#5353)

* fix: resolve pipeline and star import cycles

- Add bootstrap.py and stage_order.py to break circular dependencies
- Export Context, PluginManager, StarTools from star module
- Update pipeline __init__ to defer imports
- Split pipeline initialization into separate bootstrap module

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: add logging for get_config() failure in Star class

* fix: reorder logger initialization in base.py

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
whatevertogo
2026-02-24 13:53:29 +08:00
committed by GitHub
parent 80fd51119b
commit 9294b44831
13 changed files with 264 additions and 99 deletions
+66 -21
View File
@@ -1,30 +1,60 @@
"""Pipeline package exports.
This module intentionally avoids eager imports of all pipeline stage modules to
prevent import-time cycles. Stage classes remain available via lazy attribute
resolution for backward compatibility.
"""
from __future__ import annotations
from importlib import import_module
from typing import Any
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
EventResultType, EventResultType,
MessageEventResult, MessageEventResult,
) )
from .content_safety_check.stage import ContentSafetyCheckStage from .stage_order import STAGES_ORDER
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
# 管道阶段顺序 _LAZY_EXPORTS = {
STAGES_ORDER = [ "ContentSafetyCheckStage": (
"WakingCheckStage", # 检查是否需要唤醒 "astrbot.core.pipeline.content_safety_check.stage",
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单 "ContentSafetyCheckStage",
"SessionStatusCheckStage", # 检查会话是否整体启用 ),
"RateLimitStage", # 检查会话是否超过频率限制 "PreProcessStage": (
"ContentSafetyCheckStage", # 检查内容安全 "astrbot.core.pipeline.preprocess_stage.stage",
"PreProcessStage", # 预处理 "PreProcessStage",
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 ),
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 "ProcessStage": (
"RespondStage", # 发送消息 "astrbot.core.pipeline.process_stage.stage",
] "ProcessStage",
),
"RateLimitStage": (
"astrbot.core.pipeline.rate_limit_check.stage",
"RateLimitStage",
),
"RespondStage": (
"astrbot.core.pipeline.respond.stage",
"RespondStage",
),
"ResultDecorateStage": (
"astrbot.core.pipeline.result_decorate.stage",
"ResultDecorateStage",
),
"SessionStatusCheckStage": (
"astrbot.core.pipeline.session_status_check.stage",
"SessionStatusCheckStage",
),
"WakingCheckStage": (
"astrbot.core.pipeline.waking_check.stage",
"WakingCheckStage",
),
"WhitelistCheckStage": (
"astrbot.core.pipeline.whitelist_check.stage",
"WhitelistCheckStage",
),
}
__all__ = [ __all__ = [
"ContentSafetyCheckStage", "ContentSafetyCheckStage",
@@ -36,6 +66,21 @@ __all__ = [
"RespondStage", "RespondStage",
"ResultDecorateStage", "ResultDecorateStage",
"SessionStatusCheckStage", "SessionStatusCheckStage",
"STAGES_ORDER",
"WakingCheckStage", "WakingCheckStage",
"WhitelistCheckStage", "WhitelistCheckStage",
] ]
def __getattr__(name: str) -> Any:
if name not in _LAZY_EXPORTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module_path, attr_name = _LAZY_EXPORTS[name]
module = import_module(module_path)
value = getattr(module, attr_name)
globals()[name] = value
return value
def __dir__() -> list[str]:
return sorted(set(globals()) | set(__all__))
+52
View File
@@ -0,0 +1,52 @@
"""Pipeline bootstrap utilities."""
from importlib import import_module
from .stage import registered_stages
_BUILTIN_STAGE_MODULES = (
"astrbot.core.pipeline.waking_check.stage",
"astrbot.core.pipeline.whitelist_check.stage",
"astrbot.core.pipeline.session_status_check.stage",
"astrbot.core.pipeline.rate_limit_check.stage",
"astrbot.core.pipeline.content_safety_check.stage",
"astrbot.core.pipeline.preprocess_stage.stage",
"astrbot.core.pipeline.process_stage.stage",
"astrbot.core.pipeline.result_decorate.stage",
"astrbot.core.pipeline.respond.stage",
)
_EXPECTED_STAGE_NAMES = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
_builtin_stages_registered = False
def ensure_builtin_stages_registered() -> None:
"""Ensure built-in pipeline stages are imported and registered."""
global _builtin_stages_registered
if _builtin_stages_registered:
return
stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
if _EXPECTED_STAGE_NAMES.issubset(stage_names):
_builtin_stages_registered = True
return
for module_path in _BUILTIN_STAGE_MODULES:
import_module(module_path)
_builtin_stages_registered = True
__all__ = ["ensure_builtin_stages_registered"]
+4 -2
View File
@@ -1,7 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler from .context_utils import call_event_hook, call_handler
@@ -11,7 +13,7 @@ class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息""" """上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象 astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象 plugin_manager: Any # 插件管理器对象
astrbot_config_id: str astrbot_config_id: str
call_handler = call_handler call_handler = call_handler
call_event_hook = call_event_hook call_event_hook = call_event_hook
@@ -19,6 +19,7 @@ from astrbot.core.message.message_event_result import (
MessageEventResult, MessageEventResult,
ResultContentType, ResultContentType,
) )
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import ( from astrbot.core.provider.entities import (
LLMResponse, LLMResponse,
@@ -30,7 +31,6 @@ from astrbot.core.utils.session_lock import session_lock_manager
from .....astr_agent_run_util import run_agent, run_live_agent from .....astr_agent_run_util import run_agent, run_live_agent
from ....context import PipelineContext, call_event_hook from ....context import PipelineContext, call_event_hook
from ...stage import Stage
class InternalAgentSubStage(Stage): class InternalAgentSubStage(Stage):
@@ -8,6 +8,7 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner, DashscopeAgentRunner,
) )
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.message.components import Image from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
@@ -17,6 +18,7 @@ from astrbot.core.message.message_event_result import (
if TYPE_CHECKING: if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import ( from astrbot.core.provider.entities import (
ProviderRequest, ProviderRequest,
@@ -25,9 +27,7 @@ from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = { AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id", "dify": "dify_agent_runner_provider_id",
+3 -1
View File
@@ -8,15 +8,17 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
) )
from astrbot.core.utils.active_event_registry import active_event_registry from astrbot.core.utils.active_event_registry import active_event_registry
from . import STAGES_ORDER from .bootstrap import ensure_builtin_stages_registered
from .context import PipelineContext from .context import PipelineContext
from .stage import registered_stages from .stage import registered_stages
from .stage_order import STAGES_ORDER
class PipelineScheduler: class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行""" """管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext) -> None: def __init__(self, context: PipelineContext) -> None:
ensure_builtin_stages_registered()
registered_stages.sort( registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__), key=lambda x: STAGES_ORDER.index(x.__name__),
) # 按照顺序排序 ) # 按照顺序排序
+15
View File
@@ -0,0 +1,15 @@
"""Pipeline stage execution order."""
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
__all__ = ["STAGES_ORDER"]
+13 -62
View File
@@ -1,68 +1,19 @@
from astrbot.core import html_renderer # 兼容导出: Provider 从 provider 模块重新导出
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .base import Star
from .context import Context from .context import Context
from .star import StarMetadata, star_map, star_registry from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager from .star_manager import PluginManager
from .star_tools import StarTools
__all__ = [
class Star(CommandParserMixin, PluginKVStoreMixin): "Context",
"""所有插件(Star)的父类,所有插件都应该继承于这个类""" "PluginManager",
"Provider",
author: str "Star",
name: str "StarMetadata",
"StarTools",
def __init__(self, context: Context, config: dict | None = None) -> None: "star_map",
StarTools.initialize(context) "star_registry",
self.context = context ]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=self.context._config.get("t2i_active_template"),
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]
+87
View File
@@ -0,0 +1,87 @@
from __future__ import annotations
import logging
from typing import Any, Protocol
from astrbot.core import html_renderer
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .star import StarMetadata, star_map, star_registry
logger = logging.getLogger("astrbot")
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
class _ContextLike(Protocol):
def get_config(self, umo: str | None = None) -> Any: ...
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
self.context = context
def _get_context_config(self) -> Any:
get_config = getattr(self.context, "get_config", None)
if callable(get_config):
try:
return get_config()
except Exception as e:
logger.debug(f"get_config() failed: {e}")
return None
return getattr(self.context, "_config", None)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
config_obj = self._get_context_config()
template_name = None
if hasattr(config_obj, "get"):
try:
template_name = config_obj.get("t2i_active_template")
except Exception:
template_name = None
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=template_name,
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
+13 -4
View File
@@ -1,7 +1,9 @@
from __future__ import annotations
import logging import logging
from asyncio import Queue from asyncio import Queue
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import TYPE_CHECKING, Any, Protocol
from deprecated import deprecated from deprecated import deprecated
@@ -12,14 +14,12 @@ from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.cron.manager import CronJobManager
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
@@ -45,6 +45,15 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
logger = logging.getLogger("astrbot") logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.cron.manager import CronJobManager
else:
CronJobManager = Any
class PlatformManagerProtocol(Protocol):
platform_insts: list[Platform]
class Context: class Context:
"""暴露给插件的接口上下文。""" """暴露给插件的接口上下文。"""
@@ -61,7 +70,7 @@ class Context:
config: AstrBotConfig, config: AstrBotConfig,
db: BaseDatabase, db: BaseDatabase,
provider_manager: ProviderManager, provider_manager: ProviderManager,
platform_manager: PlatformManager, platform_manager: PlatformManagerProtocol,
conversation_manager: ConversationManager, conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager, message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager, persona_manager: PersonaManager,
+1 -1
View File
@@ -1,6 +1,6 @@
import warnings import warnings
from astrbot.core.star import StarMetadata, star_map from astrbot.core.star.star import StarMetadata, star_map
_warned_register_star = False _warned_register_star = False
+4 -5
View File
@@ -11,7 +11,6 @@ from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools from astrbot.core.provider.register import llm_tools
@@ -617,7 +616,7 @@ class RegisteringAgent:
kwargs["registering_agent"] = self kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs) return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[AstrAgentContext]) -> None: def __init__(self, agent: Agent[Any]) -> None:
self._agent = agent self._agent = agent
@@ -625,7 +624,7 @@ def register_agent(
name: str, name: str,
instruction: str, instruction: str,
tools: list[str | FunctionTool] | None = None, tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None, run_hooks: BaseAgentRunHooks[Any] | None = None,
): ):
"""注册一个 Agent """注册一个 Agent
@@ -639,12 +638,12 @@ def register_agent(
tools_ = tools or [] tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]): def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[AstrAgentContext] AstrAgent = Agent[Any]
agent = AstrAgent( agent = AstrAgent(
name=name, name=name,
instructions=instruction, instructions=instruction,
tools=tools_, tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](), run_hooks=run_hooks or BaseAgentRunHooks[Any](),
) )
handoff_tool = HandoffTool(agent=agent) handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable handoff_tool.handler = awaitable
+3
View File
@@ -49,10 +49,13 @@ class PluginVersionIncompatibleError(Exception):
class PluginManager: class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None: def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools
self.updator = PluginUpdator() self.updator = PluginUpdator()
self.context = context self.context = context
self.context._star_manager = self # type: ignore self.context._star_manager = self # type: ignore
StarTools.initialize(context)
self.config = config self.config = config
self.plugin_store_path = get_astrbot_plugin_path() self.plugin_store_path = get_astrbot_plugin_path()