fix: anthropic api error when using tools
This commit is contained in:
@@ -2,12 +2,16 @@ import abc
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ....message.message_event_result import MessageChain
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: dict
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
class BaseAgentRunner:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentResponseData
|
||||
from ...context import PipelineContext
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -70,18 +70,16 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data={
|
||||
"chain": llm_response.result_chain.chain,
|
||||
},
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data={
|
||||
"chain": MessageChain().message(
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
llm_response.completion_text
|
||||
),
|
||||
},
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
llm_resp_result = llm_response
|
||||
@@ -105,11 +103,11 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
self.is_done = True
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data={
|
||||
"chain": MessageChain().message(
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
|
||||
),
|
||||
},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if not llm_resp.tools_call_name:
|
||||
@@ -121,16 +119,14 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if llm_resp.result_chain:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data={
|
||||
"chain": llm_resp.result_chain.chain,
|
||||
},
|
||||
data=AgentResponseData(chain=llm_resp.result_chain),
|
||||
)
|
||||
elif llm_resp.completion_text:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data={
|
||||
"chain": MessageChain().message(llm_resp.completion_text),
|
||||
},
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_resp.completion_text)
|
||||
),
|
||||
)
|
||||
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
@@ -142,14 +138,14 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
elif isinstance(result, MessageChain):
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data={
|
||||
"chain": result.chain,
|
||||
},
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
role="assistant", tool_calls=llm_resp.to_openai_tool_calls()
|
||||
role="assistant",
|
||||
tool_calls=llm_resp.to_openai_tool_calls(),
|
||||
content=llm_resp.completion_text,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
|
||||
@@ -163,14 +163,14 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"],
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
)
|
||||
)
|
||||
yield
|
||||
event.clear_result()
|
||||
else:
|
||||
yield resp.data["chain"]
|
||||
yield resp.data["chain"].chain
|
||||
if tool_loop_agent.done():
|
||||
break
|
||||
except Exception as e:
|
||||
|
||||
@@ -58,7 +58,7 @@ class AssistantMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
content: str = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
|
||||
role: str = "assistant"
|
||||
|
||||
def to_dict(self):
|
||||
@@ -67,7 +67,7 @@ class AssistantMessageSegment:
|
||||
}
|
||||
if self.content:
|
||||
ret["content"] = self.content
|
||||
elif self.tool_calls:
|
||||
if self.tool_calls:
|
||||
ret["tool_calls"] = self.tool_calls
|
||||
return ret
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from dataclasses import dataclass
|
||||
from deprecated import deprecated
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
|
||||
@@ -32,9 +32,9 @@ class ProviderAnthropic(Provider):
|
||||
default_persona,
|
||||
)
|
||||
|
||||
self.chosen_api_key = None
|
||||
self.chosen_api_key: str = ""
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
@@ -97,6 +97,9 @@ class ProviderAnthropic(Provider):
|
||||
)
|
||||
else:
|
||||
new_messages.append(message)
|
||||
|
||||
logger.debug(f"message: {messages}")
|
||||
logger.debug(f"new message: {new_messages}")
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
@@ -368,3 +371,17 @@ class ProviderAnthropic(Provider):
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ""
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.chosen_api_key
|
||||
|
||||
async def get_models(self) -> List[str]:
|
||||
models_str = []
|
||||
models = await self.client.models.list()
|
||||
models = sorted(models.data, key=lambda x: x.id)
|
||||
for model in models:
|
||||
models_str.append(model.id)
|
||||
return models_str
|
||||
|
||||
def set_key(self, key: str):
|
||||
self.chosen_api_key = key
|
||||
|
||||
Reference in New Issue
Block a user