diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 1ea36ff7a..dfffa9cf9 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -51,6 +51,7 @@ from astrbot.core.tools.cron_tools import ( ) from astrbot.core.utils.file_extract import extract_file_moonshotai from astrbot.core.utils.llm_metadata import LLM_METADATAS +from typing import Coroutine @dataclass(slots=True) @@ -114,6 +115,7 @@ class MainAgentBuildResult: agent_runner: AgentRunner provider_request: ProviderRequest provider: Provider + reset_coro: Coroutine | None = None def _select_provider( @@ -837,8 +839,12 @@ async def build_main_agent( config: MainAgentBuildConfig, provider: Provider | None = None, req: ProviderRequest | None = None, + apply_reset: bool = True, ) -> MainAgentBuildResult | None: - """构建主对话代理(Main Agent),并且自动 reset。""" + """构建主对话代理(Main Agent),并且自动 reset。 + + If apply_reset is False, will not call reset on the agent runner. + """ provider = provider or _select_provider(event, plugin_context) if provider is None: logger.info("未找到任何对话模型(提供商),跳过 LLM 请求处理。") @@ -955,7 +961,7 @@ async def build_main_agent( if action_type == "live": req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n" - await agent_runner.reset( + reset_coro = agent_runner.reset( provider=provider, request=req, run_context=AgentContextWrapper( @@ -973,8 +979,12 @@ async def build_main_agent( tool_schema_mode=config.tool_schema_mode, ) + if apply_reset: + await reset_coro + return MainAgentBuildResult( agent_runner=agent_runner, provider_request=req, provider=provider, + reset_coro=reset_coro if not apply_reset else None, ) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 8fa39f8e8..b598e3aa2 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -164,6 +164,7 @@ class InternalAgentSubStage(Stage): event=event, plugin_context=self.ctx.plugin_manager.context, config=build_cfg, + apply_reset=False, ) if build_result is None: @@ -172,6 +173,7 @@ class InternalAgentSubStage(Stage): agent_runner = build_result.agent_runner req = build_result.provider_request provider = build_result.provider + reset_coro = build_result.reset_coro api_base = provider.provider_config.get("api_base", "") for host in decoded_blocked: @@ -190,6 +192,10 @@ class InternalAgentSubStage(Stage): if await call_event_hook(event, EventType.OnLLMRequestEvent, req): return + # apply reset + if reset_coro: + await reset_coro + action_type = event.get_extra("action_type") event.trace.record(