diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 45226991c..b60088609 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -4,7 +4,7 @@ from typing import Any, Generic import jsonschema import mcp from deprecated import deprecated -from pydantic import model_validator +from pydantic import Field, model_validator from pydantic.dataclasses import dataclass from .run_context import ContextWrapper, TContext @@ -63,6 +63,7 @@ class FunctionTool(ToolSchema, Generic[TContext]): ) +@dataclass class ToolSet: """A set of function tools that can be used in function calling. @@ -70,8 +71,7 @@ class ToolSet: convert the tools to different API formats (OpenAI, Anthropic, Google GenAI). """ - def __init__(self, tools: list[FunctionTool] | None = None): - self.tools: list[FunctionTool] = tools or [] + tools: list[FunctionTool] = Field(default_factory=list) def empty(self) -> bool: """Check if the tool set is empty.""" diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 5918b2029..21c1ad8fd 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -259,10 +259,6 @@ class Context: """ return self.provider_manager.llm_tools.deactivate_llm_tool(name) - def register_provider(self, provider: Provider): - """注册一个 LLM Provider(Chat_Completion 类型)。""" - self.provider_manager.provider_insts.append(provider) - def get_provider_by_id( self, provider_id: str, @@ -341,45 +337,6 @@ class Context: return self._config return self.astrbot_config_mgr.get_conf(umo) - def get_db(self) -> BaseDatabase: - """获取 AstrBot 数据库。""" - return self._db - - def get_event_queue(self) -> Queue: - """获取事件队列。""" - return self._event_queue - - @deprecated(version="4.0.0", reason="Use get_platform_inst instead") - def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: - """获取指定类型的平台适配器。 - - 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) - """ - for platform in self.platform_manager.platform_insts: - name = platform.meta().name - if isinstance(platform_type, str): - if name == platform_type: - return platform - elif ( - name in ADAPTER_NAME_2_TYPE - and ADAPTER_NAME_2_TYPE[name] & platform_type - ): - return platform - - def get_platform_inst(self, platform_id: str) -> Platform | None: - """获取指定 ID 的平台适配器实例。 - - Args: - platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。 - - Returns: - Platform: 平台适配器实例,如果未找到则返回 None。 - - """ - for platform in self.platform_manager.platform_insts: - if platform.meta().id == platform_id: - return platform - async def send_message( self, session: str | MessageSesion, @@ -452,6 +409,49 @@ class Context: 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 """ + def get_event_queue(self) -> Queue: + """获取事件队列。""" + return self._event_queue + + @deprecated(version="4.0.0", reason="Use get_platform_inst instead") + def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: + """获取指定类型的平台适配器。 + + 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) + """ + for platform in self.platform_manager.platform_insts: + name = platform.meta().name + if isinstance(platform_type, str): + if name == platform_type: + return platform + elif ( + name in ADAPTER_NAME_2_TYPE + and ADAPTER_NAME_2_TYPE[name] & platform_type + ): + return platform + + def get_platform_inst(self, platform_id: str) -> Platform | None: + """获取指定 ID 的平台适配器实例。 + + Args: + platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。 + + Returns: + Platform: 平台适配器实例,如果未找到则返回 None。 + + """ + for platform in self.platform_manager.platform_insts: + if platform.meta().id == platform_id: + return platform + + def get_db(self) -> BaseDatabase: + """获取 AstrBot 数据库。""" + return self._db + + def register_provider(self, provider: Provider): + """注册一个 LLM Provider(Chat_Completion 类型)。""" + self.provider_manager.provider_insts.append(provider) + def register_llm_tool( self, name: str,