fix: resolve pipeline and star import cycles (#5353)

* fix: resolve pipeline and star import cycles

- Add bootstrap.py and stage_order.py to break circular dependencies
- Export Context, PluginManager, StarTools from star module
- Update pipeline __init__ to defer imports
- Split pipeline initialization into separate bootstrap module

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: add logging for get_config() failure in Star class

* fix: reorder logger initialization in base.py

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
whatevertogo
2026-02-24 13:53:29 +08:00
committed by GitHub
parent 80fd51119b
commit 9294b44831
13 changed files with 264 additions and 99 deletions
+66 -21
View File
@@ -1,30 +1,60 @@
"""Pipeline package exports.
This module intentionally avoids eager imports of all pipeline stage modules to
prevent import-time cycles. Stage classes remain available via lazy attribute
resolution for backward compatibility.
"""
from __future__ import annotations
from importlib import import_module
from typing import Any
from astrbot.core.message.message_event_result import (
EventResultType,
MessageEventResult,
)
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .stage_order import STAGES_ORDER
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
_LAZY_EXPORTS = {
"ContentSafetyCheckStage": (
"astrbot.core.pipeline.content_safety_check.stage",
"ContentSafetyCheckStage",
),
"PreProcessStage": (
"astrbot.core.pipeline.preprocess_stage.stage",
"PreProcessStage",
),
"ProcessStage": (
"astrbot.core.pipeline.process_stage.stage",
"ProcessStage",
),
"RateLimitStage": (
"astrbot.core.pipeline.rate_limit_check.stage",
"RateLimitStage",
),
"RespondStage": (
"astrbot.core.pipeline.respond.stage",
"RespondStage",
),
"ResultDecorateStage": (
"astrbot.core.pipeline.result_decorate.stage",
"ResultDecorateStage",
),
"SessionStatusCheckStage": (
"astrbot.core.pipeline.session_status_check.stage",
"SessionStatusCheckStage",
),
"WakingCheckStage": (
"astrbot.core.pipeline.waking_check.stage",
"WakingCheckStage",
),
"WhitelistCheckStage": (
"astrbot.core.pipeline.whitelist_check.stage",
"WhitelistCheckStage",
),
}
__all__ = [
"ContentSafetyCheckStage",
@@ -36,6 +66,21 @@ __all__ = [
"RespondStage",
"ResultDecorateStage",
"SessionStatusCheckStage",
"STAGES_ORDER",
"WakingCheckStage",
"WhitelistCheckStage",
]
def __getattr__(name: str) -> Any:
if name not in _LAZY_EXPORTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module_path, attr_name = _LAZY_EXPORTS[name]
module = import_module(module_path)
value = getattr(module, attr_name)
globals()[name] = value
return value
def __dir__() -> list[str]:
return sorted(set(globals()) | set(__all__))
+52
View File
@@ -0,0 +1,52 @@
"""Pipeline bootstrap utilities."""
from importlib import import_module
from .stage import registered_stages
_BUILTIN_STAGE_MODULES = (
"astrbot.core.pipeline.waking_check.stage",
"astrbot.core.pipeline.whitelist_check.stage",
"astrbot.core.pipeline.session_status_check.stage",
"astrbot.core.pipeline.rate_limit_check.stage",
"astrbot.core.pipeline.content_safety_check.stage",
"astrbot.core.pipeline.preprocess_stage.stage",
"astrbot.core.pipeline.process_stage.stage",
"astrbot.core.pipeline.result_decorate.stage",
"astrbot.core.pipeline.respond.stage",
)
_EXPECTED_STAGE_NAMES = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
_builtin_stages_registered = False
def ensure_builtin_stages_registered() -> None:
"""Ensure built-in pipeline stages are imported and registered."""
global _builtin_stages_registered
if _builtin_stages_registered:
return
stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
if _EXPECTED_STAGE_NAMES.issubset(stage_names):
_builtin_stages_registered = True
return
for module_path in _BUILTIN_STAGE_MODULES:
import_module(module_path)
_builtin_stages_registered = True
__all__ = ["ensure_builtin_stages_registered"]
+4 -2
View File
@@ -1,7 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
@@ -11,7 +13,7 @@ class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
plugin_manager: Any # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
@@ -19,6 +19,7 @@ from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
)
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
LLMResponse,
@@ -30,7 +31,6 @@ from astrbot.core.utils.session_lock import session_lock_manager
from .....astr_agent_run_util import run_agent, run_live_agent
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
class InternalAgentSubStage(Stage):
@@ -8,6 +8,7 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner,
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
@@ -17,6 +18,7 @@ from astrbot.core.message.message_event_result import (
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
@@ -25,9 +27,7 @@ from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id",
+3 -1
View File
@@ -8,15 +8,17 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
)
from astrbot.core.utils.active_event_registry import active_event_registry
from . import STAGES_ORDER
from .bootstrap import ensure_builtin_stages_registered
from .context import PipelineContext
from .stage import registered_stages
from .stage_order import STAGES_ORDER
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext) -> None:
ensure_builtin_stages_registered()
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__),
) # 按照顺序排序
+15
View File
@@ -0,0 +1,15 @@
"""Pipeline stage execution order."""
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
__all__ = ["STAGES_ORDER"]
+13 -62
View File
@@ -1,68 +1,19 @@
from astrbot.core import html_renderer
# 兼容导出: Provider 从 provider 模块重新导出
from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .base import Star
from .context import Context
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
from .star_tools import StarTools
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None) -> None:
StarTools.initialize(context)
self.context = context
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=self.context._config.get("t2i_active_template"),
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]
__all__ = [
"Context",
"PluginManager",
"Provider",
"Star",
"StarMetadata",
"StarTools",
"star_map",
"star_registry",
]
+87
View File
@@ -0,0 +1,87 @@
from __future__ import annotations
import logging
from typing import Any, Protocol
from astrbot.core import html_renderer
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .star import StarMetadata, star_map, star_registry
logger = logging.getLogger("astrbot")
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
class _ContextLike(Protocol):
def get_config(self, umo: str | None = None) -> Any: ...
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
self.context = context
def _get_context_config(self) -> Any:
get_config = getattr(self.context, "get_config", None)
if callable(get_config):
try:
return get_config()
except Exception as e:
logger.debug(f"get_config() failed: {e}")
return None
return getattr(self.context, "_config", None)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
config_obj = self._get_context_config()
template_name = None
if hasattr(config_obj, "get"):
try:
template_name = config_obj.get("t2i_active_template")
except Exception:
template_name = None
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=template_name,
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
+13 -4
View File
@@ -1,7 +1,9 @@
from __future__ import annotations
import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from typing import Any
from typing import TYPE_CHECKING, Any, Protocol
from deprecated import deprecated
@@ -12,14 +14,12 @@ from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.cron.manager import CronJobManager
from astrbot.core.db import BaseDatabase
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
@@ -45,6 +45,15 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.cron.manager import CronJobManager
else:
CronJobManager = Any
class PlatformManagerProtocol(Protocol):
platform_insts: list[Platform]
class Context:
"""暴露给插件的接口上下文。"""
@@ -61,7 +70,7 @@ class Context:
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager,
platform_manager: PlatformManager,
platform_manager: PlatformManagerProtocol,
conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
+1 -1
View File
@@ -1,6 +1,6 @@
import warnings
from astrbot.core.star import StarMetadata, star_map
from astrbot.core.star.star import StarMetadata, star_map
_warned_register_star = False
+4 -5
View File
@@ -11,7 +11,6 @@ from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
@@ -617,7 +616,7 @@ class RegisteringAgent:
kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
def __init__(self, agent: Agent[Any]) -> None:
self._agent = agent
@@ -625,7 +624,7 @@ def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
run_hooks: BaseAgentRunHooks[Any] | None = None,
):
"""注册一个 Agent
@@ -639,12 +638,12 @@ def register_agent(
tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[AstrAgentContext]
AstrAgent = Agent[Any]
agent = AstrAgent(
name=name,
instructions=instruction,
tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable
+3
View File
@@ -49,10 +49,13 @@ class PluginVersionIncompatibleError(Exception):
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self # type: ignore
StarTools.initialize(context)
self.config = config
self.plugin_store_path = get_astrbot_plugin_path()