fix: resolve pipeline and star import cycles (#5353)
* fix: resolve pipeline and star import cycles - Add bootstrap.py and stage_order.py to break circular dependencies - Export Context, PluginManager, StarTools from star module - Update pipeline __init__ to defer imports - Split pipeline initialization into separate bootstrap module Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: add logging for get_config() failure in Star class * fix: reorder logger initialization in base.py --------- Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,30 +1,60 @@
|
|||||||
|
"""Pipeline package exports.
|
||||||
|
|
||||||
|
This module intentionally avoids eager imports of all pipeline stage modules to
|
||||||
|
prevent import-time cycles. Stage classes remain available via lazy attribute
|
||||||
|
resolution for backward compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from importlib import import_module
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
EventResultType,
|
EventResultType,
|
||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
from .stage_order import STAGES_ORDER
|
||||||
from .preprocess_stage.stage import PreProcessStage
|
|
||||||
from .process_stage.stage import ProcessStage
|
|
||||||
from .rate_limit_check.stage import RateLimitStage
|
|
||||||
from .respond.stage import RespondStage
|
|
||||||
from .result_decorate.stage import ResultDecorateStage
|
|
||||||
from .session_status_check.stage import SessionStatusCheckStage
|
|
||||||
from .waking_check.stage import WakingCheckStage
|
|
||||||
from .whitelist_check.stage import WhitelistCheckStage
|
|
||||||
|
|
||||||
# 管道阶段顺序
|
_LAZY_EXPORTS = {
|
||||||
STAGES_ORDER = [
|
"ContentSafetyCheckStage": (
|
||||||
"WakingCheckStage", # 检查是否需要唤醒
|
"astrbot.core.pipeline.content_safety_check.stage",
|
||||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
"ContentSafetyCheckStage",
|
||||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
),
|
||||||
"RateLimitStage", # 检查会话是否超过频率限制
|
"PreProcessStage": (
|
||||||
"ContentSafetyCheckStage", # 检查内容安全
|
"astrbot.core.pipeline.preprocess_stage.stage",
|
||||||
"PreProcessStage", # 预处理
|
"PreProcessStage",
|
||||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
),
|
||||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
"ProcessStage": (
|
||||||
"RespondStage", # 发送消息
|
"astrbot.core.pipeline.process_stage.stage",
|
||||||
]
|
"ProcessStage",
|
||||||
|
),
|
||||||
|
"RateLimitStage": (
|
||||||
|
"astrbot.core.pipeline.rate_limit_check.stage",
|
||||||
|
"RateLimitStage",
|
||||||
|
),
|
||||||
|
"RespondStage": (
|
||||||
|
"astrbot.core.pipeline.respond.stage",
|
||||||
|
"RespondStage",
|
||||||
|
),
|
||||||
|
"ResultDecorateStage": (
|
||||||
|
"astrbot.core.pipeline.result_decorate.stage",
|
||||||
|
"ResultDecorateStage",
|
||||||
|
),
|
||||||
|
"SessionStatusCheckStage": (
|
||||||
|
"astrbot.core.pipeline.session_status_check.stage",
|
||||||
|
"SessionStatusCheckStage",
|
||||||
|
),
|
||||||
|
"WakingCheckStage": (
|
||||||
|
"astrbot.core.pipeline.waking_check.stage",
|
||||||
|
"WakingCheckStage",
|
||||||
|
),
|
||||||
|
"WhitelistCheckStage": (
|
||||||
|
"astrbot.core.pipeline.whitelist_check.stage",
|
||||||
|
"WhitelistCheckStage",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
@@ -36,6 +66,21 @@ __all__ = [
|
|||||||
"RespondStage",
|
"RespondStage",
|
||||||
"ResultDecorateStage",
|
"ResultDecorateStage",
|
||||||
"SessionStatusCheckStage",
|
"SessionStatusCheckStage",
|
||||||
|
"STAGES_ORDER",
|
||||||
"WakingCheckStage",
|
"WakingCheckStage",
|
||||||
"WhitelistCheckStage",
|
"WhitelistCheckStage",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
if name not in _LAZY_EXPORTS:
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
module_path, attr_name = _LAZY_EXPORTS[name]
|
||||||
|
module = import_module(module_path)
|
||||||
|
value = getattr(module, attr_name)
|
||||||
|
globals()[name] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def __dir__() -> list[str]:
|
||||||
|
return sorted(set(globals()) | set(__all__))
|
||||||
|
|||||||
@@ -0,0 +1,52 @@
|
|||||||
|
"""Pipeline bootstrap utilities."""
|
||||||
|
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
from .stage import registered_stages
|
||||||
|
|
||||||
|
_BUILTIN_STAGE_MODULES = (
|
||||||
|
"astrbot.core.pipeline.waking_check.stage",
|
||||||
|
"astrbot.core.pipeline.whitelist_check.stage",
|
||||||
|
"astrbot.core.pipeline.session_status_check.stage",
|
||||||
|
"astrbot.core.pipeline.rate_limit_check.stage",
|
||||||
|
"astrbot.core.pipeline.content_safety_check.stage",
|
||||||
|
"astrbot.core.pipeline.preprocess_stage.stage",
|
||||||
|
"astrbot.core.pipeline.process_stage.stage",
|
||||||
|
"astrbot.core.pipeline.result_decorate.stage",
|
||||||
|
"astrbot.core.pipeline.respond.stage",
|
||||||
|
)
|
||||||
|
|
||||||
|
_EXPECTED_STAGE_NAMES = {
|
||||||
|
"WakingCheckStage",
|
||||||
|
"WhitelistCheckStage",
|
||||||
|
"SessionStatusCheckStage",
|
||||||
|
"RateLimitStage",
|
||||||
|
"ContentSafetyCheckStage",
|
||||||
|
"PreProcessStage",
|
||||||
|
"ProcessStage",
|
||||||
|
"ResultDecorateStage",
|
||||||
|
"RespondStage",
|
||||||
|
}
|
||||||
|
|
||||||
|
_builtin_stages_registered = False
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_builtin_stages_registered() -> None:
|
||||||
|
"""Ensure built-in pipeline stages are imported and registered."""
|
||||||
|
global _builtin_stages_registered
|
||||||
|
|
||||||
|
if _builtin_stages_registered:
|
||||||
|
return
|
||||||
|
|
||||||
|
stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
|
||||||
|
if _EXPECTED_STAGE_NAMES.issubset(stage_names):
|
||||||
|
_builtin_stages_registered = True
|
||||||
|
return
|
||||||
|
|
||||||
|
for module_path in _BUILTIN_STAGE_MODULES:
|
||||||
|
import_module(module_path)
|
||||||
|
|
||||||
|
_builtin_stages_registered = True
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ensure_builtin_stages_registered"]
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
from astrbot.core.star import PluginManager
|
|
||||||
|
|
||||||
from .context_utils import call_event_hook, call_handler
|
from .context_utils import call_event_hook, call_handler
|
||||||
|
|
||||||
@@ -11,7 +13,7 @@ class PipelineContext:
|
|||||||
"""上下文对象,包含管道执行所需的上下文信息"""
|
"""上下文对象,包含管道执行所需的上下文信息"""
|
||||||
|
|
||||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||||
plugin_manager: PluginManager # 插件管理器对象
|
plugin_manager: Any # 插件管理器对象
|
||||||
astrbot_config_id: str
|
astrbot_config_id: str
|
||||||
call_handler = call_handler
|
call_handler = call_handler
|
||||||
call_event_hook = call_event_hook
|
call_event_hook = call_event_hook
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from astrbot.core.message.message_event_result import (
|
|||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
ResultContentType,
|
ResultContentType,
|
||||||
)
|
)
|
||||||
|
from astrbot.core.pipeline.stage import Stage
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.provider.entities import (
|
from astrbot.core.provider.entities import (
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
@@ -30,7 +31,6 @@ from astrbot.core.utils.session_lock import session_lock_manager
|
|||||||
|
|
||||||
from .....astr_agent_run_util import run_agent, run_live_agent
|
from .....astr_agent_run_util import run_agent, run_live_agent
|
||||||
from ....context import PipelineContext, call_event_hook
|
from ....context import PipelineContext, call_event_hook
|
||||||
from ...stage import Stage
|
|
||||||
|
|
||||||
|
|
||||||
class InternalAgentSubStage(Stage):
|
class InternalAgentSubStage(Stage):
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
|||||||
DashscopeAgentRunner,
|
DashscopeAgentRunner,
|
||||||
)
|
)
|
||||||
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
|
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
|
||||||
|
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
@@ -17,6 +18,7 @@ from astrbot.core.message.message_event_result import (
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from astrbot.core.agent.runners.base import BaseAgentRunner
|
from astrbot.core.agent.runners.base import BaseAgentRunner
|
||||||
|
from astrbot.core.pipeline.stage import Stage
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.provider.entities import (
|
from astrbot.core.provider.entities import (
|
||||||
ProviderRequest,
|
ProviderRequest,
|
||||||
@@ -25,9 +27,7 @@ from astrbot.core.star.star_handler import EventType
|
|||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
|
|
||||||
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
|
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
|
||||||
from ....context import PipelineContext, call_event_hook
|
from ....context import PipelineContext, call_event_hook
|
||||||
from ...stage import Stage
|
|
||||||
|
|
||||||
AGENT_RUNNER_TYPE_KEY = {
|
AGENT_RUNNER_TYPE_KEY = {
|
||||||
"dify": "dify_agent_runner_provider_id",
|
"dify": "dify_agent_runner_provider_id",
|
||||||
|
|||||||
@@ -8,15 +8,17 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
|||||||
)
|
)
|
||||||
from astrbot.core.utils.active_event_registry import active_event_registry
|
from astrbot.core.utils.active_event_registry import active_event_registry
|
||||||
|
|
||||||
from . import STAGES_ORDER
|
from .bootstrap import ensure_builtin_stages_registered
|
||||||
from .context import PipelineContext
|
from .context import PipelineContext
|
||||||
from .stage import registered_stages
|
from .stage import registered_stages
|
||||||
|
from .stage_order import STAGES_ORDER
|
||||||
|
|
||||||
|
|
||||||
class PipelineScheduler:
|
class PipelineScheduler:
|
||||||
"""管道调度器,负责调度各个阶段的执行"""
|
"""管道调度器,负责调度各个阶段的执行"""
|
||||||
|
|
||||||
def __init__(self, context: PipelineContext) -> None:
|
def __init__(self, context: PipelineContext) -> None:
|
||||||
|
ensure_builtin_stages_registered()
|
||||||
registered_stages.sort(
|
registered_stages.sort(
|
||||||
key=lambda x: STAGES_ORDER.index(x.__name__),
|
key=lambda x: STAGES_ORDER.index(x.__name__),
|
||||||
) # 按照顺序排序
|
) # 按照顺序排序
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
"""Pipeline stage execution order."""
|
||||||
|
|
||||||
|
STAGES_ORDER = [
|
||||||
|
"WakingCheckStage", # 检查是否需要唤醒
|
||||||
|
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||||
|
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||||
|
"RateLimitStage", # 检查会话是否超过频率限制
|
||||||
|
"ContentSafetyCheckStage", # 检查内容安全
|
||||||
|
"PreProcessStage", # 预处理
|
||||||
|
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||||
|
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||||
|
"RespondStage", # 发送消息
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ = ["STAGES_ORDER"]
|
||||||
@@ -1,68 +1,19 @@
|
|||||||
from astrbot.core import html_renderer
|
# 兼容导出: Provider 从 provider 模块重新导出
|
||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.star.star_tools import StarTools
|
|
||||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
|
||||||
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
|
||||||
|
|
||||||
|
from .base import Star
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from .star import StarMetadata, star_map, star_registry
|
from .star import StarMetadata, star_map, star_registry
|
||||||
from .star_manager import PluginManager
|
from .star_manager import PluginManager
|
||||||
|
from .star_tools import StarTools
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
class Star(CommandParserMixin, PluginKVStoreMixin):
|
"Context",
|
||||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
"PluginManager",
|
||||||
|
"Provider",
|
||||||
author: str
|
"Star",
|
||||||
name: str
|
"StarMetadata",
|
||||||
|
"StarTools",
|
||||||
def __init__(self, context: Context, config: dict | None = None) -> None:
|
"star_map",
|
||||||
StarTools.initialize(context)
|
"star_registry",
|
||||||
self.context = context
|
]
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
if not star_map.get(cls.__module__):
|
|
||||||
metadata = StarMetadata(
|
|
||||||
star_cls_type=cls,
|
|
||||||
module_path=cls.__module__,
|
|
||||||
)
|
|
||||||
star_map[cls.__module__] = metadata
|
|
||||||
star_registry.append(metadata)
|
|
||||||
else:
|
|
||||||
star_map[cls.__module__].star_cls_type = cls
|
|
||||||
star_map[cls.__module__].module_path = cls.__module__
|
|
||||||
|
|
||||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
|
||||||
"""将文本转换为图片"""
|
|
||||||
return await html_renderer.render_t2i(
|
|
||||||
text,
|
|
||||||
return_url=return_url,
|
|
||||||
template_name=self.context._config.get("t2i_active_template"),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def html_render(
|
|
||||||
self,
|
|
||||||
tmpl: str,
|
|
||||||
data: dict,
|
|
||||||
return_url=True,
|
|
||||||
options: dict | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""渲染 HTML"""
|
|
||||||
return await html_renderer.render_custom_template(
|
|
||||||
tmpl,
|
|
||||||
data,
|
|
||||||
return_url=return_url,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
"""当插件被激活时会调用这个方法"""
|
|
||||||
|
|
||||||
async def terminate(self) -> None:
|
|
||||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]
|
|
||||||
|
|||||||
@@ -0,0 +1,87 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from astrbot.core import html_renderer
|
||||||
|
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||||
|
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||||
|
|
||||||
|
from .star import StarMetadata, star_map, star_registry
|
||||||
|
|
||||||
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
|
|
||||||
|
class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||||
|
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||||
|
|
||||||
|
author: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class _ContextLike(Protocol):
|
||||||
|
def get_config(self, umo: str | None = None) -> Any: ...
|
||||||
|
|
||||||
|
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
def _get_context_config(self) -> Any:
|
||||||
|
get_config = getattr(self.context, "get_config", None)
|
||||||
|
if callable(get_config):
|
||||||
|
try:
|
||||||
|
return get_config()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"get_config() failed: {e}")
|
||||||
|
return None
|
||||||
|
return getattr(self.context, "_config", None)
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if not star_map.get(cls.__module__):
|
||||||
|
metadata = StarMetadata(
|
||||||
|
star_cls_type=cls,
|
||||||
|
module_path=cls.__module__,
|
||||||
|
)
|
||||||
|
star_map[cls.__module__] = metadata
|
||||||
|
star_registry.append(metadata)
|
||||||
|
else:
|
||||||
|
star_map[cls.__module__].star_cls_type = cls
|
||||||
|
star_map[cls.__module__].module_path = cls.__module__
|
||||||
|
|
||||||
|
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||||
|
"""将文本转换为图片"""
|
||||||
|
config_obj = self._get_context_config()
|
||||||
|
template_name = None
|
||||||
|
if hasattr(config_obj, "get"):
|
||||||
|
try:
|
||||||
|
template_name = config_obj.get("t2i_active_template")
|
||||||
|
except Exception:
|
||||||
|
template_name = None
|
||||||
|
return await html_renderer.render_t2i(
|
||||||
|
text,
|
||||||
|
return_url=return_url,
|
||||||
|
template_name=template_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def html_render(
|
||||||
|
self,
|
||||||
|
tmpl: str,
|
||||||
|
data: dict,
|
||||||
|
return_url=True,
|
||||||
|
options: dict | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""渲染 HTML"""
|
||||||
|
return await html_renderer.render_custom_template(
|
||||||
|
tmpl,
|
||||||
|
data,
|
||||||
|
return_url=return_url,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""当插件被激活时会调用这个方法"""
|
||||||
|
|
||||||
|
async def terminate(self) -> None:
|
||||||
|
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any, Protocol
|
||||||
|
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
|
|
||||||
@@ -12,14 +14,12 @@ from astrbot.core.agent.tool import ToolSet
|
|||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from astrbot.core.conversation_mgr import ConversationManager
|
from astrbot.core.conversation_mgr import ConversationManager
|
||||||
from astrbot.core.cron.manager import CronJobManager
|
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.persona_mgr import PersonaManager
|
from astrbot.core.persona_mgr import PersonaManager
|
||||||
from astrbot.core.platform import Platform
|
from astrbot.core.platform import Platform
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
|
||||||
from astrbot.core.platform.manager import PlatformManager
|
|
||||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
|
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
|
||||||
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
|
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
|
||||||
@@ -45,6 +45,15 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
|||||||
|
|
||||||
logger = logging.getLogger("astrbot")
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrbot.core.cron.manager import CronJobManager
|
||||||
|
else:
|
||||||
|
CronJobManager = Any
|
||||||
|
|
||||||
|
|
||||||
|
class PlatformManagerProtocol(Protocol):
|
||||||
|
platform_insts: list[Platform]
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
"""暴露给插件的接口上下文。"""
|
"""暴露给插件的接口上下文。"""
|
||||||
@@ -61,7 +70,7 @@ class Context:
|
|||||||
config: AstrBotConfig,
|
config: AstrBotConfig,
|
||||||
db: BaseDatabase,
|
db: BaseDatabase,
|
||||||
provider_manager: ProviderManager,
|
provider_manager: ProviderManager,
|
||||||
platform_manager: PlatformManager,
|
platform_manager: PlatformManagerProtocol,
|
||||||
conversation_manager: ConversationManager,
|
conversation_manager: ConversationManager,
|
||||||
message_history_manager: PlatformMessageHistoryManager,
|
message_history_manager: PlatformMessageHistoryManager,
|
||||||
persona_manager: PersonaManager,
|
persona_manager: PersonaManager,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from astrbot.core.star import StarMetadata, star_map
|
from astrbot.core.star.star import StarMetadata, star_map
|
||||||
|
|
||||||
_warned_register_star = False
|
_warned_register_star = False
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from astrbot.core.agent.agent import Agent
|
|||||||
from astrbot.core.agent.handoff import HandoffTool
|
from astrbot.core.agent.handoff import HandoffTool
|
||||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||||
from astrbot.core.agent.tool import FunctionTool
|
from astrbot.core.agent.tool import FunctionTool
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult
|
from astrbot.core.message.message_event_result import MessageEventResult
|
||||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
@@ -617,7 +616,7 @@ class RegisteringAgent:
|
|||||||
kwargs["registering_agent"] = self
|
kwargs["registering_agent"] = self
|
||||||
return register_llm_tool(*args, **kwargs)
|
return register_llm_tool(*args, **kwargs)
|
||||||
|
|
||||||
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
|
def __init__(self, agent: Agent[Any]) -> None:
|
||||||
self._agent = agent
|
self._agent = agent
|
||||||
|
|
||||||
|
|
||||||
@@ -625,7 +624,7 @@ def register_agent(
|
|||||||
name: str,
|
name: str,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
tools: list[str | FunctionTool] | None = None,
|
tools: list[str | FunctionTool] | None = None,
|
||||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
|
run_hooks: BaseAgentRunHooks[Any] | None = None,
|
||||||
):
|
):
|
||||||
"""注册一个 Agent
|
"""注册一个 Agent
|
||||||
|
|
||||||
@@ -639,12 +638,12 @@ def register_agent(
|
|||||||
tools_ = tools or []
|
tools_ = tools or []
|
||||||
|
|
||||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||||
AstrAgent = Agent[AstrAgentContext]
|
AstrAgent = Agent[Any]
|
||||||
agent = AstrAgent(
|
agent = AstrAgent(
|
||||||
name=name,
|
name=name,
|
||||||
instructions=instruction,
|
instructions=instruction,
|
||||||
tools=tools_,
|
tools=tools_,
|
||||||
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
|
||||||
)
|
)
|
||||||
handoff_tool = HandoffTool(agent=agent)
|
handoff_tool = HandoffTool(agent=agent)
|
||||||
handoff_tool.handler = awaitable
|
handoff_tool.handler = awaitable
|
||||||
|
|||||||
@@ -49,10 +49,13 @@ class PluginVersionIncompatibleError(Exception):
|
|||||||
|
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
||||||
|
from .star_tools import StarTools
|
||||||
|
|
||||||
self.updator = PluginUpdator()
|
self.updator = PluginUpdator()
|
||||||
|
|
||||||
self.context = context
|
self.context = context
|
||||||
self.context._star_manager = self # type: ignore
|
self.context._star_manager = self # type: ignore
|
||||||
|
StarTools.initialize(context)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.plugin_store_path = get_astrbot_plugin_path()
|
self.plugin_store_path = get_astrbot_plugin_path()
|
||||||
|
|||||||
Reference in New Issue
Block a user