refactor: revise LLM message schema and fix the reload logic when using dataclass-based LLM Tool registration (#3234)

* refactor: llm message schema

* feat: implement MCPTool and local LLM tools with enhanced context handling

* refactor: reorganize imports and enhance docstrings for clarity

* refactor: enhance ContentPart validation and add message pair handling in ConversationManager

* chore: ruff format

* refactor: remove debug print statement from payloads in ProviderOpenAIOfficial

* Update astrbot/core/agent/tool.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/message.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/message.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/tool.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/pipeline/process_stage/method/llm_request.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/message.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* refactor: enhance documentation and import mcp in tool.py; update call method return type

* fix: 修复以数据类的方式注册 tool 时的插件重载机制问题

* refactor: change role attributes to use Literal types for message segments

* fix: add support for 'decorator_handler' method in call_local_llm_tool

* fix: handle None prompt in text_chat method and ensure context is properly formatted

---------

Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Soulter
2025-11-02 18:12:20 +08:00
committed by GitHub
parent 94bf3b8195
commit 50144ddcae
20 changed files with 641 additions and 205 deletions
-2
View File
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Generic
import mcp
@@ -9,7 +8,6 @@ from astrbot.core.provider.entities import LLMResponse
from .run_context import ContextWrapper, TContext
@dataclass
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start(
+36
View File
@@ -2,10 +2,15 @@ import asyncio
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
from .run_context import TContext
from .tool import FunctionTool
try:
import mcp
from mcp.client.sse import sse_client
@@ -221,3 +226,34 @@ class MCPClient:
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done
class MCPTool(FunctionTool, Generic[TContext]):
"""A function tool that calls an MCP service."""
def __init__(
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
):
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
self.mcp_server_name = mcp_server_name
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> mcp.types.CallToolResult:
session = self.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=self.mcp_tool.name,
arguments=kwargs,
read_timeout_seconds=timedelta(
seconds=context.tool_call_timeout,
),
)
return res
+168
View File
@@ -0,0 +1,168 @@
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
# License: Apache License 2.0
from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import core_schema
class ContentPart(BaseModel):
"""A part of the content in a message."""
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
type: str
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
type_value = getattr(cls, "type", None)
if type_value is None or not isinstance(type_value, str):
raise ValueError(invalid_subclass_error_msg)
cls.__content_part_registry[type_value] = cls
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# If we're dealing with the base ContentPart class, use custom validation
if cls.__name__ == "ContentPart":
def validate_content_part(value: Any) -> Any:
# if it's already an instance of a ContentPart subclass, return it
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
return value
# if it's a dict with a type field, dispatch to the appropriate subclass
if isinstance(value, dict) and "type" in value:
type_value: Any | None = cast(dict[str, Any], value).get("type")
if not isinstance(type_value, str):
raise ValueError(f"Cannot validate {value} as ContentPart")
target_class = cls.__content_part_registry[type_value]
return target_class.model_validate(value)
raise ValueError(f"Cannot validate {value} as ContentPart")
return core_schema.no_info_plain_validator_function(validate_content_part)
# for subclasses, use the default schema
return handler(source_type)
class TextPart(ContentPart):
"""
>>> TextPart(text="Hello, world!").model_dump()
{'type': 'text', 'text': 'Hello, world!'}
"""
type: str = "text"
text: str
class ImageURLPart(ContentPart):
"""
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
"""
class ImageURL(BaseModel):
url: str
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
id: str | None = None
"""The ID of the image, to allow LLMs to distinguish different images."""
type: str = "image_url"
image_url: str
class AudioURLPart(ContentPart):
"""
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
"""
class AudioURL(BaseModel):
url: str
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
id: str | None = None
"""The ID of the audio, to allow LLMs to distinguish different audios."""
type: str = "audio_url"
audio_url: AudioURL
class ToolCall(BaseModel):
"""
A tool call requested by the assistant.
>>> ToolCall(
... id="123",
... function=ToolCall.FunctionBody(
... name="function",
... arguments="{}"
... ),
... ).model_dump()
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
"""
class FunctionBody(BaseModel):
name: str
arguments: str | None
type: Literal["function"] = "function"
id: str
"""The ID of the tool call."""
function: FunctionBody
"""The function body of the tool call."""
class ToolCallPart(BaseModel):
"""A part of the tool call."""
arguments_part: str | None = None
"""A part of the arguments of the tool call."""
class Message(BaseModel):
"""A message in a conversation."""
role: Literal[
"system",
"user",
"assistant",
"tool",
]
content: str | list[ContentPart]
"""The content of the message."""
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
role: Literal["assistant"] = "assistant"
tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message):
"""A message segment representing a tool call."""
role: Literal["tool"] = "tool"
tool_call_id: str
class UserMessageSegment(Message):
"""A message segment from the user."""
role: Literal["user"] = "user"
class SystemMessageSegment(Message):
"""A message segment from the system."""
role: Literal["system"] = "system"
+1 -3
View File
@@ -3,8 +3,6 @@ from typing import Any, Generic
from typing_extensions import TypeVar
from astrbot.core.platform.astr_message_event import AstrMessageEvent
TContext = TypeVar("TContext", default=Any)
@@ -13,7 +11,7 @@ class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
event: AstrMessageEvent
tool_call_timeout: int = 60 # Default tool call timeout in seconds
NoContext = ContextWrapper[None]
@@ -16,15 +16,14 @@ from astrbot.core.message.message_event_result import (
MessageChain,
)
from astrbot.core.provider.entities import (
AssistantMessageSegment,
LLMResponse,
ProviderRequest,
ToolCallMessageSegment,
ToolCallsResult,
)
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, ToolCallMessageSegment
from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
@@ -171,8 +170,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
tool_calls=llm_resp.to_openai_to_calls_model(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
@@ -238,7 +236,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
else:
# 如果没有 handler(如 MCP 工具),使用所有参数
valid_params = func_tool_args
logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数")
try:
await self.agent_hooks.on_tool_start(
@@ -319,13 +316,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
# 发送消息逻辑在 ToolExecutor 中处理了。
logger.warning(
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
)
self._transition_state(AgentState.DONE)
if res := self.run_context.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain,
type="tool_direct_result",
)
else:
# 不应该出现其他类型
logger.warning(
@@ -341,8 +336,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
self.run_context.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
+51 -27
View File
@@ -1,52 +1,76 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, Literal
from typing import Any, Generic
import jsonschema
import mcp
from deprecated import deprecated
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from .mcp_client import MCPClient
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
@dataclass
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
class ToolSchema:
"""A class representing the schema of a tool for function calling."""
name: str
parameters: dict | None = None
description: str | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
"""The name of the tool."""
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
description: str
"""The description of the tool."""
parameters: ParametersType
"""The parameters of the tool, in JSON Schema format."""
@model_validator(mode="after")
def validate_parameters(self) -> "ToolSchema":
jsonschema.validate(
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
)
return self
@dataclass
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None
"""a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None
"""
The module path of the handler function. This is empty when the origin is mcp.
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
causing the handler's __module__ to be functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str | None = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient | None = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
"""
Whether the tool is active. This field is a special field for AstrBot.
You can ignore it when integrating with other frameworks.
"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
def __dict__(self) -> dict[str, Any]:
"""将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> str | mcp.types.CallToolResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
class ToolSet:
"""A set of function tools that can be used in function calling.
@@ -225,7 +249,7 @@ class ToolSet:
tools = []
for tool in self.tools:
d = {
d: dict[str, Any] = {
"name": tool.name,
"description": tool.description,
}
+2 -1
View File
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@@ -10,4 +11,4 @@ class AstrAgentContext:
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
tool_call_timeout: int = 60 # Default tool call timeout in seconds
event: AstrMessageEvent
+36
View File
@@ -8,6 +8,7 @@ import json
from collections.abc import Awaitable, Callable
from astrbot.core import sp
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
@@ -319,6 +320,41 @@ class ConversationManager:
persona_id=persona_id,
)
async def add_message_pair(
self,
cid: str,
user_message: UserMessageSegment | dict,
assistant_message: AssistantMessageSegment | dict,
) -> None:
"""Add a user-assistant message pair to the conversation history.
Args:
cid (str): Conversation ID
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
Raises:
Exception: If the conversation with the given ID is not found
"""
conv = await self.db.get_conversation_by_id(cid=cid)
if not conv:
raise Exception(f"Conversation with id {cid} not found")
history = conv.content or []
if isinstance(user_message, UserMessageSegment):
user_msg_dict = user_message.model_dump()
else:
user_msg_dict = user_message
if isinstance(assistant_message, AssistantMessageSegment):
assistant_msg_dict = assistant_message.model_dump()
else:
assistant_msg_dict = assistant_message
history.append(user_msg_dict)
history.append(assistant_msg_dict)
await self.db.update_conversation(
cid=cid,
content=history,
)
async def get_human_readable_context(
self,
unified_msg_origin: str,
+2 -1
View File
@@ -3,7 +3,7 @@ from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
@dataclass
@@ -15,3 +15,4 @@ class PipelineContext:
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
call_local_llm_tool = call_local_llm_tool
+65
View File
@@ -3,6 +3,8 @@ import traceback
import typing as T
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
@@ -105,3 +107,66 @@ async def call_event_hook(
return True
return event.is_stopped()
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None(无返回值)
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
@@ -5,11 +5,14 @@ import copy
import json
import traceback
from collections.abc import AsyncGenerator
from datetime import timedelta
from typing import Any
from mcp.types import CallToolResult
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
@@ -33,7 +36,7 @@ from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_handler
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
from ..stage import Stage
from ..utils import inject_kb_context
@@ -65,17 +68,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
yield r
return
if tool.origin == "local":
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
elif tool.origin == "mcp":
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
raise Exception(f"Unknown function origin: {tool.origin}")
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
@classmethod
async def _execute_handoff(
@@ -113,10 +114,13 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
event=run_context.context.event,
)
event = run_context.context.event
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await run_context.event.send(
await event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
)
@@ -125,7 +129,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
event=run_context.event,
tool_call_timeout=run_context.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
@@ -175,25 +179,46 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not run_context.event:
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run"):
raise ValueError("Tool must have a valid handler or 'run' method.")
awaitable = tool.handler or tool.run
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
logger.debug(f"Found call in: {ty}")
is_override_call = True
break
wrapper = call_handler(
event=run_context.event,
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
# async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.context.tool_call_timeout,
timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
@@ -208,10 +233,24 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds.",
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@@ -223,19 +262,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
read_timeout_seconds=timedelta(
seconds=run_context.context.tool_call_timeout,
),
)
res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
@@ -245,11 +272,20 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.event,
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
MAIN_AGENT_HOOKS = MainAgentHooks()
@@ -260,7 +296,7 @@ async def run_agent(
show_tool_use: bool = True,
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.event
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
@@ -513,12 +549,15 @@ class LLMRequestSubStage(Stage):
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
tool_call_timeout=self.tool_call_timeout,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
+36 -52
View File
@@ -4,15 +4,17 @@ import json
from dataclasses import dataclass, field
from typing import Any
from anthropic.types import Message
from anthropic.types import Message as AnthropicMessage
from google.genai.types import GenerateContentResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.agent.message import (
AssistantMessageSegment,
ToolCall,
ToolCallMessageSegment,
)
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
@@ -32,9 +34,9 @@ class ProviderMetaData:
type: str
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""提供商适配器描述."""
"""提供商适配器描述"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: type | None = None
cls_type: Any = None
default_config_tmpl: dict | None = None
"""平台的默认配置模板"""
@@ -42,44 +44,6 @@ class ProviderMetaData:
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
class ToolCallMessageSegment:
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
tool_call_id: str
content: str
role: str = "tool"
def to_dict(self):
return {
"tool_call_id": self.tool_call_id,
"content": self.content,
"role": self.role,
}
@dataclass
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str | None = None
tool_calls: list[ChatCompletionMessageToolCall | dict] = field(default_factory=list)
role: str = "assistant"
def to_dict(self):
ret: dict[str, str | list[dict]] = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
if self.tool_calls:
tool_calls_dict = [
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
]
ret["tool_calls"] = tool_calls_dict
return ret
@dataclass
class ToolCallsResult:
"""工具调用结果"""
@@ -91,8 +55,8 @@ class ToolCallsResult:
def to_openai_messages(self) -> list[dict]:
ret = [
self.tool_calls_info.to_dict(),
*[item.to_dict() for item in self.tool_calls_result],
self.tool_calls_info.model_dump(),
*[item.model_dump() for item in self.tool_calls_result],
]
return ret
@@ -108,16 +72,16 @@ class ProviderRequest:
func_tool: ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致:
"""
OpenAI 格式上下文列表。
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation | None = None
"""关联的对话对象"""
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
model: str | None = None
"""模型名称,为 None 时使用提供商的默认模型"""
@@ -227,7 +191,9 @@ class LLMResponse:
tools_call_ids: list[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None
raw_completion: (
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
) = None
_new_record: dict[str, Any] | None = None
_completion_text: str = ""
@@ -243,7 +209,10 @@ class LLMResponse:
tools_call_args: list[dict[str, Any]] | None = None,
tools_call_name: list[str] | None = None,
tools_call_ids: list[str] | None = None,
raw_completion: ChatCompletion | None = None,
raw_completion: ChatCompletion
| GenerateContentResponse
| AnthropicMessage
| None = None,
_new_record: dict[str, Any] | None = None,
is_chunk: bool = False,
):
@@ -294,7 +263,7 @@ class LLMResponse:
self._completion_text = value
def to_openai_tool_calls(self) -> list[dict]:
"""将工具调用信息转换为 OpenAI 格式"""
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
@@ -309,6 +278,21 @@ class LLMResponse:
)
return ret
def to_openai_to_calls_model(self) -> list[ToolCall]:
"""The same as to_openai_tool_calls but return pydantic model."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
ToolCall(
id=self.tools_call_ids[idx],
function=ToolCall.FunctionBody(
name=self.tools_call_name[idx],
arguments=json.dumps(tool_call_arg),
),
),
)
return ret
@dataclass
class RerankResult:
+10 -11
View File
@@ -10,7 +10,7 @@ import aiohttp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.agent.mcp_client import MCPClient
from astrbot.core.agent.mcp_client import MCPClient, MCPTool
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -254,18 +254,15 @@ class FunctionToolManager:
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
func_tool = FuncTool(
name=tool.name,
parameters=tool.inputSchema,
description=tool.description,
origin="mcp",
mcp_server_name=name,
func_tool = MCPTool(
mcp_tool=tool,
mcp_client=mcp_client,
mcp_server_name=name,
)
self.func_list.append(func_tool)
@@ -284,7 +281,7 @@ class FunctionToolManager:
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
@@ -374,7 +371,7 @@ class FunctionToolManager:
self.func_list = [
f
for f in self.func_list
if f.origin != "mcp" or f.mcp_server_name != name
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
else:
running_events = [
@@ -388,7 +385,9 @@ class FunctionToolManager:
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
self.func_list = [
f for f in self.func_list if not isinstance(f, MCPTool)
]
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
+37 -13
View File
@@ -3,6 +3,7 @@ import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Personality
from astrbot.core.provider.entities import (
@@ -23,24 +24,28 @@ class ProviderMeta:
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
"""设置当前使用的模型名称"""
"""Set the current model name"""
self.model_name = model_name
def get_model(self) -> str:
"""获得当前使用的模型名称"""
"""Get the current model name"""
return self.model_name
def meta(self) -> ProviderMeta:
"""获取 Provider 的元数据"""
"""Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
provider_type = meta_data.provider_type if meta_data else None
if provider_type is None:
raise ValueError(f"Cannot find provider type: {provider_type_name}")
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
@@ -50,6 +55,8 @@ class AbstractProvider(abc.ABC):
class Provider(AbstractProvider):
"""Chat Provider"""
def __init__(
self,
provider_config: dict,
@@ -84,11 +91,11 @@ class Provider(AbstractProvider):
@abc.abstractmethod
async def text_chat(
self,
prompt: str,
prompt: str | None = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list | None = None,
contexts: list[Message] | list[dict] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
@@ -97,11 +104,11 @@ class Provider(AbstractProvider):
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Args:
prompt: 提示词
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tools: tool set
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
@@ -114,11 +121,11 @@ class Provider(AbstractProvider):
async def text_chat_stream(
self,
prompt: str,
prompt: str | None = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list | None = None,
contexts: list[Message] | list[dict] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
@@ -127,11 +134,11 @@ class Provider(AbstractProvider):
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
prompt: 提示词
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tools: tool set
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
@@ -140,6 +147,7 @@ class Provider(AbstractProvider):
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
...
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
@@ -156,6 +164,22 @@ class Provider(AbstractProvider):
for idx in reversed(indexs_to_pop):
context.pop(idx)
def _ensure_message_to_dicts(
self,
messages: list[dict] | list[Message] | None,
) -> list[dict]:
"""Convert a list of Message objects to a list of dictionaries."""
if not messages:
return []
dicts: list[dict] = []
for message in messages:
if isinstance(message, Message):
dicts.append(message.model_dump())
else:
dicts.append(message)
return dicts
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -243,7 +243,7 @@ class ProviderAnthropic(Provider):
async def text_chat(
self,
prompt,
prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -255,8 +255,13 @@ class ProviderAnthropic(Provider):
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -306,8 +311,12 @@ class ProviderAnthropic(Provider):
):
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -331,6 +331,7 @@ class ProviderCoze(Provider):
},
)
contexts = self._ensure_message_to_dicts(contexts)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
+14 -7
View File
@@ -572,7 +572,7 @@ class ProviderGoogleGenAI(Provider):
async def text_chat(
self,
prompt: str,
prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -584,8 +584,12 @@ class ProviderGoogleGenAI(Provider):
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -621,7 +625,7 @@ class ProviderGoogleGenAI(Provider):
async def text_chat_stream(
self,
prompt,
prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -633,8 +637,12 @@ class ProviderGoogleGenAI(Provider):
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -726,7 +734,6 @@ class ProviderGoogleGenAI(Provider):
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
async def terminate(self):
logger.info("Google GenAI 适配器已终止。")
+19 -12
View File
@@ -14,9 +14,10 @@ from openai.types.chat.chat_completion import ChatCompletion
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
@@ -102,7 +103,7 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
@@ -153,7 +154,7 @@ class ProviderOpenAIOfficial(Provider):
async def _query_stream(
self,
payloads: dict,
tools: ToolSet,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
@@ -212,7 +213,9 @@ class ProviderOpenAIOfficial(Provider):
yield llm_response
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
async def parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
"""解析 OpenAI 的 ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
@@ -225,7 +228,7 @@ class ProviderOpenAIOfficial(Provider):
completion_text = str(choice.message.content).strip()
llm_response.result_chain = MessageChain().message(completion_text)
if choice.message.tool_calls:
if choice.message.tool_calls and tools is not None:
# tools call (function calling)
args_ls = []
func_name_ls = []
@@ -267,9 +270,9 @@ class ProviderOpenAIOfficial(Provider):
async def _prepare_chat_payload(
self,
prompt: str,
prompt: str | None,
image_urls: list[str] | None = None,
contexts: list | None = None,
contexts: list[dict] | list[Message] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
@@ -278,8 +281,12 @@ class ProviderOpenAIOfficial(Provider):
"""准备聊天所需的有效载荷和上下文"""
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -310,7 +317,7 @@ class ProviderOpenAIOfficial(Provider):
e: Exception,
payloads: dict,
context_query: list,
func_tool: ToolSet,
func_tool: ToolSet | None,
chosen_key: str,
available_api_keys: list[str],
retry_cnt: int,
@@ -390,7 +397,7 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat(
self,
prompt,
prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
@@ -459,7 +466,7 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat_stream(
self,
prompt: str,
prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
+41 -16
View File
@@ -1,3 +1,4 @@
import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from typing import Any
@@ -35,6 +36,8 @@ from .filter.regex import RegexFilter
from .star import StarMetadata, star_map, star_registry
from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
logger = logging.getLogger("astrbot")
class Context:
"""暴露给插件的接口上下文。"""
@@ -255,9 +258,44 @@ class Context:
def add_llm_tools(self, *tools: FunctionTool) -> None:
"""添加 LLM 工具。"""
tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list}
module_path = ""
for tool in tools:
if not module_path:
_parts = []
module_part = tool.__module__.split(".")
flags = ["packages", "plugins"]
for i, part in enumerate(module_part):
_parts.append(part)
if part in flags and i + 1 < len(module_part):
_parts.append(module_part[i + 1])
break
tool.handler_module_path = ".".join(_parts)
module_path = tool.handler_module_path
else:
tool.handler_module_path = module_path
logger.info(
f"plugin(module_path {module_path}) added LLM tool: {tool.name}"
)
if tool.name in tool_name:
logger.warning("替换已存在的 LLM 工具: " + tool.name)
self.provider_manager.llm_tools.remove_func(tool.name)
self.provider_manager.llm_tools.func_list.append(tool)
def register_web_api(
self,
route: str,
view_handler: Awaitable,
methods: list,
desc: str,
):
for idx, api in enumerate(self.registered_web_apis):
if api[0] == route and methods == api[2]:
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
return
self.registered_web_apis.append((route, view_handler, methods, desc))
"""
以下的方法已经不推荐使用请从 AstrBot 文档查看更好的注册方式
"""
@@ -269,7 +307,7 @@ class Context:
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
"""为函数调用(function-calling / tools-use)添加工具。
"""[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@@ -291,7 +329,7 @@ class Context:
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
def unregister_llm_tool(self, name: str) -> None:
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
"""[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。"""
self.provider_manager.llm_tools.remove_func(name)
def register_commands(
@@ -333,18 +371,5 @@ class Context:
star_handlers_registry.append(md)
def register_task(self, task: Awaitable, desc: str):
"""注册一个异步任务。"""
"""[DEPRECATED]注册一个异步任务。"""
self._register_tasks.append(task)
def register_web_api(
self,
route: str,
view_handler: Awaitable,
methods: list,
desc: str,
):
for idx, api in enumerate(self.registered_web_apis):
if api[0] == route and methods == api[2]:
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
return
self.registered_web_apis.append((route, view_handler, methods, desc))
+25 -4
View File
@@ -751,6 +751,19 @@ class PluginManager:
]:
del star_handlers_registry.star_handlers_map[k]
# llm_tools 中移除该插件的工具函数绑定
to_remove = []
for func_tool in llm_tools.func_list:
mp = func_tool.handler_module_path
if (
mp
and mp.startswith(plugin_module_path)
and not mp.endswith(("packages", "data.plugins"))
):
to_remove.append(func_tool)
for func_tool in to_remove:
llm_tools.func_list.remove(func_tool)
if plugin is None:
return
@@ -795,7 +808,13 @@ class PluginManager:
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
mp = func_tool.handler_module_path
if (
plugin.module_path
and mp
and plugin.module_path.startswith(mp)
and not mp.endswith(("packages", "data.plugins"))
):
func_tool.active = False
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
@@ -838,8 +857,12 @@ class PluginManager:
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
mp = func_tool.handler_module_path
if (
func_tool.handler_module_path == plugin.module_path
plugin.module_path
and mp
and plugin.module_path.startswith(mp)
and not mp.endswith(("packages", "data.plugins"))
and func_tool.name in inactivated_llm_tools
):
inactivated_llm_tools.remove(func_tool.name)
@@ -848,8 +871,6 @@ class PluginManager:
await self.reload(plugin_name)
# plugin.activated = True
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()