perf: add AgentState Enum for improved state management

This commit is contained in:
Raven95676
2025-06-29 17:19:53 +08:00
parent efa45e6203
commit e9b23f68fd
3 changed files with 35 additions and 20 deletions
@@ -3,6 +3,15 @@ import typing as T
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
from enum import Enum, auto
class AgentState(Enum):
"""Agent 状态枚举"""
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
class AgentResponseData(T.TypedDict):
@@ -1,7 +1,7 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentResponseData
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
from ...context import PipelineContext
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -43,13 +43,22 @@ class ToolLoopAgent(BaseAgentRunner):
self.req = None
self.event = event
self.pipeline_ctx = pipeline_ctx
self._state = AgentState.IDLE
self.final_llm_resp = None
self.streaming = False
@override
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
self.req = req
self.streaming = streaming
self.final_llm_resp = None
self.is_done = False
self._state = AgentState.IDLE
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.info(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
@@ -69,7 +78,8 @@ class ToolLoopAgent(BaseAgentRunner):
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
# 执行 LLM 请求
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
@@ -99,9 +109,9 @@ class ToolLoopAgent(BaseAgentRunner):
logger.info(f"LLMResp: {llm_resp}")
if llm_resp.role == "err":
# 如果 LLM 响应错误,直接返回错误信息
# 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp
self.is_done = True
self._transition_state(AgentState.ERROR)
yield AgentResponse(
type="err",
data=AgentResponseData(
@@ -112,9 +122,9 @@ class ToolLoopAgent(BaseAgentRunner):
)
if not llm_resp.tools_call_name:
# 如果没有工具调用,结束 Agent Loop
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self.is_done = True
self._transition_state(AgentState.DONE)
# 执行事件钩子
await self.pipeline_ctx.call_event_hook(
@@ -157,8 +167,6 @@ class ToolLoopAgent(BaseAgentRunner):
)
self.req.append_tool_calls_result(tool_calls_result)
logger.info("done: %s", self.is_done)
async def _handle_function_tools(
self,
req: ProviderRequest,
@@ -251,13 +259,13 @@ class ToolLoopAgent(BaseAgentRunner):
else:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self.is_done = True
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(chain=res.chain)
self.event.clear_result()
except BaseException as e:
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -272,7 +280,8 @@ class ToolLoopAgent(BaseAgentRunner):
yield tool_call_result_blocks
def done(self) -> bool:
return self.is_done
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp
@@ -268,13 +268,11 @@ class LLMRequestSubStage(Stage):
cid=cid,
title=title,
)
web_chat_back_queue.put_nowait(
{
"type": "update_title",
"cid": cid,
"data": title,
}
)
web_chat_back_queue.put_nowait({
"type": "update_title",
"cid": cid,
"data": title,
})
async def _save_to_history(
self,
@@ -323,4 +321,3 @@ class LLMRequestSubStage(Stage):
else:
fixed_messages.append(message)
return fixed_messages