diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index dd65f92e6..364cf8b4f 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -87,6 +87,21 @@ def _build_tool_result_status_message( return status_msg +def _extract_final_streaming_chain(msg_chain: MessageChain) -> MessageChain | None: + if not msg_chain.chain: + return None + + final_chain: list[BaseMessageComponent] = [] + for comp in msg_chain.chain: + if isinstance(comp, Plain): + continue + final_chain.append(comp) + + if not final_chain: + return None + return MessageChain(chain=final_chain, type=msg_chain.type) + + async def run_agent( agent_runner: AgentRunner, max_step: int = 30, @@ -211,6 +226,11 @@ async def run_agent( # display the reasoning content only when configured continue yield resp.data["chain"] # MessageChain + elif resp.type == "llm_result": + if final_chain := _extract_final_streaming_chain( + resp.data["chain"] + ): + yield final_chain if not stop_watcher.done(): stop_watcher.cancel() try: diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 1a49fc37a..87b1726d6 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -724,6 +724,38 @@ def _sanitize_context_by_modalities( req.contexts = sanitized_contexts +def _model_outputs_image(provider: Provider, req: ProviderRequest) -> bool: + model = req.model or provider.get_model() + if not model: + return False + model_info = LLM_METADATAS.get(model) + if not model_info: + return False + output_modalities = model_info.get("modalities", {}).get("output", []) + return "image" in output_modalities + + +def _should_disable_streaming_for_webchat_output( + event: AstrMessageEvent, + provider: Provider, + req: ProviderRequest, +) -> bool: + if event.get_platform_name() != "webchat": + return False + + provider_cfg = provider.provider_config + provider_type = provider_cfg.get("type", "") + if provider_type == "googlegenai_chat_completion" and provider_cfg.get( + "gm_resp_image_modal", False + ): + return True + + if _model_outputs_image(provider, req): + return not bool(provider_cfg.get("supports_streaming_output_modalities", False)) + + return False + + def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: """根据事件中的插件设置,过滤请求中的工具列表。 @@ -1091,6 +1123,17 @@ async def build_main_agent( if action_type == "live": req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n" + streaming_response = config.streaming_response + if streaming_response and _should_disable_streaming_for_webchat_output( + event, provider, req + ): + logger.info( + "Disable streaming for webchat direct media output. provider=%s model=%s", + provider.provider_config.get("id", "unknown"), + req.model or provider.get_model(), + ) + streaming_response = False + reset_coro = agent_runner.reset( provider=provider, request=req, @@ -1100,7 +1143,7 @@ async def build_main_agent( ), tool_executor=FunctionToolExecutor(), agent_hooks=MAIN_AGENT_HOOKS, - streaming=config.streaming_response, + streaming=streaming_response, llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, llm_compress_provider=_get_compress_provider(config, plugin_context), diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 572c4214d..02a525964 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -239,6 +239,8 @@ class InternalAgentSubStage(Stage): if reset_coro: await reset_coro + effective_streaming_response = bool(agent_runner.streaming) + register_active_runner(event.unified_msg_origin, agent_runner) runner_registered = True action_type = event.get_extra("action_type") @@ -247,7 +249,7 @@ class InternalAgentSubStage(Stage): "astr_agent_prepare", system_prompt=req.system_prompt, tools=req.func_tool.names() if req.func_tool else [], - stream=streaming_response, + stream=effective_streaming_response, chat_provider={ "id": provider.provider_config.get("id", ""), "model": provider.get_model(), @@ -301,7 +303,7 @@ class InternalAgentSubStage(Stage): user_aborted=agent_runner.was_aborted(), ) - elif streaming_response and not stream_to_general: + elif effective_streaming_response and not stream_to_general: # 流式响应 event.set_result( MessageEventResult() diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9557f3dbc..553f2e3d4 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -134,16 +134,14 @@ class ProviderGoogleGenAI(Provider): system_instruction: str | None = None, modalities: list[str] | None = None, temperature: float = 0.7, + streaming: bool = False, ) -> types.GenerateContentConfig: """准备查询配置""" if not modalities: modalities = ["TEXT"] # 流式输出不支持图片模态 - if ( - self.provider_settings.get("streaming_response", False) - and "IMAGE" in modalities - ): + if streaming and "IMAGE" in modalities: logger.warning("流式输出不支持图片模态,已自动降级为文本模态") modalities = ["TEXT"] @@ -538,6 +536,7 @@ class ProviderGoogleGenAI(Provider): system_instruction, modalities, temperature, + streaming=False, ) result = await self.client.models.generate_content( model=model, @@ -617,6 +616,7 @@ class ProviderGoogleGenAI(Provider): payloads, tools, system_instruction, + streaming=True, ) result = await self.client.models.generate_content_stream( model=model, diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 38c601cee..b8a189374 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -11,6 +11,9 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.astr_agent_run_util import run_agent +from astrbot.core.message.components import Image +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage from astrbot.core.provider.provider import Provider @@ -127,6 +130,26 @@ class MockAbortableStreamProvider(MockProvider): ) +class MockFinalMediaStreamProvider(MockProvider): + async def text_chat_stream(self, **kwargs): + yield LLMResponse( + role="assistant", + is_chunk=True, + result_chain=MessageChain().message("draft "), + ) + yield LLMResponse( + role="assistant", + is_chunk=False, + result_chain=MessageChain( + chain=[ + Image.fromBase64( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+a7d4AAAAASUVORK5CYII=" + ) + ] + ), + ) + + class MockHooks(BaseAgentRunHooks): """模拟钩子函数""" @@ -466,6 +489,41 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message( assert runner.run_context.messages[-1].role == "assistant" +@pytest.mark.asyncio +async def test_run_agent_emits_final_media_chain_in_streaming_mode( + runner, provider_request, mock_tool_executor, mock_hooks +): + provider = MockFinalMediaStreamProvider() + mock_event = MockEvent("test:FriendMessage:stream_media", "u1") + mock_event.is_stopped = lambda: False + mock_event.get_extra = lambda *args, **kwargs: None + mock_event.set_extra = lambda *args, **kwargs: None + mock_event.get_platform_id = lambda: "webchat" + mock_event.get_platform_name = lambda: "webchat" + mock_event.send = AsyncMock() + mock_event.trace = AsyncMock() + mock_event.trace.record = lambda *args, **kwargs: None + run_context = ContextWrapper(context=MockAgentContext(mock_event)) + + await runner.reset( + provider=provider, + request=provider_request, + run_context=run_context, + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=True, + ) + + emitted = [] + async for chain in run_agent(runner, max_step=5): + if chain is not None: + emitted.append(chain) + + assert any( + any(isinstance(comp, Image) for comp in chain.chain) for chain in emitted + ) + + @pytest.mark.asyncio async def test_tool_result_injects_follow_up_notice( runner, mock_provider, provider_request, mock_tool_executor, mock_hooks diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 613185465..b8675831d 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -1046,6 +1046,103 @@ class TestBuildMainAgent: assert result is not None assert result.provider_request == existing_req + @pytest.mark.asyncio + async def test_build_main_agent_disables_streaming_for_webchat_gemini_image_output( + self, mock_event, mock_context, mock_provider + ): + """Test Gemini image output requests force non-streaming on webchat.""" + module = ama + mock_provider.provider_config = { + "id": "google_gemini", + "type": "googlegenai_chat_completion", + "gm_resp_image_modal": True, + "modalities": ["image", "tool_use"], + } + mock_provider.get_model.return_value = "gemini-3-pro-image-preview" + mock_event.get_platform_name.return_value = "webchat" + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = {} + + conv_mgr = mock_context.conversation_manager + _setup_conversation_for_build(conv_mgr) + + with ( + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + streaming_response=True, + ), + ) + + assert result is not None + assert mock_runner.reset.call_args.kwargs["streaming"] is False + + @pytest.mark.asyncio + async def test_build_main_agent_disables_streaming_for_webchat_image_output_model_metadata( + self, mock_event, mock_context, mock_provider + ): + """Test image-output model metadata forces non-streaming on webchat.""" + module = ama + mock_provider.provider_config = { + "id": "test-provider", + "type": "openai_chat_completion", + "modalities": ["image", "tool_use"], + } + mock_provider.get_model.return_value = "test-image-output-model" + mock_event.get_platform_name.return_value = "webchat" + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = {} + + conv_mgr = mock_context.conversation_manager + _setup_conversation_for_build(conv_mgr) + + with ( + patch.dict( + "astrbot.core.astr_main_agent.LLM_METADATAS", + { + "test-image-output-model": { + "id": "test-image-output-model", + "reasoning": False, + "tool_call": False, + "knowledge": "none", + "release_date": "", + "modalities": {"input": ["text"], "output": ["text", "image"]}, + "open_weights": False, + "limit": {"context": 0, "output": 0}, + } + }, + clear=False, + ), + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + streaming_response=True, + ), + ) + + assert result is not None + assert mock_runner.reset.call_args.kwargs["streaming"] is False + class TestHandleWebchat: """Tests for _handle_webchat function.""" diff --git a/tests/unit/test_gemini_source.py b/tests/unit/test_gemini_source.py new file mode 100644 index 000000000..5a0e8be57 --- /dev/null +++ b/tests/unit/test_gemini_source.py @@ -0,0 +1,44 @@ +from unittest.mock import patch + +import pytest + +from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI + + +def _build_provider(provider_settings: dict) -> ProviderGoogleGenAI: + with patch.object(ProviderGoogleGenAI, "_init_client", lambda self: None): + return ProviderGoogleGenAI( + { + "key": ["test-key"], + "model": "gemini-3-pro-image-preview", + "timeout": 30, + "gm_safety_settings": {}, + }, + provider_settings, + ) + + +@pytest.mark.asyncio +async def test_prepare_query_config_keeps_image_modalities_for_non_stream_requests(): + provider = _build_provider({"streaming_response": True}) + + config = await provider._prepare_query_config( + payloads={"messages": [], "model": "gemini-3-pro-image-preview"}, + modalities=["TEXT", "IMAGE"], + streaming=False, + ) + + assert config.response_modalities == ["TEXT", "IMAGE"] + + +@pytest.mark.asyncio +async def test_prepare_query_config_downgrades_image_modalities_for_stream_requests(): + provider = _build_provider({"streaming_response": False}) + + config = await provider._prepare_query_config( + payloads={"messages": [], "model": "gemini-3-pro-image-preview"}, + modalities=["TEXT", "IMAGE"], + streaming=True, + ) + + assert config.response_modalities == ["TEXT"] diff --git a/tests/unit/test_internal_agent_sub_stage.py b/tests/unit/test_internal_agent_sub_stage.py new file mode 100644 index 000000000..1e6a10833 --- /dev/null +++ b/tests/unit/test_internal_agent_sub_stage.py @@ -0,0 +1,127 @@ +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.astr_main_agent import MainAgentBuildConfig, MainAgentBuildResult +from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( + InternalAgentSubStage, +) +from astrbot.core.provider.entities import LLMResponse, ProviderRequest + + +@asynccontextmanager +async def _dummy_session_lock(): + yield + + +async def _dummy_run_agent(*args, **kwargs): + yield None + + +async def _dummy_reset(): + return None + + +@pytest.mark.asyncio +async def test_internal_stage_uses_effective_runner_streaming_flag(): + stage = InternalAgentSubStage() + stage.ctx = MagicMock() + stage.ctx.plugin_manager.context = MagicMock() + stage.streaming_response = True + stage.unsupported_streaming_strategy = "realtime_segmenting" + stage.max_step = 5 + stage.show_tool_use = True + stage.show_tool_call_result = False + stage.show_reasoning = False + stage.main_agent_cfg = MainAgentBuildConfig(tool_call_timeout=60) + stage._save_to_history = AsyncMock() + + event = MagicMock() + event.message_str = "draw a shiba" + event.message_obj.message = [] + event.unified_msg_origin = "webchat:private:test" + event.get_extra.side_effect = lambda key, default=None: { + "enable_streaming": True, + "provider_request": None, + "action_type": None, + }.get(key, default) + event.platform_meta.support_streaming_message = True + event.send_typing = AsyncMock() + event.trace = MagicMock() + event.is_stopped.return_value = False + + agent_runner = MagicMock() + agent_runner.streaming = False + agent_runner.done.return_value = True + agent_runner.was_aborted.return_value = False + agent_runner.get_final_llm_resp.return_value = LLMResponse( + role="assistant", completion_text="done" + ) + agent_runner.run_context.messages = [] + agent_runner.stats.to_dict.return_value = {} + agent_runner.provider.get_model.return_value = "gemini-3-pro-image-preview" + agent_runner.provider.meta.return_value.type = "googlegenai_chat_completion" + + provider = MagicMock() + provider.provider_config = {"id": "google_gemini", "api_base": ""} + provider.get_model.return_value = "gemini-3-pro-image-preview" + + build_result = MainAgentBuildResult( + agent_runner=agent_runner, + provider_request=ProviderRequest(prompt="draw a shiba"), + provider=provider, + reset_coro=_dummy_reset(), + ) + + with ( + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.build_main_agent", + AsyncMock(return_value=build_result), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.run_agent", + _dummy_run_agent, + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.call_event_hook", + AsyncMock(return_value=False), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.try_capture_follow_up", + return_value=None, + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.register_active_runner" + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.unregister_active_runner" + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.session_lock_manager.acquire_lock", + return_value=_dummy_session_lock(), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.Metric.upload", + AsyncMock(return_value=None), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.asyncio.create_task", + side_effect=lambda coro: coro.close(), + ), + ): + yielded = [] + async for item in stage.process(event, provider_wake_prefix=""): + yielded.append(item) + + assert yielded == [None] + event.trace.record.assert_any_call( + "astr_agent_prepare", + system_prompt="", + tools=[], + stream=False, + chat_provider={ + "id": "google_gemini", + "model": "gemini-3-pro-image-preview", + }, + )