feat: add method to register LLM tools in Context class
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core.provider.provider import (
|
||||
Provider,
|
||||
@@ -11,7 +10,7 @@ from astrbot.core.provider.provider import (
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
@@ -25,7 +24,8 @@ from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable, Any, Callable
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
@@ -42,7 +42,7 @@ class Context:
|
||||
registered_web_apis: list = []
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_register_tasks: list[Awaitable] = []
|
||||
_star_manager = None
|
||||
|
||||
def __init__(
|
||||
@@ -78,7 +78,7 @@ class Context:
|
||||
if star.name == star_name:
|
||||
return star
|
||||
|
||||
def get_all_stars(self) -> List[StarMetadata]:
|
||||
def get_all_stars(self) -> list[StarMetadata]:
|
||||
"""获取当前载入的所有插件 Metadata 的列表"""
|
||||
return star_registry
|
||||
|
||||
@@ -116,19 +116,19 @@ class Context:
|
||||
prov = self.provider_manager.inst_map.get(provider_id)
|
||||
return prov
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
def get_all_providers(self) -> list[Provider]:
|
||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_all_tts_providers(self) -> List[TTSProvider]:
|
||||
def get_all_tts_providers(self) -> list[TTSProvider]:
|
||||
"""获取所有用于 TTS 任务的 Provider。"""
|
||||
return self.provider_manager.tts_provider_insts
|
||||
|
||||
def get_all_stt_providers(self) -> List[STTProvider]:
|
||||
def get_all_stt_providers(self) -> list[STTProvider]:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
|
||||
def get_all_embedding_providers(self) -> list[EmbeddingProvider]:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
@@ -196,9 +196,7 @@ class Context:
|
||||
return self._event_queue
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(
|
||||
self, platform_type: Union[PlatformAdapterType, str]
|
||||
) -> Platform | None:
|
||||
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
|
||||
"""
|
||||
获取指定类型的平台适配器。
|
||||
|
||||
@@ -231,7 +229,7 @@ class Context:
|
||||
return platform
|
||||
|
||||
async def send_message(
|
||||
self, session: Union[str, MessageSesion], message_chain: MessageChain
|
||||
self, session: str | MessageSesion, message_chain: MessageChain
|
||||
) -> bool:
|
||||
"""
|
||||
根据 session(unified_msg_origin) 主动发送消息。
|
||||
@@ -258,6 +256,11 @@ class Context:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_llm_tools(self, *tools: FunctionTool) -> None:
|
||||
"""添加一个 LLM 工具。"""
|
||||
for tool in tools:
|
||||
self.provider_manager.llm_tools.func_list.append(tool)
|
||||
|
||||
"""
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user