From 72f917d611315e9d41ab5f3a1ef35a15130ba283 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Mon, 7 Apr 2025 17:31:57 +0800 Subject: [PATCH 1/7] =?UTF-8?q?fix:=20gemini=E5=8F=AA=E5=9C=A8content?= =?UTF-8?q?=E4=B8=8D=E4=B8=BA=E7=A9=BA=E7=9A=84=E6=97=B6=E5=80=99=E5=8A=A0?= =?UTF-8?q?=E5=85=A5=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9f5f7c3c1..7ae418938 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -146,12 +146,10 @@ class ProviderGoogleGenAI(Provider): for message in payloads["messages"]: if message["role"] == "user": if isinstance(message["content"], str): - if not message["content"]: - message["content"] = "" - - google_genai_conversation.append( - {"role": "user", "parts": [{"text": message["content"]}]} - ) + if message["content"]: + google_genai_conversation.append( + {"role": "user", "parts": [{"text": message["content"]}]} + ) elif isinstance(message["content"], list): # images parts = [] @@ -175,11 +173,10 @@ class ProviderGoogleGenAI(Provider): elif message["role"] == "assistant": if "content" in message: - if not message["content"]: - message["content"] = "" - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} - ) + if message["content"]: + google_genai_conversation.append( + {"role": "model", "parts": [{"text": message["content"]}]} + ) elif "tool_calls" in message: # tool calls in the last turn parts = [] From b9a983f8e0555641f94b6b4aef4ac85f6c45ee46 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Mon, 7 Apr 2025 17:45:35 +0800 Subject: [PATCH 2/7] =?UTF-8?q?fix:=20=E4=B8=BA=E5=87=BD=E6=95=B0=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E5=8E=86=E5=8F=B2=E8=AE=B0=E5=BD=95=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=A0=87=E8=AE=B0,=20=E4=B8=8D=E8=AF=BB=E5=8F=96=E5=85=A5?= =?UTF-8?q?=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 674a7fd79..987e82570 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -58,12 +58,16 @@ class LLMRequestSubStage(Stage): if event.get_extra("provider_request"): req = event.get_extra("provider_request") - assert isinstance(req, ProviderRequest), ( - "provider_request 必须是 ProviderRequest 类型。" - ) + assert isinstance( + req, ProviderRequest + ), "provider_request 必须是 ProviderRequest 类型。" if req.conversation: - req.contexts = json.loads(req.conversation.history) + all_contexts = json.loads(req.conversation.history) + # 对函数工具调用做过滤 + req.contexts = [ + msg for msg in all_contexts if "_tool_call_history" not in msg + ] else: req = ProviderRequest(prompt="", image_urls=[]) if self.provider_wake_prefix: @@ -312,9 +316,12 @@ class LLMRequestSubStage(Stage): contexts = req.contexts contexts.append(await req.assemble_context()) - # tool calls result + # 记录并标记函数调用结果 if req.tool_calls_result: - contexts.extend(req.tool_calls_result.to_openai_messages()) + tool_calls_messages = req.tool_calls_result.to_openai_messages() + for message in tool_calls_messages: + message["_tool_call_history"] = True + contexts.extend(tool_calls_messages) contexts.append( {"role": "assistant", "content": llm_response.completion_text} From d88420dd03da6c04cca56e15f21fc57660de1f0e Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Mon, 7 Apr 2025 17:55:12 +0800 Subject: [PATCH 3/7] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E4=BA=BA=E7=B1=BB=E5=8F=AF=E8=AF=BB=E7=9A=84=E4=B8=8A=E4=B8=8B?= =?UTF-8?q?=E6=96=87=E7=9A=84=E9=80=BB=E8=BE=91,=20=E5=8C=BA=E5=88=86?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E8=B0=83=E7=94=A8(=E6=97=A0contents)?= =?UTF-8?q?=E5=92=8C=E4=B8=80=E8=88=AC=E6=B6=88=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/conversation_mgr.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index c506fa8f1..b0f5c136d 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -175,7 +175,15 @@ class ConversationManager: if record["role"] == "user": temp_contexts.append(f"User: {record['content']}") elif record["role"] == "assistant": - temp_contexts.append(f"Assistant: {record['content']}") + if "content" in record and record["content"]: + temp_contexts.append(f"Assistant: {record['content']}") + elif "tool_calls" in record: + tool_calls_str = json.dumps( + record["tool_calls"], ensure_ascii=False + ) + temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}") + else: + temp_contexts.append("Assistant: [未知的内容]") contexts.insert(0, temp_contexts) temp_contexts = [] From 5a001871470ba0906eab5457ec1547204d7a5012 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Mon, 7 Apr 2025 18:14:30 +0800 Subject: [PATCH 4/7] =?UTF-8?q?fix:=20=E5=AF=B9=E5=8E=86=E5=8F=B2=E8=AE=B0?= =?UTF-8?q?=E5=BD=95=E7=9A=84toolcall=E9=AA=8C=E8=AF=81=E6=98=AF=E5=90=A6?= =?UTF-8?q?=E6=88=90=E5=AF=B9,=20=E5=8F=82=E8=80=83:=20https://github.com/?= =?UTF-8?q?run-llama/llama=5Findex/issues/13715=20https://github.com/run-l?= =?UTF-8?q?lama/llama=5Findex/pull/16214?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 63 +++++++++++++++++-- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 987e82570..99b09b5a1 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -64,10 +64,48 @@ class LLMRequestSubStage(Stage): if req.conversation: all_contexts = json.loads(req.conversation.history) - # 对函数工具调用做过滤 - req.contexts = [ - msg for msg in all_contexts if "_tool_call_history" not in msg - ] + req.contexts = [] + i = 0 + while i < len(all_contexts): + current_msg = all_contexts[i] + # 普通消息 + if "_tool_call_history" not in current_msg: + req.contexts.append(current_msg) + i += 1 + continue + + # 工具调用消息, 必须成对出现 + if ( + current_msg.get("role") == "assistant" + and "tool_calls" in current_msg + ): + # 寻找tool响应 + assistant_msg = current_msg.copy() + # 移除标记 + if "_tool_call_history" in assistant_msg: + del assistant_msg["_tool_call_history"] + + related_tools = [] + j = i + 1 + while ( + j < len(all_contexts) + and all_contexts[j].get("role") == "tool" + and "_tool_call_history" in all_contexts[j] + ): + tool_msg = all_contexts[j].copy() + del tool_msg["_tool_call_history"] + related_tools.append(tool_msg) + j += 1 + + # 只添加成对的tool_call和tool响应 + if related_tools: + req.contexts.append(assistant_msg) + req.contexts.extend(related_tools) + # 已处理的消息跳过 + i = j + else: + i += 1 + else: req = ProviderRequest(prompt="", image_urls=[]) if self.provider_wake_prefix: @@ -313,15 +351,28 @@ class LLMRequestSubStage(Stage): if llm_response.role == "assistant": # 文本回复 - contexts = req.contexts + contexts = req.contexts.copy() contexts.append(await req.assemble_context()) # 记录并标记函数调用结果 if req.tool_calls_result: tool_calls_messages = req.tool_calls_result.to_openai_messages() + + # 对顺序的验证 + assistant_msgs = [] + tool_msgs = [] + for message in tool_calls_messages: message["_tool_call_history"] = True - contexts.extend(tool_calls_messages) + + if message.get("role") == "assistant": + assistant_msgs.append(message) + elif message.get("role") == "tool": + tool_msgs.append(message) + + # 先添加assistant再添加tool + contexts.extend(assistant_msgs) + contexts.extend(tool_msgs) contexts.append( {"role": "assistant", "content": llm_response.completion_text} From 7cd1eeac309417149035a99d9e9595ec2aa3a27d Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Tue, 8 Apr 2025 15:57:38 +0000 Subject: [PATCH 5/7] =?UTF-8?q?fix:=20=E7=9B=B4=E6=8E=A5=E6=8A=8A=E7=A9=BA?= =?UTF-8?q?=E5=AD=97=E7=AC=A6=E4=B8=B2=E6=94=B9=E4=B8=BA"=20"=E4=B8=80?= =?UTF-8?q?=E6=9D=A1=E6=B6=88=E6=81=AF=E7=9A=84content=E6=98=AF=E7=A9=BA?= =?UTF-8?q?=E5=AD=97=E7=AC=A6=E4=B8=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 7ae418938..11f3f7eaa 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -146,10 +146,12 @@ class ProviderGoogleGenAI(Provider): for message in payloads["messages"]: if message["role"] == "user": if isinstance(message["content"], str): - if message["content"]: - google_genai_conversation.append( - {"role": "user", "parts": [{"text": message["content"]}]} - ) + if not message["content"]: + message["content"] = " " + + google_genai_conversation.append( + {"role": "user", "parts": [{"text": message["content"]}]} + ) elif isinstance(message["content"], list): # images parts = [] @@ -173,10 +175,11 @@ class ProviderGoogleGenAI(Provider): elif message["role"] == "assistant": if "content" in message: - if message["content"]: - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} - ) + if not message["content"]: + message["content"] = " " + google_genai_conversation.append( + {"role": "model", "parts": [{"text": message["content"]}]} + ) elif "tool_calls" in message: # tool calls in the last turn parts = [] From 87c3aff4ce3f37e67ffc21ad49c315b2c303dc0d Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Thu, 10 Apr 2025 11:25:03 +0800 Subject: [PATCH 6/7] =?UTF-8?q?perf:=20=E7=AE=80=E5=8C=96llm=5Frequest?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E6=B6=88=E6=81=AF=E6=88=90?= =?UTF-8?q?=E5=AF=B9=E9=AA=8C=E8=AF=81=E9=80=BB=E8=BE=91,=20=E5=90=88?= =?UTF-8?q?=E5=B9=B6=E4=B8=A4=E5=A4=84=E9=AA=8C=E8=AF=81=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E5=88=B0=E4=B8=80=E4=B8=AA=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 118 ++++++++++-------- 1 file changed, 65 insertions(+), 53 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 8a7062fd7..756f10b66 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -67,48 +67,10 @@ class LLMRequestSubStage(Stage): ), "provider_request 必须是 ProviderRequest 类型。" if req.conversation: - all_contexts = json.loads(req.conversation.history) - req.contexts = [] - i = 0 - while i < len(all_contexts): - current_msg = all_contexts[i] - # 普通消息 - if "_tool_call_history" not in current_msg: - req.contexts.append(current_msg) - i += 1 - continue - - # 工具调用消息, 必须成对出现 - if ( - current_msg.get("role") == "assistant" - and "tool_calls" in current_msg - ): - # 寻找tool响应 - assistant_msg = current_msg.copy() - # 移除标记 - if "_tool_call_history" in assistant_msg: - del assistant_msg["_tool_call_history"] - - related_tools = [] - j = i + 1 - while ( - j < len(all_contexts) - and all_contexts[j].get("role") == "tool" - and "_tool_call_history" in all_contexts[j] - ): - tool_msg = all_contexts[j].copy() - del tool_msg["_tool_call_history"] - related_tools.append(tool_msg) - j += 1 - - # 只添加成对的tool_call和tool响应 - if related_tools: - req.contexts.append(assistant_msg) - req.contexts.extend(related_tools) - # 已处理的消息跳过 - i = j - else: - i += 1 + all_contexts = json.load(req.conversation.history) + req.contexts = self._process_tool_message_pairs( + all_contexts, remove_tags=True + ) else: req = ProviderRequest(prompt="", image_urls=[]) @@ -485,21 +447,15 @@ class LLMRequestSubStage(Stage): if req.tool_calls_result: tool_calls_messages = req.tool_calls_result.to_openai_messages() - # 对顺序的验证 - assistant_msgs = [] - tool_msgs = [] - + # 添加标记 for message in tool_calls_messages: message["_tool_call_history"] = True - if message.get("role") == "assistant": - assistant_msgs.append(message) - elif message.get("role") == "tool": - tool_msgs.append(message) + processed_tool_messages = self._process_tool_message_pairs( + tool_calls_messages, remove_tags=False + ) - # 先添加assistant再添加tool - contexts.extend(assistant_msgs) - contexts.extend(tool_msgs) + contexts.extend(processed_tool_messages) contexts.append( {"role": "assistant", "content": llm_response.completion_text} @@ -510,3 +466,59 @@ class LLMRequestSubStage(Stage): await self.conv_manager.update_conversation( event.unified_msg_origin, req.conversation.cid, history=contexts_to_save ) + + def _process_tool_message_pairs(self, messages, remove_tags=True): + """处理工具调用消息,确保assistant和tool消息成对出现 + + Args: + messages (list): 消息列表 + remove_tags (bool): 是否移除_tool_call_history标记 + + Returns: + list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现 + """ + result = [] + i = 0 + + while i < len(messages): + current_msg = messages[i] + + # 普通消息直接添加 + if "_tool_call_history" not in current_msg: + result.append(current_msg.copy() if remove_tags else current_msg) + i += 1 + continue + + # 工具调用消息成对处理 + if current_msg.get("role") == "assistant" and "tool_calls" in current_msg: + assistant_msg = current_msg.copy() + + if remove_tags and "_tool_call_history" in assistant_msg: + del assistant_msg["_tool_call_history"] + + related_tools = [] + j = i + 1 + while ( + j < len(messages) + and messages[j].get("role") == "tool" + and "_tool_call_history" in messages[j] + ): + tool_msg = messages[j].copy() + + if remove_tags: + del tool_msg["_tool_call_history"] + + related_tools.append(tool_msg) + j += 1 + + # 成对的时候添加到结果 + if related_tools: + result.append(assistant_msg) + result.extend(related_tools) + + i = j # 跳过已处理 + else: + # 单独的tool消息 + i += 1 + + return result From bdf25976a34508afb39363e9d4d664fd5c08d4b5 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Thu, 10 Apr 2025 11:28:47 +0800 Subject: [PATCH 7/7] =?UTF-8?q?fix:=20=E5=B0=91=E6=89=93=E4=B8=80=E4=B8=AA?= =?UTF-8?q?=E5=AD=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/pipeline/process_stage/method/llm_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 756f10b66..b7d279996 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -67,7 +67,7 @@ class LLMRequestSubStage(Stage): ), "provider_request 必须是 ProviderRequest 类型。" if req.conversation: - all_contexts = json.load(req.conversation.history) + all_contexts = json.loads(req.conversation.history) req.contexts = self._process_tool_message_pairs( all_contexts, remove_tags=True )