perf: 简化llm_request工具调用消息成对验证逻辑, 合并两处验证逻辑到一个函数
This commit is contained in:
@@ -67,48 +67,10 @@ class LLMRequestSubStage(Stage):
|
||||
), "provider_request 必须是 ProviderRequest 类型。"
|
||||
|
||||
if req.conversation:
|
||||
all_contexts = json.loads(req.conversation.history)
|
||||
req.contexts = []
|
||||
i = 0
|
||||
while i < len(all_contexts):
|
||||
current_msg = all_contexts[i]
|
||||
# 普通消息
|
||||
if "_tool_call_history" not in current_msg:
|
||||
req.contexts.append(current_msg)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 工具调用消息, 必须成对出现
|
||||
if (
|
||||
current_msg.get("role") == "assistant"
|
||||
and "tool_calls" in current_msg
|
||||
):
|
||||
# 寻找tool响应
|
||||
assistant_msg = current_msg.copy()
|
||||
# 移除标记
|
||||
if "_tool_call_history" in assistant_msg:
|
||||
del assistant_msg["_tool_call_history"]
|
||||
|
||||
related_tools = []
|
||||
j = i + 1
|
||||
while (
|
||||
j < len(all_contexts)
|
||||
and all_contexts[j].get("role") == "tool"
|
||||
and "_tool_call_history" in all_contexts[j]
|
||||
):
|
||||
tool_msg = all_contexts[j].copy()
|
||||
del tool_msg["_tool_call_history"]
|
||||
related_tools.append(tool_msg)
|
||||
j += 1
|
||||
|
||||
# 只添加成对的tool_call和tool响应
|
||||
if related_tools:
|
||||
req.contexts.append(assistant_msg)
|
||||
req.contexts.extend(related_tools)
|
||||
# 已处理的消息跳过
|
||||
i = j
|
||||
else:
|
||||
i += 1
|
||||
all_contexts = json.load(req.conversation.history)
|
||||
req.contexts = self._process_tool_message_pairs(
|
||||
all_contexts, remove_tags=True
|
||||
)
|
||||
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
@@ -485,21 +447,15 @@ class LLMRequestSubStage(Stage):
|
||||
if req.tool_calls_result:
|
||||
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||
|
||||
# 对顺序的验证
|
||||
assistant_msgs = []
|
||||
tool_msgs = []
|
||||
|
||||
# 添加标记
|
||||
for message in tool_calls_messages:
|
||||
message["_tool_call_history"] = True
|
||||
|
||||
if message.get("role") == "assistant":
|
||||
assistant_msgs.append(message)
|
||||
elif message.get("role") == "tool":
|
||||
tool_msgs.append(message)
|
||||
processed_tool_messages = self._process_tool_message_pairs(
|
||||
tool_calls_messages, remove_tags=False
|
||||
)
|
||||
|
||||
# 先添加assistant再添加tool
|
||||
contexts.extend(assistant_msgs)
|
||||
contexts.extend(tool_msgs)
|
||||
contexts.extend(processed_tool_messages)
|
||||
|
||||
contexts.append(
|
||||
{"role": "assistant", "content": llm_response.completion_text}
|
||||
@@ -510,3 +466,59 @@ class LLMRequestSubStage(Stage):
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||
)
|
||||
|
||||
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||
|
||||
Args:
|
||||
messages (list): 消息列表
|
||||
remove_tags (bool): 是否移除_tool_call_history标记
|
||||
|
||||
Returns:
|
||||
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
current_msg = messages[i]
|
||||
|
||||
# 普通消息直接添加
|
||||
if "_tool_call_history" not in current_msg:
|
||||
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 工具调用消息成对处理
|
||||
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||
assistant_msg = current_msg.copy()
|
||||
|
||||
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||
del assistant_msg["_tool_call_history"]
|
||||
|
||||
related_tools = []
|
||||
j = i + 1
|
||||
while (
|
||||
j < len(messages)
|
||||
and messages[j].get("role") == "tool"
|
||||
and "_tool_call_history" in messages[j]
|
||||
):
|
||||
tool_msg = messages[j].copy()
|
||||
|
||||
if remove_tags:
|
||||
del tool_msg["_tool_call_history"]
|
||||
|
||||
related_tools.append(tool_msg)
|
||||
j += 1
|
||||
|
||||
# 成对的时候添加到结果
|
||||
if related_tools:
|
||||
result.append(assistant_msg)
|
||||
result.extend(related_tools)
|
||||
|
||||
i = j # 跳过已处理
|
||||
else:
|
||||
# 单独的tool消息
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user