fix: ensure max_tokens is set and validate tool_calls type in ProviderAnthropic (#4212)
This commit is contained in:
@@ -68,7 +68,7 @@ class ProviderAnthropic(Provider):
|
||||
blocks = []
|
||||
if isinstance(message["content"], str):
|
||||
blocks.append({"type": "text", "text": message["content"]})
|
||||
if "tool_calls" in message:
|
||||
if "tool_calls" in message and isinstance(message["tool_calls"], list):
|
||||
for tool_call in message["tool_calls"]:
|
||||
blocks.append( # noqa: PERF401
|
||||
{
|
||||
@@ -132,6 +132,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
@@ -181,6 +184,9 @@ class ProviderAnthropic(Provider):
|
||||
usage = TokenUsage()
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
) as stream:
|
||||
@@ -342,11 +348,11 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
prompt=None,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
|
||||
Reference in New Issue
Block a user