From a6dc458212f7fa8562b38d531e8626a52fe0e9f7 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 23:03:56 +0800 Subject: [PATCH] feat(third-party-agent): implement streaming response handling and enhance agent execution flow --- .../method/agent_sub_stages/third_party.py | 108 +++++++++++++++--- 1 file changed, 91 insertions(+), 17 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index d4d0709e7..a9fb67dc8 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -1,5 +1,6 @@ import asyncio from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING from astrbot.core import logger from astrbot.core.agent.runners.coze_agent_runner import CozeAgentRunner @@ -7,9 +8,13 @@ from astrbot.core.agent.runners.dashscope_agent_runner import DashscopeAgentRunn from astrbot.core.agent.runners.dify_agent_runner import DifyAgentRunner from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( + MessageChain, MessageEventResult, ResultContentType, ) + +if TYPE_CHECKING: + from astrbot.core.agent.runners.base import BaseAgentRunner from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( ProviderRequest, @@ -29,6 +34,32 @@ AGENT_RUNNER_TYPE_KEY = { } +async def run_third_party_agent( + runner: "BaseAgentRunner", + stream_to_general: bool = False, +) -> AsyncGenerator[MessageChain | None, None]: + """ + 运行第三方 agent runner 并转换响应格式 + 类似于 run_agent 函数,但专门处理第三方 agent runner + """ + try: + async for resp in runner.step_until_done(max_step=30): # type: ignore[misc] + if resp.type == "streaming_delta": + if stream_to_general: + continue + yield resp.data["chain"] + elif resp.type == "llm_result": + if stream_to_general: + yield resp.data["chain"] + except Exception as e: + logger.error(f"Third party agent runner error: {e}") + err_msg = ( + f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n" + f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n" + ) + yield MessageChain().message(err_msg) + + class ThirdPartyAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -38,10 +69,11 @@ class ThirdPartyAgentSubStage(Stage): AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""), "", ) - self.prov_cfg: dict = next( - (p for p in self.conf["provider"] if p["id"] == self.prov_id), - {}, - ) + settings = ctx.astrbot_config["provider_settings"] + self.streaming_response: bool = settings["streaming_response"] + self.unsupported_streaming_strategy: str = settings[ + "unsupported_streaming_strategy" + ] async def process( self, event: AstrMessageEvent, provider_wake_prefix: str @@ -52,6 +84,11 @@ class ThirdPartyAgentSubStage(Stage): provider_wake_prefix ): return + + self.prov_cfg: dict = next( + (p for p in self.conf["provider"] if p["id"] == self.prov_id), + {}, + ) if not self.prov_id or not self.prov_cfg: logger.error( "Third Party Agent Runner provider ID is not configured properly." @@ -90,6 +127,15 @@ class ThirdPartyAgentSubStage(Stage): event=event, ) + streaming_response = self.streaming_response + if (enable_streaming := event.get_extra("enable_streaming")) is not None: + streaming_response = bool(enable_streaming) + + stream_to_general = ( + self.unsupported_streaming_strategy == "turn_off" + and not event.platform_meta.support_streaming_message + ) + await runner.reset( request=req, run_context=AgentContextWrapper( @@ -98,24 +144,52 @@ class ThirdPartyAgentSubStage(Stage): ), agent_hooks=MAIN_AGENT_HOOKS, provider_config=self.prov_cfg, + streaming=streaming_response, ) - async for _ in runner.step_until_done(): - pass + if streaming_response and not stream_to_general: + # 流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_third_party_agent( + runner, + stream_to_general=False, + ), + ), + ) + yield + if runner.done(): + final_resp = runner.get_final_llm_resp() + if final_resp and final_resp.result_chain: + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + else: + # 非流式响应或转换为普通响应 + async for _ in run_third_party_agent( + runner, + stream_to_general=stream_to_general, + ): + yield - final_resp = runner.get_final_llm_resp() + final_resp = runner.get_final_llm_resp() - if not final_resp or not final_resp.result_chain: - logger.warning("Agent Runner 未返回最终结果。") - return + if not final_resp or not final_resp.result_chain: + logger.warning("Agent Runner 未返回最终结果。") + return - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.LLM_RESULT, - ), - ) - yield + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.LLM_RESULT, + ), + ) + yield asyncio.create_task( Metric.upload(