refactor: revise LLM message schema and fix the reload logic when using dataclass-based LLM Tool registration (#3234)
* refactor: llm message schema * feat: implement MCPTool and local LLM tools with enhanced context handling * refactor: reorganize imports and enhance docstrings for clarity * refactor: enhance ContentPart validation and add message pair handling in ConversationManager * chore: ruff format * refactor: remove debug print statement from payloads in ProviderOpenAIOfficial * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/pipeline/process_stage/method/llm_request.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: enhance documentation and import mcp in tool.py; update call method return type * fix: 修复以数据类的方式注册 tool 时的插件重载机制问题 * refactor: change role attributes to use Literal types for message segments * fix: add support for 'decorator_handler' method in call_local_llm_tool * fix: handle None prompt in text_chat method and ensure context is properly formatted --------- Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -14,9 +14,10 @@ from openai.types.chat.chat_completion import ChatCompletion
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
@@ -102,7 +103,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
model = payloads.get("model", "").lower()
|
||||
omit_empty_param_field = "gemini" in model
|
||||
@@ -153,7 +154,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def _query_stream(
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: ToolSet,
|
||||
tools: ToolSet | None,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式查询API,逐步返回结果"""
|
||||
if tools:
|
||||
@@ -212,7 +213,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
yield llm_response
|
||||
|
||||
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
|
||||
async def parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: ToolSet | None
|
||||
) -> LLMResponse:
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
@@ -225,7 +228,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
completion_text = str(choice.message.content).strip()
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
if choice.message.tool_calls:
|
||||
if choice.message.tool_calls and tools is not None:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
@@ -267,9 +270,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def _prepare_chat_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: str | None,
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
contexts: list[dict] | list[Message] | None = None,
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
@@ -278,8 +281,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
@@ -310,7 +317,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
e: Exception,
|
||||
payloads: dict,
|
||||
context_query: list,
|
||||
func_tool: ToolSet,
|
||||
func_tool: ToolSet | None,
|
||||
chosen_key: str,
|
||||
available_api_keys: list[str],
|
||||
retry_cnt: int,
|
||||
@@ -390,7 +397,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt,
|
||||
prompt=None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
@@ -459,7 +466,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt=None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
|
||||
Reference in New Issue
Block a user