From 20b760529e975e1277462eb79732c1c6dfcfb9a2 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 29 Jun 2025 15:33:08 +0800 Subject: [PATCH] fix: anthropic api error when using tools --- .../process_stage/agent_runner/base.py | 6 ++- .../agent_runner/tool_loop_agent.py | 40 +++++++++---------- .../process_stage/method/llm_request.py | 4 +- astrbot/core/provider/entities.py | 4 +- astrbot/core/provider/provider.py | 1 + .../core/provider/sources/anthropic_source.py | 21 +++++++++- 6 files changed, 47 insertions(+), 29 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/agent_runner/base.py b/astrbot/core/pipeline/process_stage/agent_runner/base.py index fd3c7d4ae..d694d767a 100644 --- a/astrbot/core/pipeline/process_stage/agent_runner/base.py +++ b/astrbot/core/pipeline/process_stage/agent_runner/base.py @@ -2,12 +2,16 @@ import abc import typing as T from dataclasses import dataclass from astrbot.core.provider.entities import LLMResponse +from ....message.message_event_result import MessageChain +class AgentResponseData(T.TypedDict): + chain: MessageChain + @dataclass class AgentResponse: type: str - data: dict + data: AgentResponseData class BaseAgentRunner: diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py index d26701d6b..d6d3add83 100644 --- a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py +++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py @@ -1,7 +1,7 @@ import sys import traceback import typing as T -from .base import BaseAgentRunner, AgentResponse +from .base import BaseAgentRunner, AgentResponse, AgentResponseData from ...context import PipelineContext from astrbot.core.provider.provider import Provider from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -70,18 +70,16 @@ class ToolLoopAgent(BaseAgentRunner): if llm_response.result_chain: yield AgentResponse( type="streaming_delta", - data={ - "chain": llm_response.result_chain.chain, - }, + data=AgentResponseData(chain=llm_response.result_chain), ) else: yield AgentResponse( type="streaming_delta", - data={ - "chain": MessageChain().message( + data=AgentResponseData( + chain=MessageChain().message( llm_response.completion_text - ), - }, + ) + ), ) else: llm_resp_result = llm_response @@ -105,11 +103,11 @@ class ToolLoopAgent(BaseAgentRunner): self.is_done = True yield AgentResponse( type="err", - data={ - "chain": MessageChain().message( + data=AgentResponseData( + chain=MessageChain().message( f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}" - ), - }, + ) + ), ) if not llm_resp.tools_call_name: @@ -121,16 +119,14 @@ class ToolLoopAgent(BaseAgentRunner): if llm_resp.result_chain: yield AgentResponse( type="llm_result", - data={ - "chain": llm_resp.result_chain.chain, - }, + data=AgentResponseData(chain=llm_resp.result_chain), ) elif llm_resp.completion_text: yield AgentResponse( type="llm_result", - data={ - "chain": MessageChain().message(llm_resp.completion_text), - }, + data=AgentResponseData( + chain=MessageChain().message(llm_resp.completion_text) + ), ) # 如果有工具调用,还需处理工具调用 @@ -142,14 +138,14 @@ class ToolLoopAgent(BaseAgentRunner): elif isinstance(result, MessageChain): yield AgentResponse( type="tool_call_result", - data={ - "chain": result.chain, - }, + data=AgentResponseData(chain=result), ) # 将结果添加到上下文中 tool_calls_result = ToolCallsResult( tool_calls_info=AssistantMessageSegment( - role="assistant", tool_calls=llm_resp.to_openai_tool_calls() + role="assistant", + tool_calls=llm_resp.to_openai_tool_calls(), + content=llm_resp.completion_text, ), tool_calls_result=tool_call_result_blocks, ) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 7d5fa8135..26d0877d1 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -163,14 +163,14 @@ class LLMRequestSubStage(Stage): ) event.set_result( MessageEventResult( - chain=resp.data["chain"], + chain=resp.data["chain"].chain, result_content_type=content_typ, ) ) yield event.clear_result() else: - yield resp.data["chain"] + yield resp.data["chain"].chain if tool_loop_agent.done(): break except Exception as e: diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 7dd96e8d0..abb01960c 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -58,7 +58,7 @@ class AssistantMessageSegment: """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" content: str = None - tool_calls: List[ChatCompletionMessageToolCall | Dict] = None + tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list) role: str = "assistant" def to_dict(self): @@ -67,7 +67,7 @@ class AssistantMessageSegment: } if self.content: ret["content"] = self.content - elif self.tool_calls: + if self.tool_calls: ret["tool_calls"] = self.tool_calls return ret diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 1ecca3537..cfe5748da 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -4,6 +4,7 @@ from typing import TypedDict, AsyncGenerator from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from dataclasses import dataclass +from deprecated import deprecated class Personality(TypedDict): diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 18b4cfead..83ba58cc5 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -32,9 +32,9 @@ class ProviderAnthropic(Provider): default_persona, ) - self.chosen_api_key = None + self.chosen_api_key: str = "" self.api_keys: List = provider_config.get("key", []) - self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None + self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else "" self.base_url = provider_config.get("api_base", "https://api.anthropic.com") self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): @@ -97,6 +97,9 @@ class ProviderAnthropic(Provider): ) else: new_messages.append(message) + + logger.debug(f"message: {messages}") + logger.debug(f"new message: {new_messages}") return system_prompt, new_messages @@ -368,3 +371,17 @@ class ProviderAnthropic(Provider): image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 return "" + + def get_current_key(self) -> str: + return self.chosen_api_key + + async def get_models(self) -> List[str]: + models_str = [] + models = await self.client.models.list() + models = sorted(models.data, key=lambda x: x.id) + for model in models: + models_str.append(model.id) + return models_str + + def set_key(self, key: str): + self.chosen_api_key = key