fix: handle webchat image outputs without streaming

This commit is contained in:
邹永赫
2026-03-15 16:13:13 +09:00
parent da1565ee81
commit d7457f38d4
8 changed files with 398 additions and 7 deletions
+20
View File
@@ -87,6 +87,21 @@ def _build_tool_result_status_message(
return status_msg
def _extract_final_streaming_chain(msg_chain: MessageChain) -> MessageChain | None:
if not msg_chain.chain:
return None
final_chain: list[BaseMessageComponent] = []
for comp in msg_chain.chain:
if isinstance(comp, Plain):
continue
final_chain.append(comp)
if not final_chain:
return None
return MessageChain(chain=final_chain, type=msg_chain.type)
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
@@ -211,6 +226,11 @@ async def run_agent(
# display the reasoning content only when configured
continue
yield resp.data["chain"] # MessageChain
elif resp.type == "llm_result":
if final_chain := _extract_final_streaming_chain(
resp.data["chain"]
):
yield final_chain
if not stop_watcher.done():
stop_watcher.cancel()
try:
+44 -1
View File
@@ -724,6 +724,38 @@ def _sanitize_context_by_modalities(
req.contexts = sanitized_contexts
def _model_outputs_image(provider: Provider, req: ProviderRequest) -> bool:
model = req.model or provider.get_model()
if not model:
return False
model_info = LLM_METADATAS.get(model)
if not model_info:
return False
output_modalities = model_info.get("modalities", {}).get("output", [])
return "image" in output_modalities
def _should_disable_streaming_for_webchat_output(
event: AstrMessageEvent,
provider: Provider,
req: ProviderRequest,
) -> bool:
if event.get_platform_name() != "webchat":
return False
provider_cfg = provider.provider_config
provider_type = provider_cfg.get("type", "")
if provider_type == "googlegenai_chat_completion" and provider_cfg.get(
"gm_resp_image_modal", False
):
return True
if _model_outputs_image(provider, req):
return not bool(provider_cfg.get("supports_streaming_output_modalities", False))
return False
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
"""根据事件中的插件设置,过滤请求中的工具列表。
@@ -1091,6 +1123,17 @@ async def build_main_agent(
if action_type == "live":
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
streaming_response = config.streaming_response
if streaming_response and _should_disable_streaming_for_webchat_output(
event, provider, req
):
logger.info(
"Disable streaming for webchat direct media output. provider=%s model=%s",
provider.provider_config.get("id", "unknown"),
req.model or provider.get_model(),
)
streaming_response = False
reset_coro = agent_runner.reset(
provider=provider,
request=req,
@@ -1100,7 +1143,7 @@ async def build_main_agent(
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=config.streaming_response,
streaming=streaming_response,
llm_compress_instruction=config.llm_compress_instruction,
llm_compress_keep_recent=config.llm_compress_keep_recent,
llm_compress_provider=_get_compress_provider(config, plugin_context),
@@ -239,6 +239,8 @@ class InternalAgentSubStage(Stage):
if reset_coro:
await reset_coro
effective_streaming_response = bool(agent_runner.streaming)
register_active_runner(event.unified_msg_origin, agent_runner)
runner_registered = True
action_type = event.get_extra("action_type")
@@ -247,7 +249,7 @@ class InternalAgentSubStage(Stage):
"astr_agent_prepare",
system_prompt=req.system_prompt,
tools=req.func_tool.names() if req.func_tool else [],
stream=streaming_response,
stream=effective_streaming_response,
chat_provider={
"id": provider.provider_config.get("id", ""),
"model": provider.get_model(),
@@ -301,7 +303,7 @@ class InternalAgentSubStage(Stage):
user_aborted=agent_runner.was_aborted(),
)
elif streaming_response and not stream_to_general:
elif effective_streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
@@ -134,16 +134,14 @@ class ProviderGoogleGenAI(Provider):
system_instruction: str | None = None,
modalities: list[str] | None = None,
temperature: float = 0.7,
streaming: bool = False,
) -> types.GenerateContentConfig:
"""准备查询配置"""
if not modalities:
modalities = ["TEXT"]
# 流式输出不支持图片模态
if (
self.provider_settings.get("streaming_response", False)
and "IMAGE" in modalities
):
if streaming and "IMAGE" in modalities:
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["TEXT"]
@@ -538,6 +536,7 @@ class ProviderGoogleGenAI(Provider):
system_instruction,
modalities,
temperature,
streaming=False,
)
result = await self.client.models.generate_content(
model=model,
@@ -617,6 +616,7 @@ class ProviderGoogleGenAI(Provider):
payloads,
tools,
system_instruction,
streaming=True,
)
result = await self.client.models.generate_content_stream(
model=model,
+58
View File
@@ -11,6 +11,9 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.astr_agent_run_util import run_agent
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage
from astrbot.core.provider.provider import Provider
@@ -127,6 +130,26 @@ class MockAbortableStreamProvider(MockProvider):
)
class MockFinalMediaStreamProvider(MockProvider):
async def text_chat_stream(self, **kwargs):
yield LLMResponse(
role="assistant",
is_chunk=True,
result_chain=MessageChain().message("draft "),
)
yield LLMResponse(
role="assistant",
is_chunk=False,
result_chain=MessageChain(
chain=[
Image.fromBase64(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+a7d4AAAAASUVORK5CYII="
)
]
),
)
class MockHooks(BaseAgentRunHooks):
"""模拟钩子函数"""
@@ -466,6 +489,41 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message(
assert runner.run_context.messages[-1].role == "assistant"
@pytest.mark.asyncio
async def test_run_agent_emits_final_media_chain_in_streaming_mode(
runner, provider_request, mock_tool_executor, mock_hooks
):
provider = MockFinalMediaStreamProvider()
mock_event = MockEvent("test:FriendMessage:stream_media", "u1")
mock_event.is_stopped = lambda: False
mock_event.get_extra = lambda *args, **kwargs: None
mock_event.set_extra = lambda *args, **kwargs: None
mock_event.get_platform_id = lambda: "webchat"
mock_event.get_platform_name = lambda: "webchat"
mock_event.send = AsyncMock()
mock_event.trace = AsyncMock()
mock_event.trace.record = lambda *args, **kwargs: None
run_context = ContextWrapper(context=MockAgentContext(mock_event))
await runner.reset(
provider=provider,
request=provider_request,
run_context=run_context,
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=True,
)
emitted = []
async for chain in run_agent(runner, max_step=5):
if chain is not None:
emitted.append(chain)
assert any(
any(isinstance(comp, Image) for comp in chain.chain) for chain in emitted
)
@pytest.mark.asyncio
async def test_tool_result_injects_follow_up_notice(
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
+97
View File
@@ -1046,6 +1046,103 @@ class TestBuildMainAgent:
assert result is not None
assert result.provider_request == existing_req
@pytest.mark.asyncio
async def test_build_main_agent_disables_streaming_for_webchat_gemini_image_output(
self, mock_event, mock_context, mock_provider
):
"""Test Gemini image output requests force non-streaming on webchat."""
module = ama
mock_provider.provider_config = {
"id": "google_gemini",
"type": "googlegenai_chat_completion",
"gm_resp_image_modal": True,
"modalities": ["image", "tool_use"],
}
mock_provider.get_model.return_value = "gemini-3-pro-image-preview"
mock_event.get_platform_name.return_value = "webchat"
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {}
conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(
tool_call_timeout=60,
streaming_response=True,
),
)
assert result is not None
assert mock_runner.reset.call_args.kwargs["streaming"] is False
@pytest.mark.asyncio
async def test_build_main_agent_disables_streaming_for_webchat_image_output_model_metadata(
self, mock_event, mock_context, mock_provider
):
"""Test image-output model metadata forces non-streaming on webchat."""
module = ama
mock_provider.provider_config = {
"id": "test-provider",
"type": "openai_chat_completion",
"modalities": ["image", "tool_use"],
}
mock_provider.get_model.return_value = "test-image-output-model"
mock_event.get_platform_name.return_value = "webchat"
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {}
conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)
with (
patch.dict(
"astrbot.core.astr_main_agent.LLM_METADATAS",
{
"test-image-output-model": {
"id": "test-image-output-model",
"reasoning": False,
"tool_call": False,
"knowledge": "none",
"release_date": "",
"modalities": {"input": ["text"], "output": ["text", "image"]},
"open_weights": False,
"limit": {"context": 0, "output": 0},
}
},
clear=False,
),
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(
tool_call_timeout=60,
streaming_response=True,
),
)
assert result is not None
assert mock_runner.reset.call_args.kwargs["streaming"] is False
class TestHandleWebchat:
"""Tests for _handle_webchat function."""
+44
View File
@@ -0,0 +1,44 @@
from unittest.mock import patch
import pytest
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
def _build_provider(provider_settings: dict) -> ProviderGoogleGenAI:
with patch.object(ProviderGoogleGenAI, "_init_client", lambda self: None):
return ProviderGoogleGenAI(
{
"key": ["test-key"],
"model": "gemini-3-pro-image-preview",
"timeout": 30,
"gm_safety_settings": {},
},
provider_settings,
)
@pytest.mark.asyncio
async def test_prepare_query_config_keeps_image_modalities_for_non_stream_requests():
provider = _build_provider({"streaming_response": True})
config = await provider._prepare_query_config(
payloads={"messages": [], "model": "gemini-3-pro-image-preview"},
modalities=["TEXT", "IMAGE"],
streaming=False,
)
assert config.response_modalities == ["TEXT", "IMAGE"]
@pytest.mark.asyncio
async def test_prepare_query_config_downgrades_image_modalities_for_stream_requests():
provider = _build_provider({"streaming_response": False})
config = await provider._prepare_query_config(
payloads={"messages": [], "model": "gemini-3-pro-image-preview"},
modalities=["TEXT", "IMAGE"],
streaming=True,
)
assert config.response_modalities == ["TEXT"]
+127
View File
@@ -0,0 +1,127 @@
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from astrbot.core.astr_main_agent import MainAgentBuildConfig, MainAgentBuildResult
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
InternalAgentSubStage,
)
from astrbot.core.provider.entities import LLMResponse, ProviderRequest
@asynccontextmanager
async def _dummy_session_lock():
yield
async def _dummy_run_agent(*args, **kwargs):
yield None
async def _dummy_reset():
return None
@pytest.mark.asyncio
async def test_internal_stage_uses_effective_runner_streaming_flag():
stage = InternalAgentSubStage()
stage.ctx = MagicMock()
stage.ctx.plugin_manager.context = MagicMock()
stage.streaming_response = True
stage.unsupported_streaming_strategy = "realtime_segmenting"
stage.max_step = 5
stage.show_tool_use = True
stage.show_tool_call_result = False
stage.show_reasoning = False
stage.main_agent_cfg = MainAgentBuildConfig(tool_call_timeout=60)
stage._save_to_history = AsyncMock()
event = MagicMock()
event.message_str = "draw a shiba"
event.message_obj.message = []
event.unified_msg_origin = "webchat:private:test"
event.get_extra.side_effect = lambda key, default=None: {
"enable_streaming": True,
"provider_request": None,
"action_type": None,
}.get(key, default)
event.platform_meta.support_streaming_message = True
event.send_typing = AsyncMock()
event.trace = MagicMock()
event.is_stopped.return_value = False
agent_runner = MagicMock()
agent_runner.streaming = False
agent_runner.done.return_value = True
agent_runner.was_aborted.return_value = False
agent_runner.get_final_llm_resp.return_value = LLMResponse(
role="assistant", completion_text="done"
)
agent_runner.run_context.messages = []
agent_runner.stats.to_dict.return_value = {}
agent_runner.provider.get_model.return_value = "gemini-3-pro-image-preview"
agent_runner.provider.meta.return_value.type = "googlegenai_chat_completion"
provider = MagicMock()
provider.provider_config = {"id": "google_gemini", "api_base": ""}
provider.get_model.return_value = "gemini-3-pro-image-preview"
build_result = MainAgentBuildResult(
agent_runner=agent_runner,
provider_request=ProviderRequest(prompt="draw a shiba"),
provider=provider,
reset_coro=_dummy_reset(),
)
with (
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.build_main_agent",
AsyncMock(return_value=build_result),
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.run_agent",
_dummy_run_agent,
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.call_event_hook",
AsyncMock(return_value=False),
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.try_capture_follow_up",
return_value=None,
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.register_active_runner"
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.unregister_active_runner"
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.session_lock_manager.acquire_lock",
return_value=_dummy_session_lock(),
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.Metric.upload",
AsyncMock(return_value=None),
),
patch(
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.asyncio.create_task",
side_effect=lambda coro: coro.close(),
),
):
yielded = []
async for item in stage.process(event, provider_wake_prefix=""):
yielded.append(item)
assert yielded == [None]
event.trace.record.assert_any_call(
"astr_agent_prepare",
system_prompt="",
tools=[],
stream=False,
chat_provider={
"id": "google_gemini",
"model": "gemini-3-pro-image-preview",
},
)