fix: anthropic api error when using tools

This commit is contained in:
Soulter
2025-06-29 15:33:08 +08:00
parent a55a07c5ff
commit 20b760529e
6 changed files with 47 additions and 29 deletions
@@ -2,12 +2,16 @@ import abc
import typing as T
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: dict
data: AgentResponseData
class BaseAgentRunner:
@@ -1,7 +1,7 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse
from .base import BaseAgentRunner, AgentResponse, AgentResponseData
from ...context import PipelineContext
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -70,18 +70,16 @@ class ToolLoopAgent(BaseAgentRunner):
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data={
"chain": llm_response.result_chain.chain,
},
data=AgentResponseData(chain=llm_response.result_chain),
)
else:
yield AgentResponse(
type="streaming_delta",
data={
"chain": MessageChain().message(
data=AgentResponseData(
chain=MessageChain().message(
llm_response.completion_text
),
},
)
),
)
else:
llm_resp_result = llm_response
@@ -105,11 +103,11 @@ class ToolLoopAgent(BaseAgentRunner):
self.is_done = True
yield AgentResponse(
type="err",
data={
"chain": MessageChain().message(
data=AgentResponseData(
chain=MessageChain().message(
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
),
},
)
),
)
if not llm_resp.tools_call_name:
@@ -121,16 +119,14 @@ class ToolLoopAgent(BaseAgentRunner):
if llm_resp.result_chain:
yield AgentResponse(
type="llm_result",
data={
"chain": llm_resp.result_chain.chain,
},
data=AgentResponseData(chain=llm_resp.result_chain),
)
elif llm_resp.completion_text:
yield AgentResponse(
type="llm_result",
data={
"chain": MessageChain().message(llm_resp.completion_text),
},
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text)
),
)
# 如果有工具调用,还需处理工具调用
@@ -142,14 +138,14 @@ class ToolLoopAgent(BaseAgentRunner):
elif isinstance(result, MessageChain):
yield AgentResponse(
type="tool_call_result",
data={
"chain": result.chain,
},
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
role="assistant", tool_calls=llm_resp.to_openai_tool_calls()
role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
)
@@ -163,14 +163,14 @@ class LLMRequestSubStage(Stage):
)
event.set_result(
MessageEventResult(
chain=resp.data["chain"],
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
event.clear_result()
else:
yield resp.data["chain"]
yield resp.data["chain"].chain
if tool_loop_agent.done():
break
except Exception as e:
+2 -2
View File
@@ -58,7 +58,7 @@ class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
role: str = "assistant"
def to_dict(self):
@@ -67,7 +67,7 @@ class AssistantMessageSegment:
}
if self.content:
ret["content"] = self.content
elif self.tool_calls:
if self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
+1
View File
@@ -4,6 +4,7 @@ from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from dataclasses import dataclass
from deprecated import deprecated
class Personality(TypedDict):
@@ -32,9 +32,9 @@ class ProviderAnthropic(Provider):
default_persona,
)
self.chosen_api_key = None
self.chosen_api_key: str = ""
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
@@ -97,6 +97,9 @@ class ProviderAnthropic(Provider):
)
else:
new_messages.append(message)
logger.debug(f"message: {messages}")
logger.debug(f"new message: {new_messages}")
return system_prompt, new_messages
@@ -368,3 +371,17 @@ class ProviderAnthropic(Provider):
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
def get_current_key(self) -> str:
return self.chosen_api_key
async def get_models(self) -> List[str]:
models_str = []
models = await self.client.models.list()
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
return models_str
def set_key(self, key: str):
self.chosen_api_key = key