From c8f567347bafe91fb2e882f28a9b1add2a980680 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Wed, 23 Apr 2025 11:52:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9=E9=87=8D=E6=8E=92?= =?UTF-8?q?=E5=BA=8F=E9=80=BB=E8=BE=91=E4=B8=BA=E5=90=88=E5=B9=B6=E8=BF=9E?= =?UTF-8?q?=E7=BB=AD=E7=9B=B8=E5=90=8C=E7=B1=BB=E5=9E=8B=E7=9A=84=E6=B6=88?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 97 +++++++++---------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 8600d3e9f..d8d4c789c 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -182,11 +182,11 @@ class ProviderGoogleGenAI(Provider): def _prepare_conversation(self, payloads: Dict) -> List[types.Content]: """准备 Gemini SDK 的 Content 列表""" - def create_text_part(text: str) -> types.UserContent: + def create_text_part(text: str) -> types.Part: content_a = text if text else " " if not text: logger.warning("文本内容为空,已添加空格占位") - return types.UserContent(parts=[types.Part.from_text(text=content_a)]) + return types.Part.from_text(text=content_a) def process_image_url(image_url_dict: dict) -> types.Part: url = image_url_dict["url"] @@ -205,75 +205,66 @@ class ProviderGoogleGenAI(Provider): role, content = message["role"], message.get("content") if role == "user": - if isinstance(content, str): - gemini_contents.append(create_text_part(content)) - elif isinstance(content, list): + if isinstance(content, list): parts = [ types.Part.from_text(text=item["text"] or " ") if item["type"] == "text" else process_image_url(item["image_url"]) for item in content ] + else: + parts = [create_text_part(content)] + + if gemini_contents and isinstance(gemini_contents[-1], types.UserContent): + gemini_contents[-1].parts.extend(parts) + else: gemini_contents.append(types.UserContent(parts=parts)) elif role == "assistant": if content: - gemini_contents.append( - types.ModelContent(parts=[types.Part.from_text(text=content)]) - ) - elif "tool_calls" in message and not native_tool_enabled: - gemini_contents.extend( - [ - types.ModelContent( - parts=[ - types.Part.from_function_call( - name=tool["function"]["name"], - args=json.loads(tool["function"]["arguments"]), - ) - ] - ) - for tool in message["tool_calls"] - ] - ) + parts = [types.Part.from_text(text=content)] + if gemini_contents and isinstance(gemini_contents[-1], types.ModelContent): + gemini_contents[-1].parts.extend(parts) + else: + gemini_contents.append(types.ModelContent(parts=parts)) + elif not native_tool_enabled and "tool_calls" in message : + parts = [ + types.Part.from_function_call( + name=tool["function"]["name"], + args=json.loads(tool["function"]["arguments"]), + ) + for tool in message["tool_calls"] + ] + if gemini_contents and isinstance(gemini_contents[-1], types.ModelContent): + gemini_contents[-1].parts.extend(parts) + else: + gemini_contents.append(types.ModelContent(parts=parts)) else: logger.warning("assistant 角色的消息内容为空,已添加空格占位") - if native_tool_enabled: + if native_tool_enabled and "tool_calls" in message: logger.warning( "检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文" ) - gemini_contents.append( - types.ModelContent(parts=[types.Part.from_text(text=" ")]) - ) + parts = [types.Part.from_text(text=" ")] + if gemini_contents and isinstance(gemini_contents[-1], types.ModelContent): + gemini_contents[-1].parts.extend(parts) + else: + gemini_contents.append(types.ModelContent(parts=parts)) elif role == "tool" and not native_tool_enabled: - gemini_contents.append( - types.UserContent( - parts=[ - types.Part.from_function_response( - name=message["tool_call_id"], - response={ - "name": message["tool_call_id"], - "content": message["content"], - }, - ) - ] + parts = [ + types.Part.from_function_response( + name=message["tool_call_id"], + response={ + "name": message["tool_call_id"], + "content": message["content"], + }, ) - ) - - # 保证偶数索引为用户消息,奇数索引为模型消息 - content_num = len(gemini_contents) - for i in range(content_num): - expected_type = types.UserContent if i % 2 == 0 else types.ModelContent - if isinstance(gemini_contents[i], expected_type): - continue - for j in range(i + 1, content_num): - if isinstance(gemini_contents[j], expected_type): - logger.debug(f"交换索引 {i} 与 {j}") - gemini_contents[i], gemini_contents[j] = ( - gemini_contents[j], - gemini_contents[i], - ) - break + ] + if gemini_contents and isinstance(gemini_contents[-1], types.UserContent): + gemini_contents[-1].parts.extend(parts) + else: + gemini_contents.append(types.UserContent(parts=parts)) return gemini_contents