6d76d55452
* fix: ensure tool call/response pairing in context truncation * refactor: simplify fix_messages to single-pass state machine
183 lines
6.0 KiB
Python
183 lines
6.0 KiB
Python
from ..message import Message
|
|
|
|
|
|
class ContextTruncator:
|
|
"""Context truncator."""
|
|
|
|
def _has_tool_calls(self, message: Message) -> bool:
|
|
"""Check if a message contains tool calls."""
|
|
return (
|
|
message.role == "assistant"
|
|
and message.tool_calls is not None
|
|
and len(message.tool_calls) > 0
|
|
)
|
|
|
|
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
|
"""修复消息列表,确保 tool call 和 tool response 的配对关系有效。
|
|
|
|
此方法确保:
|
|
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
|
|
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
|
|
|
|
这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。
|
|
"""
|
|
if not messages:
|
|
return messages
|
|
|
|
fixed_messages: list[Message] = []
|
|
pending_assistant: Message | None = None
|
|
pending_tools: list[Message] = []
|
|
|
|
def flush_pending_if_valid() -> None:
|
|
nonlocal pending_assistant, pending_tools
|
|
if pending_assistant is not None and pending_tools:
|
|
fixed_messages.append(pending_assistant)
|
|
fixed_messages.extend(pending_tools)
|
|
pending_assistant = None
|
|
pending_tools = []
|
|
|
|
for msg in messages:
|
|
if msg.role == "tool":
|
|
# 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
|
|
if pending_assistant is not None:
|
|
pending_tools.append(msg)
|
|
# else: 孤立的 tool 消息,直接忽略
|
|
continue
|
|
|
|
if self._has_tool_calls(msg):
|
|
# 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链
|
|
flush_pending_if_valid()
|
|
pending_assistant = msg
|
|
continue
|
|
|
|
# 非 tool,且不含 tool_calls 的消息
|
|
# 先结束任何 pending 链,再正常追加
|
|
flush_pending_if_valid()
|
|
fixed_messages.append(msg)
|
|
|
|
# 结束时处理最后一个 pending 链
|
|
flush_pending_if_valid()
|
|
|
|
return fixed_messages
|
|
|
|
def truncate_by_turns(
|
|
self,
|
|
messages: list[Message],
|
|
keep_most_recent_turns: int,
|
|
drop_turns: int = 1,
|
|
) -> list[Message]:
|
|
"""截断上下文列表,确保不超过最大长度。
|
|
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
|
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
|
|
|
Args:
|
|
messages: 上下文列表
|
|
keep_most_recent_turns: 保留最近的对话轮数
|
|
drop_turns: 一次性丢弃的对话轮数
|
|
|
|
Returns:
|
|
截断后的上下文列表
|
|
"""
|
|
if keep_most_recent_turns == -1:
|
|
return messages
|
|
|
|
first_non_system = 0
|
|
for i, msg in enumerate(messages):
|
|
if msg.role != "system":
|
|
first_non_system = i
|
|
break
|
|
|
|
system_messages = messages[:first_non_system]
|
|
non_system_messages = messages[first_non_system:]
|
|
|
|
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
|
return messages
|
|
|
|
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
|
if num_to_keep <= 0:
|
|
truncated_contexts = []
|
|
else:
|
|
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
|
|
|
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
|
index = next(
|
|
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
|
None,
|
|
)
|
|
if index is not None and index > 0:
|
|
truncated_contexts = truncated_contexts[index:]
|
|
|
|
result = system_messages + truncated_contexts
|
|
|
|
return self.fix_messages(result)
|
|
|
|
def truncate_by_dropping_oldest_turns(
|
|
self,
|
|
messages: list[Message],
|
|
drop_turns: int = 1,
|
|
) -> list[Message]:
|
|
"""丢弃最旧的 N 个对话轮次。"""
|
|
if drop_turns <= 0:
|
|
return messages
|
|
|
|
first_non_system = 0
|
|
for i, msg in enumerate(messages):
|
|
if msg.role != "system":
|
|
first_non_system = i
|
|
break
|
|
|
|
system_messages = messages[:first_non_system]
|
|
non_system_messages = messages[first_non_system:]
|
|
|
|
if len(non_system_messages) // 2 <= drop_turns:
|
|
truncated_non_system = []
|
|
else:
|
|
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
|
|
|
index = next(
|
|
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
|
None,
|
|
)
|
|
if index is not None:
|
|
truncated_non_system = truncated_non_system[index:]
|
|
elif truncated_non_system:
|
|
truncated_non_system = []
|
|
|
|
result = system_messages + truncated_non_system
|
|
|
|
return self.fix_messages(result)
|
|
|
|
def truncate_by_halving(
|
|
self,
|
|
messages: list[Message],
|
|
) -> list[Message]:
|
|
"""对半砍策略,删除 50% 的消息"""
|
|
if len(messages) <= 2:
|
|
return messages
|
|
|
|
first_non_system = 0
|
|
for i, msg in enumerate(messages):
|
|
if msg.role != "system":
|
|
first_non_system = i
|
|
break
|
|
|
|
system_messages = messages[:first_non_system]
|
|
non_system_messages = messages[first_non_system:]
|
|
|
|
messages_to_delete = len(non_system_messages) // 2
|
|
if messages_to_delete == 0:
|
|
return messages
|
|
|
|
truncated_non_system = non_system_messages[messages_to_delete:]
|
|
|
|
index = next(
|
|
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
|
None,
|
|
)
|
|
if index is not None:
|
|
truncated_non_system = truncated_non_system[index:]
|
|
|
|
result = system_messages + truncated_non_system
|
|
|
|
return self.fix_messages(result)
|