feat: add method to register LLM tools in Context class

This commit is contained in:
Soulter
2025-10-29 20:13:15 +08:00
parent 68aa107689
commit a2be155b8e
+16 -13
View File
@@ -1,5 +1,4 @@
from asyncio import Queue
from typing import List, Union
from astrbot.core.provider.provider import (
Provider,
@@ -11,7 +10,7 @@ from astrbot.core.provider.provider import (
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
@@ -25,7 +24,8 @@ from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable, Any, Callable
from typing import Any
from collections.abc import Awaitable, Callable
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
@@ -42,7 +42,7 @@ class Context:
registered_web_apis: list = []
# back compatibility
_register_tasks: List[Awaitable] = []
_register_tasks: list[Awaitable] = []
_star_manager = None
def __init__(
@@ -78,7 +78,7 @@ class Context:
if star.name == star_name:
return star
def get_all_stars(self) -> List[StarMetadata]:
def get_all_stars(self) -> list[StarMetadata]:
"""获取当前载入的所有插件 Metadata 的列表"""
return star_registry
@@ -116,19 +116,19 @@ class Context:
prov = self.provider_manager.inst_map.get(provider_id)
return prov
def get_all_providers(self) -> List[Provider]:
def get_all_providers(self) -> list[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.provider_insts
def get_all_tts_providers(self) -> List[TTSProvider]:
def get_all_tts_providers(self) -> list[TTSProvider]:
"""获取所有用于 TTS 任务的 Provider。"""
return self.provider_manager.tts_provider_insts
def get_all_stt_providers(self) -> List[STTProvider]:
def get_all_stt_providers(self) -> list[STTProvider]:
"""获取所有用于 STT 任务的 Provider。"""
return self.provider_manager.stt_provider_insts
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
def get_all_embedding_providers(self) -> list[EmbeddingProvider]:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
@@ -196,9 +196,7 @@ class Context:
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(
self, platform_type: Union[PlatformAdapterType, str]
) -> Platform | None:
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""
获取指定类型的平台适配器。
@@ -231,7 +229,7 @@ class Context:
return platform
async def send_message(
self, session: Union[str, MessageSesion], message_chain: MessageChain
self, session: str | MessageSesion, message_chain: MessageChain
) -> bool:
"""
根据 session(unified_msg_origin) 主动发送消息。
@@ -258,6 +256,11 @@ class Context:
return True
return False
def add_llm_tools(self, *tools: FunctionTool) -> None:
"""添加一个 LLM 工具。"""
for tool in tools:
self.provider_manager.llm_tools.func_list.append(tool)
"""
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""