From a2be155b8eec77ad39fac63183c05b24344e7304 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 29 Oct 2025 20:13:15 +0800 Subject: [PATCH] feat: add method to register LLM tools in Context class --- astrbot/core/star/context.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 0229f4dbb..f7480793a 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,5 +1,4 @@ from asyncio import Queue -from typing import List, Union from astrbot.core.provider.provider import ( Provider, @@ -11,7 +10,7 @@ from astrbot.core.provider.provider import ( from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.provider.func_tool_manager import FunctionToolManager +from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.manager import ProviderManager @@ -25,7 +24,8 @@ from .star import star_registry, StarMetadata, star_map from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter from .filter.regex import RegexFilter -from typing import Awaitable, Any, Callable +from typing import Any +from collections.abc import Awaitable, Callable from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, @@ -42,7 +42,7 @@ class Context: registered_web_apis: list = [] # back compatibility - _register_tasks: List[Awaitable] = [] + _register_tasks: list[Awaitable] = [] _star_manager = None def __init__( @@ -78,7 +78,7 @@ class Context: if star.name == star_name: return star - def get_all_stars(self) -> List[StarMetadata]: + def get_all_stars(self) -> list[StarMetadata]: """获取当前载入的所有插件 Metadata 的列表""" return star_registry @@ -116,19 +116,19 @@ class Context: prov = self.provider_manager.inst_map.get(provider_id) return prov - def get_all_providers(self) -> List[Provider]: + def get_all_providers(self) -> list[Provider]: """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" return self.provider_manager.provider_insts - def get_all_tts_providers(self) -> List[TTSProvider]: + def get_all_tts_providers(self) -> list[TTSProvider]: """获取所有用于 TTS 任务的 Provider。""" return self.provider_manager.tts_provider_insts - def get_all_stt_providers(self) -> List[STTProvider]: + def get_all_stt_providers(self) -> list[STTProvider]: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts - def get_all_embedding_providers(self) -> List[EmbeddingProvider]: + def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts @@ -196,9 +196,7 @@ class Context: return self._event_queue @deprecated(version="4.0.0", reason="Use get_platform_inst instead") - def get_platform( - self, platform_type: Union[PlatformAdapterType, str] - ) -> Platform | None: + def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: """ 获取指定类型的平台适配器。 @@ -231,7 +229,7 @@ class Context: return platform async def send_message( - self, session: Union[str, MessageSesion], message_chain: MessageChain + self, session: str | MessageSesion, message_chain: MessageChain ) -> bool: """ 根据 session(unified_msg_origin) 主动发送消息。 @@ -258,6 +256,11 @@ class Context: return True return False + def add_llm_tools(self, *tools: FunctionTool) -> None: + """添加一个 LLM 工具。""" + for tool in tools: + self.provider_manager.llm_tools.func_list.append(tool) + """ 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 """