From e2c26c292ddd27df72d96df40d5de2df60e0daf0 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Wed, 23 Apr 2025 19:55:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A4=84=E7=90=86MCP=E8=BF=94=E5=9B=9E?= =?UTF-8?q?ImageContent=E3=80=81EmbeddedResource=E7=9A=84=E6=83=85?= =?UTF-8?q?=E5=86=B5=EF=BC=8C=E6=8F=90=E4=BE=9B=E7=AE=80=E5=8D=95fallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 93 ++++++++++++++++--- 1 file changed, 80 insertions(+), 13 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index fd70275d8..28745f2c5 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -26,6 +26,13 @@ from astrbot.core.provider.entities import ( ) from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map +from mcp.types import ( + TextContent, + ImageContent, + EmbeddedResource, + TextResourceContents, + BlobResourceContents, +) class LLMRequestSubStage(Stage): @@ -66,9 +73,9 @@ class LLMRequestSubStage(Stage): if event.get_extra("provider_request"): req = event.get_extra("provider_request") - assert isinstance( - req, ProviderRequest - ), "provider_request 必须是 ProviderRequest 类型。" + assert isinstance(req, ProviderRequest), ( + "provider_request 必须是 ProviderRequest 类型。" + ) if req.conversation: all_contexts = json.loads(req.conversation.history) @@ -149,7 +156,14 @@ class LLMRequestSubStage(Stage): -(self.max_context_length - self.dequeue_context_length + 1) * 2 : ] # 找到第一个role 为 user 的索引,确保上下文格式正确 - index = next((i for i, item in enumerate(req.contexts) if item.get("role") == "user"), None) + index = next( + ( + i + for i, item in enumerate(req.contexts) + if item.get("role") == "user" + ), + None, + ) if index is not None and index > 0: req.contexts = req.contexts[index:] @@ -265,6 +279,12 @@ class LLMRequestSubStage(Stage): event.set_extra("tool_call_result", None) yield + # 暂时直接发出去 + if img_b64 := event.get_extra("tool_call_img_respond"): + await event.send(MessageChain(chain=[Image.fromBase64(img_b64)])) + event.set_extra("tool_call_img_respond", None) + yield + async def _handle_llm_response( self, event: AstrMessageEvent, @@ -375,21 +395,68 @@ class LLMRequestSubStage(Stage): client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] res = await client.session.call_tool(func_tool.name, func_tool_args) if res: - # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=res.content[0].text, + # TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback + if isinstance(res.content[0], TextContent): + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) ) - ) + elif isinstance(res.content[0], ImageContent): + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + event.set_extra( + "tool_call_img_respond", + res.content[0].data, + ) + elif isinstance(res.content[0], EmbeddedResource): + resource = res.content[0].resource + if isinstance(resource, TextResourceContents): + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resource.text, + ) + ) + elif ( + isinstance(resource, BlobResourceContents) + and resource.mimeType + and resource.mimeType.startswith("image/") + ): + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + event.set_extra( + "tool_call_img_respond", + res.content[0].data, + ) + else: + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回的数据类型不受支持", + ) + ) else: # 获取处理器,过滤掉平台不兼容的处理器 platform_id = event.get_platform_id() star_md = star_map.get(func_tool.handler_module_path) if ( - star_md and - platform_id in star_md.supported_platforms + star_md + and platform_id in star_md.supported_platforms and not star_md.supported_platforms[platform_id] ): logger.debug(