From e9b23f68fd87f5a9aba0850125d965b2c7b62373 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sun, 29 Jun 2025 17:19:53 +0800 Subject: [PATCH] perf: add AgentState Enum for improved state management --- .../process_stage/agent_runner/base.py | 9 +++++ .../agent_runner/tool_loop_agent.py | 33 ++++++++++++------- .../process_stage/method/llm_request.py | 13 +++----- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/agent_runner/base.py b/astrbot/core/pipeline/process_stage/agent_runner/base.py index c2fb1e79d..431a95ca6 100644 --- a/astrbot/core/pipeline/process_stage/agent_runner/base.py +++ b/astrbot/core/pipeline/process_stage/agent_runner/base.py @@ -3,6 +3,15 @@ import typing as T from dataclasses import dataclass from astrbot.core.provider.entities import LLMResponse from ....message.message_event_result import MessageChain +from enum import Enum, auto + + +class AgentState(Enum): + """Agent 状态枚举""" + IDLE = auto() # 初始状态 + RUNNING = auto() # 运行中 + DONE = auto() # 完成 + ERROR = auto() # 错误状态 class AgentResponseData(T.TypedDict): diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py index 07b808925..04fda3c1c 100644 --- a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py +++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py @@ -1,7 +1,7 @@ import sys import traceback import typing as T -from .base import BaseAgentRunner, AgentResponse, AgentResponseData +from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState from ...context import PipelineContext from astrbot.core.provider.provider import Provider from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -43,13 +43,22 @@ class ToolLoopAgent(BaseAgentRunner): self.req = None self.event = event self.pipeline_ctx = pipeline_ctx + self._state = AgentState.IDLE + self.final_llm_resp = None + self.streaming = False @override async def reset(self, req: ProviderRequest, streaming: bool) -> None: self.req = req self.streaming = streaming self.final_llm_resp = None - self.is_done = False + self._state = AgentState.IDLE + + def _transition_state(self, new_state: AgentState) -> None: + """转换 Agent 状态""" + if self._state != new_state: + logger.info(f"Agent state transition: {self._state} -> {new_state}") + self._state = new_state async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" @@ -69,7 +78,8 @@ class ToolLoopAgent(BaseAgentRunner): if not self.req: raise ValueError("Request is not set. Please call reset() first.") - # 执行 LLM 请求 + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) llm_resp_result = None async for llm_response in self._iter_llm_responses(): @@ -99,9 +109,9 @@ class ToolLoopAgent(BaseAgentRunner): logger.info(f"LLMResp: {llm_resp}") if llm_resp.role == "err": - # 如果 LLM 响应错误,直接返回错误信息 + # 如果 LLM 响应错误,转换到错误状态 self.final_llm_resp = llm_resp - self.is_done = True + self._transition_state(AgentState.ERROR) yield AgentResponse( type="err", data=AgentResponseData( @@ -112,9 +122,9 @@ class ToolLoopAgent(BaseAgentRunner): ) if not llm_resp.tools_call_name: - # 如果没有工具调用,结束 Agent Loop + # 如果没有工具调用,转换到完成状态 self.final_llm_resp = llm_resp - self.is_done = True + self._transition_state(AgentState.DONE) # 执行事件钩子 await self.pipeline_ctx.call_event_hook( @@ -157,8 +167,6 @@ class ToolLoopAgent(BaseAgentRunner): ) self.req.append_tool_calls_result(tool_calls_result) - logger.info("done: %s", self.is_done) - async def _handle_function_tools( self, req: ProviderRequest, @@ -251,13 +259,13 @@ class ToolLoopAgent(BaseAgentRunner): else: # Tool 直接请求发送消息给用户 # 这里我们将直接结束 Agent Loop。 - self.is_done = True + self._transition_state(AgentState.DONE) if res := self.event.get_result(): if res.chain: yield MessageChain(chain=res.chain) self.event.clear_result() - except BaseException as e: + except Exception as e: logger.warning(traceback.format_exc()) tool_call_result_blocks.append( ToolCallMessageSegment( @@ -272,7 +280,8 @@ class ToolLoopAgent(BaseAgentRunner): yield tool_call_result_blocks def done(self) -> bool: - return self.is_done + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) def get_final_llm_resp(self) -> LLMResponse | None: return self.final_llm_resp diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 1134a19e2..1ede6313e 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -268,13 +268,11 @@ class LLMRequestSubStage(Stage): cid=cid, title=title, ) - web_chat_back_queue.put_nowait( - { - "type": "update_title", - "cid": cid, - "data": title, - } - ) + web_chat_back_queue.put_nowait({ + "type": "update_title", + "cid": cid, + "data": title, + }) async def _save_to_history( self, @@ -323,4 +321,3 @@ class LLMRequestSubStage(Stage): else: fixed_messages.append(message) return fixed_messages -