fix: handle webchat image outputs without streaming
This commit is contained in:
@@ -87,6 +87,21 @@ def _build_tool_result_status_message(
|
||||
return status_msg
|
||||
|
||||
|
||||
def _extract_final_streaming_chain(msg_chain: MessageChain) -> MessageChain | None:
|
||||
if not msg_chain.chain:
|
||||
return None
|
||||
|
||||
final_chain: list[BaseMessageComponent] = []
|
||||
for comp in msg_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
continue
|
||||
final_chain.append(comp)
|
||||
|
||||
if not final_chain:
|
||||
return None
|
||||
return MessageChain(chain=final_chain, type=msg_chain.type)
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
@@ -211,6 +226,11 @@ async def run_agent(
|
||||
# display the reasoning content only when configured
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
elif resp.type == "llm_result":
|
||||
if final_chain := _extract_final_streaming_chain(
|
||||
resp.data["chain"]
|
||||
):
|
||||
yield final_chain
|
||||
if not stop_watcher.done():
|
||||
stop_watcher.cancel()
|
||||
try:
|
||||
|
||||
@@ -724,6 +724,38 @@ def _sanitize_context_by_modalities(
|
||||
req.contexts = sanitized_contexts
|
||||
|
||||
|
||||
def _model_outputs_image(provider: Provider, req: ProviderRequest) -> bool:
|
||||
model = req.model or provider.get_model()
|
||||
if not model:
|
||||
return False
|
||||
model_info = LLM_METADATAS.get(model)
|
||||
if not model_info:
|
||||
return False
|
||||
output_modalities = model_info.get("modalities", {}).get("output", [])
|
||||
return "image" in output_modalities
|
||||
|
||||
|
||||
def _should_disable_streaming_for_webchat_output(
|
||||
event: AstrMessageEvent,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
) -> bool:
|
||||
if event.get_platform_name() != "webchat":
|
||||
return False
|
||||
|
||||
provider_cfg = provider.provider_config
|
||||
provider_type = provider_cfg.get("type", "")
|
||||
if provider_type == "googlegenai_chat_completion" and provider_cfg.get(
|
||||
"gm_resp_image_modal", False
|
||||
):
|
||||
return True
|
||||
|
||||
if _model_outputs_image(provider, req):
|
||||
return not bool(provider_cfg.get("supports_streaming_output_modalities", False))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表。
|
||||
|
||||
@@ -1091,6 +1123,17 @@ async def build_main_agent(
|
||||
if action_type == "live":
|
||||
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
|
||||
|
||||
streaming_response = config.streaming_response
|
||||
if streaming_response and _should_disable_streaming_for_webchat_output(
|
||||
event, provider, req
|
||||
):
|
||||
logger.info(
|
||||
"Disable streaming for webchat direct media output. provider=%s model=%s",
|
||||
provider.provider_config.get("id", "unknown"),
|
||||
req.model or provider.get_model(),
|
||||
)
|
||||
streaming_response = False
|
||||
|
||||
reset_coro = agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
@@ -1100,7 +1143,7 @@ async def build_main_agent(
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=config.streaming_response,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=config.llm_compress_instruction,
|
||||
llm_compress_keep_recent=config.llm_compress_keep_recent,
|
||||
llm_compress_provider=_get_compress_provider(config, plugin_context),
|
||||
|
||||
@@ -239,6 +239,8 @@ class InternalAgentSubStage(Stage):
|
||||
if reset_coro:
|
||||
await reset_coro
|
||||
|
||||
effective_streaming_response = bool(agent_runner.streaming)
|
||||
|
||||
register_active_runner(event.unified_msg_origin, agent_runner)
|
||||
runner_registered = True
|
||||
action_type = event.get_extra("action_type")
|
||||
@@ -247,7 +249,7 @@ class InternalAgentSubStage(Stage):
|
||||
"astr_agent_prepare",
|
||||
system_prompt=req.system_prompt,
|
||||
tools=req.func_tool.names() if req.func_tool else [],
|
||||
stream=streaming_response,
|
||||
stream=effective_streaming_response,
|
||||
chat_provider={
|
||||
"id": provider.provider_config.get("id", ""),
|
||||
"model": provider.get_model(),
|
||||
@@ -301,7 +303,7 @@ class InternalAgentSubStage(Stage):
|
||||
user_aborted=agent_runner.was_aborted(),
|
||||
)
|
||||
|
||||
elif streaming_response and not stream_to_general:
|
||||
elif effective_streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
|
||||
@@ -134,16 +134,14 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_instruction: str | None = None,
|
||||
modalities: list[str] | None = None,
|
||||
temperature: float = 0.7,
|
||||
streaming: bool = False,
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
if not modalities:
|
||||
modalities = ["TEXT"]
|
||||
|
||||
# 流式输出不支持图片模态
|
||||
if (
|
||||
self.provider_settings.get("streaming_response", False)
|
||||
and "IMAGE" in modalities
|
||||
):
|
||||
if streaming and "IMAGE" in modalities:
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["TEXT"]
|
||||
|
||||
@@ -538,6 +536,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_instruction,
|
||||
modalities,
|
||||
temperature,
|
||||
streaming=False,
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=model,
|
||||
@@ -617,6 +616,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
payloads,
|
||||
tools,
|
||||
system_instruction,
|
||||
streaming=True,
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=model,
|
||||
|
||||
@@ -11,6 +11,9 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.astr_agent_run_util import run_agent
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
@@ -127,6 +130,26 @@ class MockAbortableStreamProvider(MockProvider):
|
||||
)
|
||||
|
||||
|
||||
class MockFinalMediaStreamProvider(MockProvider):
|
||||
async def text_chat_stream(self, **kwargs):
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
is_chunk=True,
|
||||
result_chain=MessageChain().message("draft "),
|
||||
)
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
is_chunk=False,
|
||||
result_chain=MessageChain(
|
||||
chain=[
|
||||
Image.fromBase64(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+a7d4AAAAASUVORK5CYII="
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MockHooks(BaseAgentRunHooks):
|
||||
"""模拟钩子函数"""
|
||||
|
||||
@@ -466,6 +489,41 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message(
|
||||
assert runner.run_context.messages[-1].role == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_emits_final_media_chain_in_streaming_mode(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
provider = MockFinalMediaStreamProvider()
|
||||
mock_event = MockEvent("test:FriendMessage:stream_media", "u1")
|
||||
mock_event.is_stopped = lambda: False
|
||||
mock_event.get_extra = lambda *args, **kwargs: None
|
||||
mock_event.set_extra = lambda *args, **kwargs: None
|
||||
mock_event.get_platform_id = lambda: "webchat"
|
||||
mock_event.get_platform_name = lambda: "webchat"
|
||||
mock_event.send = AsyncMock()
|
||||
mock_event.trace = AsyncMock()
|
||||
mock_event.trace.record = lambda *args, **kwargs: None
|
||||
run_context = ContextWrapper(context=MockAgentContext(mock_event))
|
||||
|
||||
await runner.reset(
|
||||
provider=provider,
|
||||
request=provider_request,
|
||||
run_context=run_context,
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
emitted = []
|
||||
async for chain in run_agent(runner, max_step=5):
|
||||
if chain is not None:
|
||||
emitted.append(chain)
|
||||
|
||||
assert any(
|
||||
any(isinstance(comp, Image) for comp in chain.chain) for chain in emitted
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_result_injects_follow_up_notice(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
|
||||
@@ -1046,6 +1046,103 @@ class TestBuildMainAgent:
|
||||
assert result is not None
|
||||
assert result.provider_request == existing_req
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_main_agent_disables_streaming_for_webchat_gemini_image_output(
|
||||
self, mock_event, mock_context, mock_provider
|
||||
):
|
||||
"""Test Gemini image output requests force non-streaming on webchat."""
|
||||
module = ama
|
||||
mock_provider.provider_config = {
|
||||
"id": "google_gemini",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"gm_resp_image_modal": True,
|
||||
"modalities": ["image", "tool_use"],
|
||||
}
|
||||
mock_provider.get_model.return_value = "gemini-3-pro-image-preview"
|
||||
mock_event.get_platform_name.return_value = "webchat"
|
||||
mock_context.get_provider_by_id.return_value = None
|
||||
mock_context.get_using_provider.return_value = mock_provider
|
||||
mock_context.get_config.return_value = {}
|
||||
|
||||
conv_mgr = mock_context.conversation_manager
|
||||
_setup_conversation_for_build(conv_mgr)
|
||||
|
||||
with (
|
||||
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
||||
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
||||
):
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.reset = AsyncMock()
|
||||
mock_runner_cls.return_value = mock_runner
|
||||
|
||||
result = await module.build_main_agent(
|
||||
event=mock_event,
|
||||
plugin_context=mock_context,
|
||||
config=module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60,
|
||||
streaming_response=True,
|
||||
),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert mock_runner.reset.call_args.kwargs["streaming"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_main_agent_disables_streaming_for_webchat_image_output_model_metadata(
|
||||
self, mock_event, mock_context, mock_provider
|
||||
):
|
||||
"""Test image-output model metadata forces non-streaming on webchat."""
|
||||
module = ama
|
||||
mock_provider.provider_config = {
|
||||
"id": "test-provider",
|
||||
"type": "openai_chat_completion",
|
||||
"modalities": ["image", "tool_use"],
|
||||
}
|
||||
mock_provider.get_model.return_value = "test-image-output-model"
|
||||
mock_event.get_platform_name.return_value = "webchat"
|
||||
mock_context.get_provider_by_id.return_value = None
|
||||
mock_context.get_using_provider.return_value = mock_provider
|
||||
mock_context.get_config.return_value = {}
|
||||
|
||||
conv_mgr = mock_context.conversation_manager
|
||||
_setup_conversation_for_build(conv_mgr)
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
"astrbot.core.astr_main_agent.LLM_METADATAS",
|
||||
{
|
||||
"test-image-output-model": {
|
||||
"id": "test-image-output-model",
|
||||
"reasoning": False,
|
||||
"tool_call": False,
|
||||
"knowledge": "none",
|
||||
"release_date": "",
|
||||
"modalities": {"input": ["text"], "output": ["text", "image"]},
|
||||
"open_weights": False,
|
||||
"limit": {"context": 0, "output": 0},
|
||||
}
|
||||
},
|
||||
clear=False,
|
||||
),
|
||||
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
||||
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
||||
):
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.reset = AsyncMock()
|
||||
mock_runner_cls.return_value = mock_runner
|
||||
|
||||
result = await module.build_main_agent(
|
||||
event=mock_event,
|
||||
plugin_context=mock_context,
|
||||
config=module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60,
|
||||
streaming_response=True,
|
||||
),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert mock_runner.reset.call_args.kwargs["streaming"] is False
|
||||
|
||||
|
||||
class TestHandleWebchat:
|
||||
"""Tests for _handle_webchat function."""
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
|
||||
|
||||
|
||||
def _build_provider(provider_settings: dict) -> ProviderGoogleGenAI:
|
||||
with patch.object(ProviderGoogleGenAI, "_init_client", lambda self: None):
|
||||
return ProviderGoogleGenAI(
|
||||
{
|
||||
"key": ["test-key"],
|
||||
"model": "gemini-3-pro-image-preview",
|
||||
"timeout": 30,
|
||||
"gm_safety_settings": {},
|
||||
},
|
||||
provider_settings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_query_config_keeps_image_modalities_for_non_stream_requests():
|
||||
provider = _build_provider({"streaming_response": True})
|
||||
|
||||
config = await provider._prepare_query_config(
|
||||
payloads={"messages": [], "model": "gemini-3-pro-image-preview"},
|
||||
modalities=["TEXT", "IMAGE"],
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert config.response_modalities == ["TEXT", "IMAGE"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_query_config_downgrades_image_modalities_for_stream_requests():
|
||||
provider = _build_provider({"streaming_response": False})
|
||||
|
||||
config = await provider._prepare_query_config(
|
||||
payloads={"messages": [], "model": "gemini-3-pro-image-preview"},
|
||||
modalities=["TEXT", "IMAGE"],
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert config.response_modalities == ["TEXT"]
|
||||
@@ -0,0 +1,127 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.astr_main_agent import MainAgentBuildConfig, MainAgentBuildResult
|
||||
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
|
||||
InternalAgentSubStage,
|
||||
)
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _dummy_session_lock():
|
||||
yield
|
||||
|
||||
|
||||
async def _dummy_run_agent(*args, **kwargs):
|
||||
yield None
|
||||
|
||||
|
||||
async def _dummy_reset():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_stage_uses_effective_runner_streaming_flag():
|
||||
stage = InternalAgentSubStage()
|
||||
stage.ctx = MagicMock()
|
||||
stage.ctx.plugin_manager.context = MagicMock()
|
||||
stage.streaming_response = True
|
||||
stage.unsupported_streaming_strategy = "realtime_segmenting"
|
||||
stage.max_step = 5
|
||||
stage.show_tool_use = True
|
||||
stage.show_tool_call_result = False
|
||||
stage.show_reasoning = False
|
||||
stage.main_agent_cfg = MainAgentBuildConfig(tool_call_timeout=60)
|
||||
stage._save_to_history = AsyncMock()
|
||||
|
||||
event = MagicMock()
|
||||
event.message_str = "draw a shiba"
|
||||
event.message_obj.message = []
|
||||
event.unified_msg_origin = "webchat:private:test"
|
||||
event.get_extra.side_effect = lambda key, default=None: {
|
||||
"enable_streaming": True,
|
||||
"provider_request": None,
|
||||
"action_type": None,
|
||||
}.get(key, default)
|
||||
event.platform_meta.support_streaming_message = True
|
||||
event.send_typing = AsyncMock()
|
||||
event.trace = MagicMock()
|
||||
event.is_stopped.return_value = False
|
||||
|
||||
agent_runner = MagicMock()
|
||||
agent_runner.streaming = False
|
||||
agent_runner.done.return_value = True
|
||||
agent_runner.was_aborted.return_value = False
|
||||
agent_runner.get_final_llm_resp.return_value = LLMResponse(
|
||||
role="assistant", completion_text="done"
|
||||
)
|
||||
agent_runner.run_context.messages = []
|
||||
agent_runner.stats.to_dict.return_value = {}
|
||||
agent_runner.provider.get_model.return_value = "gemini-3-pro-image-preview"
|
||||
agent_runner.provider.meta.return_value.type = "googlegenai_chat_completion"
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_config = {"id": "google_gemini", "api_base": ""}
|
||||
provider.get_model.return_value = "gemini-3-pro-image-preview"
|
||||
|
||||
build_result = MainAgentBuildResult(
|
||||
agent_runner=agent_runner,
|
||||
provider_request=ProviderRequest(prompt="draw a shiba"),
|
||||
provider=provider,
|
||||
reset_coro=_dummy_reset(),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.build_main_agent",
|
||||
AsyncMock(return_value=build_result),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.run_agent",
|
||||
_dummy_run_agent,
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.call_event_hook",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.try_capture_follow_up",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.register_active_runner"
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.unregister_active_runner"
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.session_lock_manager.acquire_lock",
|
||||
return_value=_dummy_session_lock(),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.Metric.upload",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.asyncio.create_task",
|
||||
side_effect=lambda coro: coro.close(),
|
||||
),
|
||||
):
|
||||
yielded = []
|
||||
async for item in stage.process(event, provider_wake_prefix=""):
|
||||
yielded.append(item)
|
||||
|
||||
assert yielded == [None]
|
||||
event.trace.record.assert_any_call(
|
||||
"astr_agent_prepare",
|
||||
system_prompt="",
|
||||
tools=[],
|
||||
stream=False,
|
||||
chat_provider={
|
||||
"id": "google_gemini",
|
||||
"model": "gemini-3-pro-image-preview",
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user