refactor: unify extra_user_content_parts type to ContentPart across providers and update related handling

This commit is contained in:
Soulter
2025-12-26 21:47:02 +08:00
parent 05012af627
commit 7c1dbecea5
6 changed files with 18 additions and 23 deletions
@@ -77,11 +77,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages,
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
if self.streaming:
+2 -9
View File
@@ -93,9 +93,7 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
extra_user_content_parts: list[dict] | list[ContentPart] = field(
default_factory=list
)
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
func_tool: ToolSet | None = None
"""可用的函数工具"""
@@ -184,12 +182,7 @@ class ProviderRequest:
# 2. 额外的内容块(系统提醒、指令等)
if self.extra_user_content_parts:
for part in self.extra_user_content_parts:
if hasattr(part, "model_dump"):
# ContentPart 对象,需要 model_dump
content_blocks.append(part.model_dump())
else:
# 已经是 dict
content_blocks.append(part)
content_blocks.append(part.model_dump())
# 3. 图片内容
if self.image_urls:
+2 -2
View File
@@ -4,7 +4,7 @@ import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
@@ -103,7 +103,7 @@ class Provider(AbstractProvider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[dict] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -11,6 +11,7 @@ from anthropic.types.usage import Usage
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
@@ -398,7 +399,7 @@ class ProviderAnthropic(Provider):
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[dict] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文,支持文本和图片"""
content = []
@@ -417,9 +418,7 @@ class ProviderAnthropic(Provider):
if extra_user_content_parts:
# 过滤出文本块,因为 Anthropic 主要支持文本和图片
text_blocks = [
block
for block in extra_user_content_parts
if block.get("type") == "text"
block for block in extra_user_content_parts if block.type == "text"
]
content.extend(text_blocks)
@@ -13,6 +13,7 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
@@ -807,7 +808,7 @@ class ProviderGoogleGenAI(Provider):
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[dict] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文。"""
# 构建内容块列表
@@ -825,7 +826,8 @@ class ProviderGoogleGenAI(Provider):
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
content_blocks.extend(extra_user_content_parts)
for part in extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if image_urls:
@@ -17,7 +17,7 @@ from openai.types.completion_usage import CompletionUsage
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
@@ -348,7 +348,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[dict] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -629,7 +629,7 @@ class ProviderOpenAIOfficial(Provider):
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[dict] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
# 构建内容块列表
@@ -647,7 +647,8 @@ class ProviderOpenAIOfficial(Provider):
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
content_blocks.extend(extra_user_content_parts)
for part in extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if image_urls: