fix: ensure tool call/response pairing in context truncation (#5417)

* fix: ensure tool call/response pairing in context truncation

* refactor: simplify fix_messages to single-pass state machine
This commit is contained in:
Luna_Dol
2026-02-25 15:21:30 +08:00
committed by GitHub
parent c7d318304b
commit 6d76d55452
+53 -12
View File
@@ -4,19 +4,60 @@ from ..message import Message
class ContextTruncator:
"""Context truncator."""
def _has_tool_calls(self, message: Message) -> bool:
"""Check if a message contains tool calls."""
return (
message.role == "assistant"
and message.tool_calls is not None
and len(message.tool_calls) > 0
)
def fix_messages(self, messages: list[Message]) -> list[Message]:
fixed_messages = []
for message in messages:
if message.role == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
"""修复消息列表,确保 tool call 和 tool response 的配对关系有效。
此方法确保:
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。
"""
if not messages:
return messages
fixed_messages: list[Message] = []
pending_assistant: Message | None = None
pending_tools: list[Message] = []
def flush_pending_if_valid() -> None:
nonlocal pending_assistant, pending_tools
if pending_assistant is not None and pending_tools:
fixed_messages.append(pending_assistant)
fixed_messages.extend(pending_tools)
pending_assistant = None
pending_tools = []
for msg in messages:
if msg.role == "tool":
# 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
if pending_assistant is not None:
pending_tools.append(msg)
# else: 孤立的 tool 消息,直接忽略
continue
if self._has_tool_calls(msg):
# 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链
flush_pending_if_valid()
pending_assistant = msg
continue
# 非 tool,且不含 tool_calls 的消息
# 先结束任何 pending 链,再正常追加
flush_pending_if_valid()
fixed_messages.append(msg)
# 结束时处理最后一个 pending 链
flush_pending_if_valid()
return fixed_messages
def truncate_by_turns(