feat: 允许 LLM 预览工具返回的图片并自主决定是否发送 (#4895)

* feat: 允许 LLM 预览工具返回的图片并自主决定是否发送

* 复用 send_message_to_user 替代独立的图片发送工具

* feat: implement _HandleFunctionToolsResult class for improved tool response handling

* docs: add path handling guidelines to AGENTS.md

---------

Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
Gao Jinzhe
2026-02-08 13:16:16 +08:00
committed by GitHub
parent 4e0b5063c6
commit 952023db30
3 changed files with 301 additions and 40 deletions
+1
View File
@@ -26,6 +26,7 @@ Runs on `http://localhost:3000` by default.
3. After finishing, use `ruff format .` and `ruff check .` to format and check the code. 3. After finishing, use `ruff format .` and `ruff check .` to format and check the code.
4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`. 4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`.
5. Use English for all new comments. 5. Use English for all new comments.
6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory.
## PR instructions ## PR instructions
@@ -3,6 +3,7 @@ import sys
import time import time
import traceback import traceback
import typing as T import typing as T
from dataclasses import dataclass
from mcp.types import ( from mcp.types import (
BlobResourceContents, BlobResourceContents,
@@ -14,8 +15,9 @@ from mcp.types import (
) )
from astrbot import logger from astrbot import logger
from astrbot.core.agent.message import TextPart, ThinkPart from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
from astrbot.core.agent.tool_image_cache import tool_image_cache
from astrbot.core.message.components import Json from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
@@ -44,6 +46,28 @@ else:
from typing_extensions import override from typing_extensions import override
@dataclass(slots=True)
class _HandleFunctionToolsResult:
kind: T.Literal["message_chain", "tool_call_result_blocks", "cached_image"]
message_chain: MessageChain | None = None
tool_call_result_blocks: list[ToolCallMessageSegment] | None = None
cached_image: T.Any = None
@classmethod
def from_message_chain(cls, chain: MessageChain) -> "_HandleFunctionToolsResult":
return cls(kind="message_chain", message_chain=chain)
@classmethod
def from_tool_call_result_blocks(
cls, blocks: list[ToolCallMessageSegment]
) -> "_HandleFunctionToolsResult":
return cls(kind="tool_call_result_blocks", tool_call_result_blocks=blocks)
@classmethod
def from_cached_image(cls, image: T.Any) -> "_HandleFunctionToolsResult":
return cls(kind="cached_image", cached_image=image)
class ToolLoopAgentRunner(BaseAgentRunner[TContext]): class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override @override
async def reset( async def reset(
@@ -286,20 +310,27 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp, _ = await self._resolve_tool_exec(llm_resp) llm_resp, _ = await self._resolve_tool_exec(llm_resp)
tool_call_result_blocks = [] tool_call_result_blocks = []
cached_images = [] # Collect cached images for LLM visibility
async for result in self._handle_function_tools(self.req, llm_resp): async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list): if result.kind == "tool_call_result_blocks":
tool_call_result_blocks = result if result.tool_call_result_blocks is not None:
elif isinstance(result, MessageChain): tool_call_result_blocks = result.tool_call_result_blocks
if result.type is None: elif result.kind == "cached_image":
if result.cached_image is not None:
# Collect cached image info
cached_images.append(result.cached_image)
elif result.kind == "message_chain":
chain = result.message_chain
if chain is None or chain.type is None:
# should not happen # should not happen
continue continue
if result.type == "tool_direct_result": if chain.type == "tool_direct_result":
ar_type = "tool_call_result" ar_type = "tool_call_result"
else: else:
ar_type = result.type ar_type = chain.type
yield AgentResponse( yield AgentResponse(
type=ar_type, type=ar_type,
data=AgentResponseData(chain=result), data=AgentResponseData(chain=chain),
) )
# 将结果添加到上下文中 # 将结果添加到上下文中
@@ -327,6 +358,41 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
tool_calls_result.to_openai_messages_model() tool_calls_result.to_openai_messages_model()
) )
# If there are cached images and the model supports image input,
# append a user message with images so LLM can see them
if cached_images:
modalities = self.provider.provider_config.get("modalities", [])
supports_image = "image" in modalities
if supports_image:
# Build user message with images for LLM to review
image_parts = []
for cached_img in cached_images:
img_data = tool_image_cache.get_image_base64_by_path(
cached_img.file_path, cached_img.mime_type
)
if img_data:
base64_data, mime_type = img_data
image_parts.append(
TextPart(
text=f"[Image from tool '{cached_img.tool_name}', path='{cached_img.file_path}']"
)
)
image_parts.append(
ImageURLPart(
image_url=ImageURLPart.ImageURL(
url=f"data:{mime_type};base64,{base64_data}",
id=cached_img.file_path,
)
)
)
if image_parts:
self.run_context.messages.append(
Message(role="user", content=image_parts)
)
logger.debug(
f"Appended {len(cached_images)} cached image(s) to context for LLM review"
)
self.req.append_tool_calls_result(tool_calls_result) self.req.append_tool_calls_result(tool_calls_result)
async def step_until_done( async def step_until_done(
@@ -362,7 +428,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self, self,
req: ProviderRequest, req: ProviderRequest,
llm_response: LLMResponse, llm_response: LLMResponse,
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]: ) -> T.AsyncGenerator[_HandleFunctionToolsResult, None]:
"""处理函数工具调用。""" """处理函数工具调用。"""
tool_call_result_blocks: list[ToolCallMessageSegment] = [] tool_call_result_blocks: list[ToolCallMessageSegment] = []
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
@@ -373,18 +439,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_response.tools_call_args, llm_response.tools_call_args,
llm_response.tools_call_ids, llm_response.tools_call_ids,
): ):
yield MessageChain( yield _HandleFunctionToolsResult.from_message_chain(
type="tool_call", MessageChain(
chain=[ type="tool_call",
Json( chain=[
data={ Json(
"id": func_tool_id, data={
"name": func_tool_name, "id": func_tool_id,
"args": func_tool_args, "name": func_tool_name,
"ts": time.time(), "args": func_tool_args,
} "ts": time.time(),
) }
], )
],
)
) )
try: try:
if not req.func_tool: if not req.func_tool:
@@ -470,15 +538,28 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
), ),
) )
elif isinstance(res.content[0], ImageContent): elif isinstance(res.content[0], ImageContent):
# Cache the image instead of sending directly
cached_img = tool_image_cache.save_image(
base64_data=res.content[0].data,
tool_call_id=func_tool_id,
tool_name=func_tool_name,
index=0,
mime_type=res.content[0].mimeType or "image/png",
)
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.", content=(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
), ),
) )
yield MessageChain(type="tool_direct_result").base64_image( # Yield image info for LLM visibility (will be handled in step())
res.content[0].data, yield _HandleFunctionToolsResult.from_cached_image(
cached_img
) )
elif isinstance(res.content[0], EmbeddedResource): elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource resource = res.content[0].resource
@@ -495,16 +576,29 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
and resource.mimeType and resource.mimeType
and resource.mimeType.startswith("image/") and resource.mimeType.startswith("image/")
): ):
# Cache the image instead of sending directly
cached_img = tool_image_cache.save_image(
base64_data=resource.blob,
tool_call_id=func_tool_id,
tool_name=func_tool_name,
index=0,
mime_type=resource.mimeType,
)
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.", content=(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
), ),
) )
yield MessageChain( # Yield image info for LLM visibility
type="tool_direct_result", yield _HandleFunctionToolsResult.from_cached_image(
).base64_image(resource.blob) cached_img
)
else: else:
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
@@ -565,23 +659,27 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# yield the last tool call result # yield the last tool call result
if tool_call_result_blocks: if tool_call_result_blocks:
last_tcr_content = str(tool_call_result_blocks[-1].content) last_tcr_content = str(tool_call_result_blocks[-1].content)
yield MessageChain( yield _HandleFunctionToolsResult.from_message_chain(
type="tool_call_result", MessageChain(
chain=[ type="tool_call_result",
Json( chain=[
data={ Json(
"id": func_tool_id, data={
"ts": time.time(), "id": func_tool_id,
"result": last_tcr_content, "ts": time.time(),
} "result": last_tcr_content,
) }
], )
],
)
) )
logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}")
# 处理函数调用响应 # 处理函数调用响应
if tool_call_result_blocks: if tool_call_result_blocks:
yield tool_call_result_blocks yield _HandleFunctionToolsResult.from_tool_call_result_blocks(
tool_call_result_blocks
)
def _build_tool_requery_context( def _build_tool_requery_context(
self, tool_names: list[str] self, tool_names: list[str]
+162
View File
@@ -0,0 +1,162 @@
"""Tool image cache module for storing and retrieving images returned by tools.
This module allows LLM to review images before deciding whether to send them to users.
"""
import base64
import os
import time
from dataclasses import dataclass, field
from typing import ClassVar
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
@dataclass
class CachedImage:
"""Represents a cached image from a tool call."""
tool_call_id: str
"""The tool call ID that produced this image."""
tool_name: str
"""The name of the tool that produced this image."""
file_path: str
"""The file path where the image is stored."""
mime_type: str
"""The MIME type of the image."""
created_at: float = field(default_factory=time.time)
"""Timestamp when the image was cached."""
class ToolImageCache:
"""Manages cached images from tool calls.
Images are stored in data/temp/tool_images/ and can be retrieved by file path.
"""
_instance: ClassVar["ToolImageCache | None"] = None
CACHE_DIR_NAME: ClassVar[str] = "tool_images"
# Cache expiry time in seconds (1 hour)
CACHE_EXPIRY: ClassVar[int] = 3600
def __new__(cls) -> "ToolImageCache":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._cache_dir = os.path.join(get_astrbot_temp_path(), self.CACHE_DIR_NAME)
os.makedirs(self._cache_dir, exist_ok=True)
logger.debug(f"ToolImageCache initialized, cache dir: {self._cache_dir}")
def _get_file_extension(self, mime_type: str) -> str:
"""Get file extension from MIME type."""
mime_to_ext = {
"image/png": ".png",
"image/jpeg": ".jpg",
"image/jpg": ".jpg",
"image/gif": ".gif",
"image/webp": ".webp",
"image/bmp": ".bmp",
"image/svg+xml": ".svg",
}
return mime_to_ext.get(mime_type.lower(), ".png")
def save_image(
self,
base64_data: str,
tool_call_id: str,
tool_name: str,
index: int = 0,
mime_type: str = "image/png",
) -> CachedImage:
"""Save an image to cache and return the cached image info.
Args:
base64_data: Base64 encoded image data.
tool_call_id: The tool call ID that produced this image.
tool_name: The name of the tool that produced this image.
index: The index of the image (for multiple images from same tool call).
mime_type: The MIME type of the image.
Returns:
CachedImage object with file path.
"""
ext = self._get_file_extension(mime_type)
file_name = f"{tool_call_id}_{index}{ext}"
file_path = os.path.join(self._cache_dir, file_name)
# Decode and save the image
try:
image_bytes = base64.b64decode(base64_data)
with open(file_path, "wb") as f:
f.write(image_bytes)
logger.debug(f"Saved tool image to: {file_path}")
except Exception as e:
logger.error(f"Failed to save tool image: {e}")
raise
return CachedImage(
tool_call_id=tool_call_id,
tool_name=tool_name,
file_path=file_path,
mime_type=mime_type,
)
def get_image_base64_by_path(
self, file_path: str, mime_type: str = "image/png"
) -> tuple[str, str] | None:
"""Read an image file and return its base64 encoded data.
Args:
file_path: The file path of the cached image.
mime_type: The MIME type of the image.
Returns:
Tuple of (base64_data, mime_type) if found, None otherwise.
"""
if not os.path.exists(file_path):
return None
try:
with open(file_path, "rb") as f:
image_bytes = f.read()
base64_data = base64.b64encode(image_bytes).decode("utf-8")
return base64_data, mime_type
except Exception as e:
logger.error(f"Failed to read cached image {file_path}: {e}")
return None
def cleanup_expired(self) -> int:
"""Clean up expired cached images.
Returns:
Number of images cleaned up.
"""
now = time.time()
cleaned = 0
try:
for file_name in os.listdir(self._cache_dir):
file_path = os.path.join(self._cache_dir, file_name)
if os.path.isfile(file_path):
file_age = now - os.path.getmtime(file_path)
if file_age > self.CACHE_EXPIRY:
os.remove(file_path)
cleaned += 1
except Exception as e:
logger.warning(f"Error during cache cleanup: {e}")
if cleaned:
logger.info(f"Cleaned up {cleaned} expired cached images")
return cleaned
# Global singleton instance
tool_image_cache = ToolImageCache()