perf: add AgentState Enum for improved state management
This commit is contained in:
@@ -3,6 +3,15 @@ import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ....message.message_event_result import MessageChain
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Agent 状态枚举"""
|
||||
IDLE = auto() # 初始状态
|
||||
RUNNING = auto() # 运行中
|
||||
DONE = auto() # 完成
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentResponseData
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
|
||||
from ...context import PipelineContext
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -43,13 +43,22 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
self.req = None
|
||||
self.event = event
|
||||
self.pipeline_ctx = pipeline_ctx
|
||||
self._state = AgentState.IDLE
|
||||
self.final_llm_resp = None
|
||||
self.streaming = False
|
||||
|
||||
@override
|
||||
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
|
||||
self.req = req
|
||||
self.streaming = streaming
|
||||
self.final_llm_resp = None
|
||||
self.is_done = False
|
||||
self._state = AgentState.IDLE
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
logger.info(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
@@ -69,7 +78,8 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
# 执行 LLM 请求
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
@@ -99,9 +109,9 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
logger.info(f"LLMResp: {llm_resp}")
|
||||
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,直接返回错误信息
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self.is_done = True
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
@@ -112,9 +122,9 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
)
|
||||
|
||||
if not llm_resp.tools_call_name:
|
||||
# 如果没有工具调用,结束 Agent Loop
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self.is_done = True
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
# 执行事件钩子
|
||||
await self.pipeline_ctx.call_event_hook(
|
||||
@@ -157,8 +167,6 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
)
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
logger.info("done: %s", self.is_done)
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -251,13 +259,13 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
else:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self.is_done = True
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(chain=res.chain)
|
||||
|
||||
self.event.clear_result()
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -272,7 +280,8 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
yield tool_call_result_blocks
|
||||
|
||||
def done(self) -> bool:
|
||||
return self.is_done
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
|
||||
@@ -268,13 +268,11 @@ class LLMRequestSubStage(Stage):
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
web_chat_back_queue.put_nowait(
|
||||
{
|
||||
"type": "update_title",
|
||||
"cid": cid,
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
web_chat_back_queue.put_nowait({
|
||||
"type": "update_title",
|
||||
"cid": cid,
|
||||
"data": title,
|
||||
})
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
@@ -323,4 +321,3 @@ class LLMRequestSubStage(Stage):
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
|
||||
Reference in New Issue
Block a user