diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 949ebd3fe..d834240b7 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Generic import mcp @@ -9,7 +8,6 @@ from astrbot.core.provider.entities import LLMResponse from .run_context import ContextWrapper, TContext -@dataclass class BaseAgentRunHooks(Generic[TContext]): async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... async def on_tool_start( diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 303973a0d..05980b212 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,10 +2,15 @@ import asyncio import logging from contextlib import AsyncExitStack from datetime import timedelta +from typing import Generic from astrbot import logger +from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe +from .run_context import TContext +from .tool import FunctionTool + try: import mcp from mcp.client.sse import sse_client @@ -221,3 +226,34 @@ class MCPClient: """Clean up resources""" await self.exit_stack.aclose() self.running_event.set() # Set the running event to indicate cleanup is done + + +class MCPTool(FunctionTool, Generic[TContext]): + """A function tool that calls an MCP service.""" + + def __init__( + self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs + ): + super().__init__( + name=mcp_tool.name, + description=mcp_tool.description or "", + parameters=mcp_tool.inputSchema, + ) + self.mcp_tool = mcp_tool + self.mcp_client = mcp_client + self.mcp_server_name = mcp_server_name + + async def call( + self, context: ContextWrapper[TContext], **kwargs + ) -> mcp.types.CallToolResult: + session = self.mcp_client.session + if not session: + raise ValueError("MCP session is not available for MCP function tools.") + res = await session.call_tool( + name=self.mcp_tool.name, + arguments=kwargs, + read_timeout_seconds=timedelta( + seconds=context.tool_call_timeout, + ), + ) + return res diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py new file mode 100644 index 000000000..11128c0f6 --- /dev/null +++ b/astrbot/core/agent/message.py @@ -0,0 +1,168 @@ +# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation. +# License: Apache License 2.0 + +from typing import Any, ClassVar, Literal, cast + +from pydantic import BaseModel, GetCoreSchemaHandler +from pydantic_core import core_schema + + +class ContentPart(BaseModel): + """A part of the content in a message.""" + + __content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {} + + type: str + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + + invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" + + type_value = getattr(cls, "type", None) + if type_value is None or not isinstance(type_value, str): + raise ValueError(invalid_subclass_error_msg) + + cls.__content_part_registry[type_value] = cls + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + # If we're dealing with the base ContentPart class, use custom validation + if cls.__name__ == "ContentPart": + + def validate_content_part(value: Any) -> Any: + # if it's already an instance of a ContentPart subclass, return it + if hasattr(value, "__class__") and issubclass(value.__class__, cls): + return value + + # if it's a dict with a type field, dispatch to the appropriate subclass + if isinstance(value, dict) and "type" in value: + type_value: Any | None = cast(dict[str, Any], value).get("type") + if not isinstance(type_value, str): + raise ValueError(f"Cannot validate {value} as ContentPart") + target_class = cls.__content_part_registry[type_value] + return target_class.model_validate(value) + + raise ValueError(f"Cannot validate {value} as ContentPart") + + return core_schema.no_info_plain_validator_function(validate_content_part) + + # for subclasses, use the default schema + return handler(source_type) + + +class TextPart(ContentPart): + """ + >>> TextPart(text="Hello, world!").model_dump() + {'type': 'text', 'text': 'Hello, world!'} + """ + + type: str = "text" + text: str + + +class ImageURLPart(ContentPart): + """ + >>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump() + {'type': 'image_url', 'image_url': 'http://example.com/image.jpg'} + """ + + class ImageURL(BaseModel): + url: str + """The URL of the image, can be data URI scheme like `data:image/png;base64,...`.""" + id: str | None = None + """The ID of the image, to allow LLMs to distinguish different images.""" + + type: str = "image_url" + image_url: str + + +class AudioURLPart(ContentPart): + """ + >>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump() + {'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}} + """ + + class AudioURL(BaseModel): + url: str + """The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`.""" + id: str | None = None + """The ID of the audio, to allow LLMs to distinguish different audios.""" + + type: str = "audio_url" + audio_url: AudioURL + + +class ToolCall(BaseModel): + """ + A tool call requested by the assistant. + + >>> ToolCall( + ... id="123", + ... function=ToolCall.FunctionBody( + ... name="function", + ... arguments="{}" + ... ), + ... ).model_dump() + {'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}} + """ + + class FunctionBody(BaseModel): + name: str + arguments: str | None + + type: Literal["function"] = "function" + + id: str + """The ID of the tool call.""" + function: FunctionBody + """The function body of the tool call.""" + + +class ToolCallPart(BaseModel): + """A part of the tool call.""" + + arguments_part: str | None = None + """A part of the arguments of the tool call.""" + + +class Message(BaseModel): + """A message in a conversation.""" + + role: Literal[ + "system", + "user", + "assistant", + "tool", + ] + + content: str | list[ContentPart] + """The content of the message.""" + + +class AssistantMessageSegment(Message): + """A message segment from the assistant.""" + + role: Literal["assistant"] = "assistant" + tool_calls: list[ToolCall] | list[dict] | None = None + + +class ToolCallMessageSegment(Message): + """A message segment representing a tool call.""" + + role: Literal["tool"] = "tool" + tool_call_id: str + + +class UserMessageSegment(Message): + """A message segment from the user.""" + + role: Literal["user"] = "user" + + +class SystemMessageSegment(Message): + """A message segment from the system.""" + + role: Literal["system"] = "system" diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py index 634735ccc..395817679 100644 --- a/astrbot/core/agent/run_context.py +++ b/astrbot/core/agent/run_context.py @@ -3,8 +3,6 @@ from typing import Any, Generic from typing_extensions import TypeVar -from astrbot.core.platform.astr_message_event import AstrMessageEvent - TContext = TypeVar("TContext", default=Any) @@ -13,7 +11,7 @@ class ContextWrapper(Generic[TContext]): """A context for running an agent, which can be used to pass additional data or state.""" context: TContext - event: AstrMessageEvent + tool_call_timeout: int = 60 # Default tool call timeout in seconds NoContext = ContextWrapper[None] diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index cb89fb612..23071d446 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -16,15 +16,14 @@ from astrbot.core.message.message_event_result import ( MessageChain, ) from astrbot.core.provider.entities import ( - AssistantMessageSegment, LLMResponse, ProviderRequest, - ToolCallMessageSegment, ToolCallsResult, ) from astrbot.core.provider.provider import Provider from ..hooks import BaseAgentRunHooks +from ..message import AssistantMessageSegment, ToolCallMessageSegment from ..response import AgentResponseData from ..run_context import ContextWrapper, TContext from ..tool_executor import BaseFunctionToolExecutor @@ -171,8 +170,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): # 将结果添加到上下文中 tool_calls_result = ToolCallsResult( tool_calls_info=AssistantMessageSegment( - role="assistant", - tool_calls=llm_resp.to_openai_tool_calls(), + tool_calls=llm_resp.to_openai_to_calls_model(), content=llm_resp.completion_text, ), tool_calls_result=tool_call_result_blocks, @@ -238,7 +236,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): else: # 如果没有 handler(如 MCP 工具),使用所有参数 valid_params = func_tool_args - logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数") try: await self.agent_hooks.on_tool_start( @@ -319,13 +316,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): elif resp is None: # Tool 直接请求发送消息给用户 # 这里我们将直接结束 Agent Loop。 + # 发送消息逻辑在 ToolExecutor 中处理了。 + logger.warning( + f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。" + ) self._transition_state(AgentState.DONE) - if res := self.run_context.event.get_result(): - if res.chain: - yield MessageChain( - chain=res.chain, - type="tool_direct_result", - ) else: # 不应该出现其他类型 logger.warning( @@ -341,8 +336,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) except Exception as e: logger.error(f"Error in on_tool_end hook: {e}", exc_info=True) - - self.run_context.event.clear_result() except Exception as e: logger.warning(traceback.format_exc()) tool_call_result_blocks.append( diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 3c36def63..e9738dc0f 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,52 +1,76 @@ from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Generic +import jsonschema +import mcp from deprecated import deprecated +from pydantic import model_validator +from pydantic.dataclasses import dataclass -from .mcp_client import MCPClient +from .run_context import ContextWrapper, TContext + +ParametersType = dict[str, Any] @dataclass -class FunctionTool: - """A class representing a function tool that can be used in function calling.""" +class ToolSchema: + """A class representing the schema of a tool for function calling.""" name: str - parameters: dict | None = None - description: str | None = None - handler: Callable[..., Awaitable[Any]] | None = None - """处理函数, 当 origin 为 mcp 时,这个为空""" - handler_module_path: str | None = None - """处理函数的模块路径,当 origin 为 mcp 时,这个为空 + """The name of the tool.""" - 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + description: str + """The description of the tool.""" + + parameters: ParametersType + """The parameters of the tool, in JSON Schema format.""" + + @model_validator(mode="after") + def validate_parameters(self) -> "ToolSchema": + jsonschema.validate( + self.parameters, jsonschema.Draft202012Validator.META_SCHEMA + ) + return self + + +@dataclass +class FunctionTool(ToolSchema, Generic[TContext]): + """A callable tool, for function calling.""" + + handler: Callable[..., Awaitable[Any]] | None = None + """a callable that implements the tool's functionality. It should be an async function.""" + + handler_module_path: str | None = None + """ + The module path of the handler function. This is empty when the origin is mcp. + This field must be retained, as the handler will be wrapped in functools.partial during initialization, + causing the handler's __module__ to be functools """ active: bool = True - """是否激活""" - - origin: Literal["local", "mcp"] = "local" - """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" - - # MCP 相关字段 - mcp_server_name: str | None = None - """MCP 服务名称,当 origin 为 mcp 时有效""" - mcp_client: MCPClient | None = None - """MCP 客户端,当 origin 为 mcp 时有效""" + """ + Whether the tool is active. This field is a special field for AstrBot. + You can ignore it when integrating with other frameworks. + """ def __repr__(self): - return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" + return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" def __dict__(self) -> dict[str, Any]: - """将 FunctionTool 转换为字典格式""" return { "name": self.name, "parameters": self.parameters, "description": self.description, "active": self.active, - "origin": self.origin, - "mcp_server_name": self.mcp_server_name, } + async def call( + self, context: ContextWrapper[TContext], **kwargs + ) -> str | mcp.types.CallToolResult: + """Run the tool with the given arguments. The handler field has priority.""" + raise NotImplementedError( + "FunctionTool.call() must be implemented by subclasses or set a handler." + ) + class ToolSet: """A set of function tools that can be used in function calling. @@ -225,7 +249,7 @@ class ToolSet: tools = [] for tool in self.tools: - d = { + d: dict[str, Any] = { "name": tool.name, "description": tool.description, } diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py index e21ddb9c6..28b242253 100644 --- a/astrbot/core/astr_agent_context.py +++ b/astrbot/core/astr_agent_context.py @@ -1,5 +1,6 @@ from dataclasses import dataclass +from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest @@ -10,4 +11,4 @@ class AstrAgentContext: first_provider_request: ProviderRequest curr_provider_request: ProviderRequest streaming: bool - tool_call_timeout: int = 60 # Default tool call timeout in seconds + event: AstrMessageEvent diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 2be406100..287fe03c4 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -8,6 +8,7 @@ import json from collections.abc import Awaitable, Callable from astrbot.core import sp +from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 @@ -319,6 +320,41 @@ class ConversationManager: persona_id=persona_id, ) + async def add_message_pair( + self, + cid: str, + user_message: UserMessageSegment | dict, + assistant_message: AssistantMessageSegment | dict, + ) -> None: + """Add a user-assistant message pair to the conversation history. + + Args: + cid (str): Conversation ID + user_message (UserMessageSegment | dict): OpenAI-format user message object or dict + assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict + + Raises: + Exception: If the conversation with the given ID is not found + """ + conv = await self.db.get_conversation_by_id(cid=cid) + if not conv: + raise Exception(f"Conversation with id {cid} not found") + history = conv.content or [] + if isinstance(user_message, UserMessageSegment): + user_msg_dict = user_message.model_dump() + else: + user_msg_dict = user_message + if isinstance(assistant_message, AssistantMessageSegment): + assistant_msg_dict = assistant_message.model_dump() + else: + assistant_msg_dict = assistant_message + history.append(user_msg_dict) + history.append(assistant_msg_dict) + await self.db.update_conversation( + cid=cid, + content=history, + ) + async def get_human_readable_context( self, unified_msg_origin: str, diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index a6cd567e0..44186764e 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from astrbot.core.config import AstrBotConfig from astrbot.core.star import PluginManager -from .context_utils import call_event_hook, call_handler +from .context_utils import call_event_hook, call_handler, call_local_llm_tool @dataclass @@ -15,3 +15,4 @@ class PipelineContext: astrbot_config_id: str call_handler = call_handler call_event_hook = call_event_hook + call_local_llm_tool = call_local_llm_tool diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 73d28c5d1..371816b6e 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -3,6 +3,8 @@ import traceback import typing as T from astrbot import logger +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.message.message_event_result import CommandResult, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star import star_map @@ -105,3 +107,66 @@ async def call_event_hook( return True return event.is_stopped() + + +async def call_local_llm_tool( + context: ContextWrapper[AstrAgentContext], + handler: T.Callable[..., T.Awaitable[T.Any]], + method_name: str, + *args, + **kwargs, +) -> T.AsyncGenerator[T.Any, None]: + """执行本地 LLM 工具的处理函数并处理其返回结果""" + ready_to_call = None # 一个协程或者异步生成器 + + trace_ = None + + event = context.context.event + + try: + if method_name == "run" or method_name == "decorator_handler": + ready_to_call = handler(event, *args, **kwargs) + elif method_name == "call": + ready_to_call = handler(context, *args, **kwargs) + else: + raise ValueError(f"未知的方法名: {method_name}") + except ValueError as e: + logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True) + except TypeError: + logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) + except Exception as e: + trace_ = traceback.format_exc() + logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}") + + if not ready_to_call: + return + + if inspect.isasyncgen(ready_to_call): + _has_yielded = False + try: + async for ret in ready_to_call: + # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 + # 返回值只能是 MessageEventResult 或者 None(无返回值) + _has_yielded = True + if isinstance(ret, (MessageEventResult, CommandResult)): + # 如果返回值是 MessageEventResult, 设置结果并继续 + event.set_result(ret) + yield + else: + # 如果返回值是 None, 则不设置结果并继续 + # 继续执行后续阶段 + yield ret + if not _has_yielded: + # 如果这个异步生成器没有执行到 yield 分支 + yield + except Exception as e: + logger.error(f"Previous Error: {trace_}") + raise e + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个协程, 直接执行 + ret = await ready_to_call + if isinstance(ret, (MessageEventResult, CommandResult)): + event.set_result(ret) + yield + else: + yield ret diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index d1cffc43f..03352cc40 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -5,11 +5,14 @@ import copy import json import traceback from collections.abc import AsyncGenerator -from datetime import timedelta +from typing import Any + +from mcp.types import CallToolResult from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import FunctionTool, ToolSet @@ -33,7 +36,7 @@ from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric -from ...context import PipelineContext, call_event_hook, call_handler +from ...context import PipelineContext, call_event_hook, call_local_llm_tool from ..stage import Stage from ..utils import inject_kb_context @@ -65,17 +68,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): yield r return - if tool.origin == "local": - async for r in cls._execute_local(tool, run_context, **tool_args): - yield r - return - - elif tool.origin == "mcp": + elif isinstance(tool, MCPTool): async for r in cls._execute_mcp(tool, run_context, **tool_args): yield r return - raise Exception(f"Unknown function origin: {tool.origin}") + else: + async for r in cls._execute_local(tool, run_context, **tool_args): + yield r + return @classmethod async def _execute_handoff( @@ -113,10 +114,13 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): first_provider_request=run_context.context.first_provider_request, curr_provider_request=request, streaming=run_context.context.streaming, + event=run_context.context.event, ) + event = run_context.context.event + logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}") - await run_context.event.send( + await event.send( MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name), ) @@ -125,7 +129,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): request=request, run_context=AgentContextWrapper( context=astr_agent_ctx, - event=run_context.event, + tool_call_timeout=run_context.tool_call_timeout, ), tool_executor=FunctionToolExecutor(), agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](), @@ -175,25 +179,46 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): run_context: ContextWrapper[AstrAgentContext], **tool_args, ): - if not run_context.event: + event = run_context.context.event + if not event: raise ValueError("Event must be provided for local function tools.") - # 检查 tool 下有没有 run 方法 - if not tool.handler and not hasattr(tool, "run"): - raise ValueError("Tool must have a valid handler or 'run' method.") - awaitable = tool.handler or tool.run + is_override_call = False + for ty in type(tool).mro(): + if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call: + logger.debug(f"Found call in: {ty}") + is_override_call = True + break - wrapper = call_handler( - event=run_context.event, + # 检查 tool 下有没有 run 方法 + if not tool.handler and not hasattr(tool, "run") and not is_override_call: + raise ValueError("Tool must have a valid handler or override 'run' method.") + + awaitable = None + method_name = "" + if tool.handler: + awaitable = tool.handler + method_name = "decorator_handler" + elif is_override_call: + awaitable = tool.call + method_name = "call" + elif hasattr(tool, "run"): + awaitable = getattr(tool, "run") + method_name = "run" + if awaitable is None: + raise ValueError("Tool must have a valid handler or override 'run' method.") + + wrapper = call_local_llm_tool( + context=run_context, handler=awaitable, + method_name=method_name, **tool_args, ) - # async for resp in wrapper: while True: try: resp = await asyncio.wait_for( anext(wrapper), - timeout=run_context.context.tool_call_timeout, + timeout=run_context.tool_call_timeout, ) if resp is not None: if isinstance(resp, mcp.types.CallToolResult): @@ -208,10 +233,24 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): # NOTE: Tool 在这里直接请求发送消息给用户 # TODO: 是否需要判断 event.get_result() 是否为空? # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" + if res := run_context.context.event.get_result(): + if res.chain: + try: + await event.send( + MessageChain( + chain=res.chain, + type="tool_direct_result", + ) + ) + except Exception as e: + logger.error( + f"Tool 直接发送消息失败: {e}", + exc_info=True, + ) yield None except asyncio.TimeoutError: raise Exception( - f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds.", + f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.", ) except StopAsyncIteration: break @@ -223,19 +262,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): run_context: ContextWrapper[AstrAgentContext], **tool_args, ): - if not tool.mcp_client: - raise ValueError("MCP client is not available for MCP function tools.") - - session = tool.mcp_client.session - if not session: - raise ValueError("MCP session is not available for MCP function tools.") - res = await session.call_tool( - name=tool.name, - arguments=tool_args, - read_timeout_seconds=timedelta( - seconds=run_context.context.tool_call_timeout, - ), - ) + res = await tool.call(run_context, **tool_args) if not res: return yield res @@ -245,11 +272,20 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): async def on_agent_done(self, run_context, llm_response): # 执行事件钩子 await call_event_hook( - run_context.event, + run_context.context.event, EventType.OnLLMResponseEvent, llm_response, ) + async def on_tool_end( + self, + run_context: ContextWrapper[AstrAgentContext], + tool: FunctionTool[Any], + tool_args: dict | None, + tool_result: CallToolResult | None, + ): + run_context.context.event.clear_result() + MAIN_AGENT_HOOKS = MainAgentHooks() @@ -260,7 +296,7 @@ async def run_agent( show_tool_use: bool = True, ) -> AsyncGenerator[MessageChain, None]: step_idx = 0 - astr_event = agent_runner.run_context.event + astr_event = agent_runner.run_context.context.event while step_idx < max_step: step_idx += 1 try: @@ -513,12 +549,15 @@ class LLMRequestSubStage(Stage): first_provider_request=req, curr_provider_request=req, streaming=self.streaming_response, - tool_call_timeout=self.tool_call_timeout, + event=event, ) await agent_runner.reset( provider=provider, request=req, - run_context=AgentContextWrapper(context=astr_agent_ctx, event=event), + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=self.tool_call_timeout, + ), tool_executor=FunctionToolExecutor(), agent_hooks=MAIN_AGENT_HOOKS, streaming=self.streaming_response, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 28dc63f72..2f1e84419 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -4,15 +4,17 @@ import json from dataclasses import dataclass, field from typing import Any -from anthropic.types import Message +from anthropic.types import Message as AnthropicMessage from google.genai.types import GenerateContentResponse from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, -) import astrbot.core.message.components as Comp from astrbot import logger +from astrbot.core.agent.message import ( + AssistantMessageSegment, + ToolCall, + ToolCallMessageSegment, +) from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain @@ -32,9 +34,9 @@ class ProviderMetaData: type: str """提供商适配器名称,如 openai, ollama""" desc: str = "" - """提供商适配器描述.""" + """提供商适配器描述""" provider_type: ProviderType = ProviderType.CHAT_COMPLETION - cls_type: type | None = None + cls_type: Any = None default_config_tmpl: dict | None = None """平台的默认配置模板""" @@ -42,44 +44,6 @@ class ProviderMetaData: """显示在 WebUI 配置页中的提供商名称,如空则是 type""" -@dataclass -class ToolCallMessageSegment: - """OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" - - tool_call_id: str - content: str - role: str = "tool" - - def to_dict(self): - return { - "tool_call_id": self.tool_call_id, - "content": self.content, - "role": self.role, - } - - -@dataclass -class AssistantMessageSegment: - """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" - - content: str | None = None - tool_calls: list[ChatCompletionMessageToolCall | dict] = field(default_factory=list) - role: str = "assistant" - - def to_dict(self): - ret: dict[str, str | list[dict]] = { - "role": self.role, - } - if self.content: - ret["content"] = self.content - if self.tool_calls: - tool_calls_dict = [ - tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls - ] - ret["tool_calls"] = tool_calls_dict - return ret - - @dataclass class ToolCallsResult: """工具调用结果""" @@ -91,8 +55,8 @@ class ToolCallsResult: def to_openai_messages(self) -> list[dict]: ret = [ - self.tool_calls_info.to_dict(), - *[item.to_dict() for item in self.tool_calls_result], + self.tool_calls_info.model_dump(), + *[item.model_dump() for item in self.tool_calls_result], ] return ret @@ -108,16 +72,16 @@ class ProviderRequest: func_tool: ToolSet | None = None """可用的函数工具""" contexts: list[dict] = field(default_factory=list) - """上下文。格式与 openai 的上下文格式一致: + """ + OpenAI 格式上下文列表。 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages """ system_prompt: str = "" """系统提示词""" conversation: Conversation | None = None - + """关联的对话对象""" tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" - model: str | None = None """模型名称,为 None 时使用提供商的默认模型""" @@ -227,7 +191,9 @@ class LLMResponse: tools_call_ids: list[str] = field(default_factory=list) """工具调用 ID""" - raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None + raw_completion: ( + ChatCompletion | GenerateContentResponse | AnthropicMessage | None + ) = None _new_record: dict[str, Any] | None = None _completion_text: str = "" @@ -243,7 +209,10 @@ class LLMResponse: tools_call_args: list[dict[str, Any]] | None = None, tools_call_name: list[str] | None = None, tools_call_ids: list[str] | None = None, - raw_completion: ChatCompletion | None = None, + raw_completion: ChatCompletion + | GenerateContentResponse + | AnthropicMessage + | None = None, _new_record: dict[str, Any] | None = None, is_chunk: bool = False, ): @@ -294,7 +263,7 @@ class LLMResponse: self._completion_text = value def to_openai_tool_calls(self) -> list[dict]: - """将工具调用信息转换为 OpenAI 格式""" + """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): ret.append( @@ -309,6 +278,21 @@ class LLMResponse: ) return ret + def to_openai_to_calls_model(self) -> list[ToolCall]: + """The same as to_openai_tool_calls but return pydantic model.""" + ret = [] + for idx, tool_call_arg in enumerate(self.tools_call_args): + ret.append( + ToolCall( + id=self.tools_call_ids[idx], + function=ToolCall.FunctionBody( + name=self.tools_call_name[idx], + arguments=json.dumps(tool_call_arg), + ), + ), + ) + return ret + @dataclass class RerankResult: diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index b3ef1ed5c..36aad2ae9 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -10,7 +10,7 @@ import aiohttp from astrbot import logger from astrbot.core import sp -from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.agent.mcp_client import MCPClient, MCPTool from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -254,18 +254,15 @@ class FunctionToolManager: self.func_list = [ f for f in self.func_list - if not (f.origin == "mcp" and f.mcp_server_name == name) + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) ] # 将 MCP 工具转换为 FuncTool 并添加到 func_list for tool in mcp_client.tools: - func_tool = FuncTool( - name=tool.name, - parameters=tool.inputSchema, - description=tool.description, - origin="mcp", - mcp_server_name=name, + func_tool = MCPTool( + mcp_tool=tool, mcp_client=mcp_client, + mcp_server_name=name, ) self.func_list.append(func_tool) @@ -284,7 +281,7 @@ class FunctionToolManager: self.func_list = [ f for f in self.func_list - if not (f.origin == "mcp" and f.mcp_server_name == name) + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) ] logger.info(f"已关闭 MCP 服务 {name}") @@ -374,7 +371,7 @@ class FunctionToolManager: self.func_list = [ f for f in self.func_list - if f.origin != "mcp" or f.mcp_server_name != name + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) ] else: running_events = [ @@ -388,7 +385,9 @@ class FunctionToolManager: finally: self.mcp_client_event.clear() self.mcp_client_dict.clear() - self.func_list = [f for f in self.func_list if f.origin != "mcp"] + self.func_list = [ + f for f in self.func_list if not isinstance(f, MCPTool) + ] def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: """获得 OpenAI API 风格的**已经激活**的工具描述""" diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index a46b9cf9f..7ab8f00ba 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -3,6 +3,7 @@ import asyncio from collections.abc import AsyncGenerator from dataclasses import dataclass +from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Personality from astrbot.core.provider.entities import ( @@ -23,24 +24,28 @@ class ProviderMeta: class AbstractProvider(abc.ABC): + """Provider Abstract Class""" + def __init__(self, provider_config: dict) -> None: super().__init__() self.model_name = "" self.provider_config = provider_config def set_model(self, model_name: str): - """设置当前使用的模型名称""" + """Set the current model name""" self.model_name = model_name def get_model(self) -> str: - """获得当前使用的模型名称""" + """Get the current model name""" return self.model_name def meta(self) -> ProviderMeta: - """获取 Provider 的元数据""" + """Get the provider metadata""" provider_type_name = self.provider_config["type"] meta_data = provider_cls_map.get(provider_type_name) provider_type = meta_data.provider_type if meta_data else None + if provider_type is None: + raise ValueError(f"Cannot find provider type: {provider_type_name}") return ProviderMeta( id=self.provider_config["id"], model=self.get_model(), @@ -50,6 +55,8 @@ class AbstractProvider(abc.ABC): class Provider(AbstractProvider): + """Chat Provider""" + def __init__( self, provider_config: dict, @@ -84,11 +91,11 @@ class Provider(AbstractProvider): @abc.abstractmethod async def text_chat( self, - prompt: str, + prompt: str | None = None, session_id: str | None = None, image_urls: list[str] | None = None, func_tool: ToolSet | None = None, - contexts: list | None = None, + contexts: list[Message] | list[dict] | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, @@ -97,11 +104,11 @@ class Provider(AbstractProvider): """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 Args: - prompt: 提示词 + prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 - tools: Function-calling 工具 - contexts: 上下文 + tools: tool set + contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 @@ -114,11 +121,11 @@ class Provider(AbstractProvider): async def text_chat_stream( self, - prompt: str, + prompt: str | None = None, session_id: str | None = None, image_urls: list[str] | None = None, func_tool: ToolSet | None = None, - contexts: list | None = None, + contexts: list[Message] | list[dict] | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, @@ -127,11 +134,11 @@ class Provider(AbstractProvider): """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 Args: - prompt: 提示词 + prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 - tools: Function-calling 工具 - contexts: 上下文 + tools: tool set + contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 @@ -140,6 +147,7 @@ class Provider(AbstractProvider): - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ + ... async def pop_record(self, context: list): """弹出 context 第一条非系统提示词对话记录""" @@ -156,6 +164,22 @@ class Provider(AbstractProvider): for idx in reversed(indexs_to_pop): context.pop(idx) + def _ensure_message_to_dicts( + self, + messages: list[dict] | list[Message] | None, + ) -> list[dict]: + """Convert a list of Message objects to a list of dictionaries.""" + if not messages: + return [] + dicts: list[dict] = [] + for message in messages: + if isinstance(message, Message): + dicts.append(message.model_dump()) + else: + dicts.append(message) + + return dicts + class STTProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 6f292f076..77c85cef4 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -243,7 +243,7 @@ class ProviderAnthropic(Provider): async def text_chat( self, - prompt, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -255,8 +255,13 @@ class ProviderAnthropic(Provider): ) -> LLMResponse: if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) + if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -306,8 +311,12 @@ class ProviderAnthropic(Provider): ): if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py index caee65020..23a8b3b76 100644 --- a/astrbot/core/provider/sources/coze_source.py +++ b/astrbot/core/provider/sources/coze_source.py @@ -331,6 +331,7 @@ class ProviderCoze(Provider): }, ) + contexts = self._ensure_message_to_dicts(contexts) if not self.auto_save_history and contexts: # 如果关闭了自动保存历史,传入上下文 for ctx in contexts: diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index f9eef2e92..c3c9253a5 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -572,7 +572,7 @@ class ProviderGoogleGenAI(Provider): async def text_chat( self, - prompt: str, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -584,8 +584,12 @@ class ProviderGoogleGenAI(Provider): ) -> LLMResponse: if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -621,7 +625,7 @@ class ProviderGoogleGenAI(Provider): async def text_chat_stream( self, - prompt, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -633,8 +637,12 @@ class ProviderGoogleGenAI(Provider): ) -> AsyncGenerator[LLMResponse, None]: if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -726,7 +734,6 @@ class ProviderGoogleGenAI(Provider): with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - return "" async def terminate(self): logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 1020075af..076afc40f 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -14,9 +14,10 @@ from openai.types.chat.chat_completion import ChatCompletion import astrbot.core.message.components as Comp from astrbot import logger from astrbot.api.provider import Provider +from astrbot.core.agent.message import Message +from astrbot.core.agent.tool import ToolSet from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, ToolCallsResult -from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.io import download_image_by_url from ..register import register_provider_adapter @@ -102,7 +103,7 @@ class ProviderOpenAIOfficial(Provider): except NotFoundError as e: raise Exception(f"获取模型列表失败:{e}") - async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse: + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -153,7 +154,7 @@ class ProviderOpenAIOfficial(Provider): async def _query_stream( self, payloads: dict, - tools: ToolSet, + tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" if tools: @@ -212,7 +213,9 @@ class ProviderOpenAIOfficial(Provider): yield llm_response - async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet): + async def parse_openai_completion( + self, completion: ChatCompletion, tools: ToolSet | None + ) -> LLMResponse: """解析 OpenAI 的 ChatCompletion 响应""" llm_response = LLMResponse("assistant") @@ -225,7 +228,7 @@ class ProviderOpenAIOfficial(Provider): completion_text = str(choice.message.content).strip() llm_response.result_chain = MessageChain().message(completion_text) - if choice.message.tool_calls: + if choice.message.tool_calls and tools is not None: # tools call (function calling) args_ls = [] func_name_ls = [] @@ -267,9 +270,9 @@ class ProviderOpenAIOfficial(Provider): async def _prepare_chat_payload( self, - prompt: str, + prompt: str | None, image_urls: list[str] | None = None, - contexts: list | None = None, + contexts: list[dict] | list[Message] | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, @@ -278,8 +281,12 @@ class ProviderOpenAIOfficial(Provider): """准备聊天所需的有效载荷和上下文""" if contexts is None: contexts = [] - new_record = await self.assemble_context(prompt, image_urls) - context_query = [*contexts, new_record] + new_record = None + if prompt is not None: + new_record = await self.assemble_context(prompt, image_urls) + context_query = self._ensure_message_to_dicts(contexts) + if new_record: + context_query.append(new_record) if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -310,7 +317,7 @@ class ProviderOpenAIOfficial(Provider): e: Exception, payloads: dict, context_query: list, - func_tool: ToolSet, + func_tool: ToolSet | None, chosen_key: str, available_api_keys: list[str], retry_cnt: int, @@ -390,7 +397,7 @@ class ProviderOpenAIOfficial(Provider): async def text_chat( self, - prompt, + prompt=None, session_id=None, image_urls=None, func_tool=None, @@ -459,7 +466,7 @@ class ProviderOpenAIOfficial(Provider): async def text_chat_stream( self, - prompt: str, + prompt=None, session_id=None, image_urls=None, func_tool=None, diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 620e7e907..1a5bc53d9 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,3 +1,4 @@ +import logging from asyncio import Queue from collections.abc import Awaitable, Callable from typing import Any @@ -35,6 +36,8 @@ from .filter.regex import RegexFilter from .star import StarMetadata, star_map, star_registry from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry +logger = logging.getLogger("astrbot") + class Context: """暴露给插件的接口上下文。""" @@ -255,9 +258,44 @@ class Context: def add_llm_tools(self, *tools: FunctionTool) -> None: """添加 LLM 工具。""" + tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list} + module_path = "" for tool in tools: + if not module_path: + _parts = [] + module_part = tool.__module__.split(".") + flags = ["packages", "plugins"] + for i, part in enumerate(module_part): + _parts.append(part) + if part in flags and i + 1 < len(module_part): + _parts.append(module_part[i + 1]) + break + tool.handler_module_path = ".".join(_parts) + module_path = tool.handler_module_path + else: + tool.handler_module_path = module_path + logger.info( + f"plugin(module_path {module_path}) added LLM tool: {tool.name}" + ) + + if tool.name in tool_name: + logger.warning("替换已存在的 LLM 工具: " + tool.name) + self.provider_manager.llm_tools.remove_func(tool.name) self.provider_manager.llm_tools.func_list.append(tool) + def register_web_api( + self, + route: str, + view_handler: Awaitable, + methods: list, + desc: str, + ): + for idx, api in enumerate(self.registered_web_apis): + if api[0] == route and methods == api[2]: + self.registered_web_apis[idx] = (route, view_handler, methods, desc) + return + self.registered_web_apis.append((route, view_handler, methods, desc)) + """ 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 """ @@ -269,7 +307,7 @@ class Context: desc: str, func_obj: Callable[..., Awaitable[Any]], ) -> None: - """为函数调用(function-calling / tools-use)添加工具。 + """[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。 @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @@ -291,7 +329,7 @@ class Context: self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj) def unregister_llm_tool(self, name: str) -> None: - """删除一个函数调用工具。如果再要启用,需要重新注册。""" + """[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。""" self.provider_manager.llm_tools.remove_func(name) def register_commands( @@ -333,18 +371,5 @@ class Context: star_handlers_registry.append(md) def register_task(self, task: Awaitable, desc: str): - """注册一个异步任务。""" + """[DEPRECATED]注册一个异步任务。""" self._register_tasks.append(task) - - def register_web_api( - self, - route: str, - view_handler: Awaitable, - methods: list, - desc: str, - ): - for idx, api in enumerate(self.registered_web_apis): - if api[0] == route and methods == api[2]: - self.registered_web_apis[idx] = (route, view_handler, methods, desc) - return - self.registered_web_apis.append((route, view_handler, methods, desc)) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index ef50917fe..7a49807f6 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -751,6 +751,19 @@ class PluginManager: ]: del star_handlers_registry.star_handlers_map[k] + # llm_tools 中移除该插件的工具函数绑定 + to_remove = [] + for func_tool in llm_tools.func_list: + mp = func_tool.handler_module_path + if ( + mp + and mp.startswith(plugin_module_path) + and not mp.endswith(("packages", "data.plugins")) + ): + to_remove.append(func_tool) + for func_tool in to_remove: + llm_tools.func_list.remove(func_tool) + if plugin is None: return @@ -795,7 +808,13 @@ class PluginManager: # 禁用插件启用的 llm_tool for func_tool in llm_tools.func_list: - if func_tool.handler_module_path == plugin.module_path: + mp = func_tool.handler_module_path + if ( + plugin.module_path + and mp + and plugin.module_path.startswith(mp) + and not mp.endswith(("packages", "data.plugins")) + ): func_tool.active = False if func_tool.name not in inactivated_llm_tools: inactivated_llm_tools.append(func_tool.name) @@ -838,8 +857,12 @@ class PluginManager: # 启用插件启用的 llm_tool for func_tool in llm_tools.func_list: + mp = func_tool.handler_module_path if ( - func_tool.handler_module_path == plugin.module_path + plugin.module_path + and mp + and plugin.module_path.startswith(mp) + and not mp.endswith(("packages", "data.plugins")) and func_tool.name in inactivated_llm_tools ): inactivated_llm_tools.remove(func_tool.name) @@ -848,8 +871,6 @@ class PluginManager: await self.reload(plugin_name) - # plugin.activated = True - async def install_plugin_from_file(self, zip_file_path: str): dir_name = os.path.basename(zip_file_path).replace(".zip", "") dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()