f624971613
* chore(core.utils): 🚨 修正错误Lint
* chore(core.provider): 🚨 修复基类错误Lint
* chore(core.utils): 补全session_get()的重载
* chore(core.provider): 🚨 修正实现错误Lint
* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint
* chore(core.platform): 修正错误实现Lint
* fix(core.provider): 修复循环调用和错误assert
* chore(core.platform): 修复部分实现Lint
* chore(core.provider): 补充Dify.text_chat_stream的参数类型
* chore(core.pipeline): 🚨 修复错误Lint
* fix(core.slack): 补充遗漏导入
* chore(core.utils): 修复错误的session_get声明
* chore(core.platform): 移除Lark adapter import中的wildcard
* chore(core.db): 修复声明和部分逻辑
* chore(core.db): 添加typings,使faiss参数能被正确识别。
* chore(core): 修复声明
* chore(core): 修改声明
* chore: 补充faiss声明
* chore(dashboard): 修改实现,减少报错
* chore(package): 修改部分声明与实现,减少报错
* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查
* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查
* chore(core.config): 添加类型标注,通过类型检查
* chore(core.message): 为File._download_file添加检查,通过类型检查
* fix: 将断言改为条件判断以实现优雅关闭的容错性
* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 增强 Lark 相关空值/异常检查并完善日志输出
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将断言替换为条件检查并加入日志与错误处理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* chore: 移除LLM生成的无用注释
* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断
* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志
* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 去除cast,直接使用字段与字典访问,修正端口解析
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: group_name_display 时若 group 对象为空则记录错误并返回
* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置
* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_
* refactor: 移除 typing.cast,简化内容安全检查调用
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast
* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值
* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全
* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错
* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"
This reverts commit 1cbfdf9d1b.
* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接
* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型
* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转
* chore: ruff format
---------
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
528 lines
20 KiB
Python
528 lines
20 KiB
Python
import logging
|
|
from asyncio import Queue
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
from deprecated import deprecated
|
|
|
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
from astrbot.core.agent.message import Message
|
|
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
|
from astrbot.core.agent.tool import ToolSet
|
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
|
from astrbot.core.conversation_mgr import ConversationManager
|
|
from astrbot.core.db import BaseDatabase
|
|
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
|
from astrbot.core.message.message_event_result import MessageChain
|
|
from astrbot.core.persona_mgr import PersonaManager
|
|
from astrbot.core.platform import Platform
|
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
|
|
from astrbot.core.platform.manager import PlatformManager
|
|
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
|
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
|
|
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
|
|
from astrbot.core.provider.manager import ProviderManager
|
|
from astrbot.core.provider.provider import (
|
|
EmbeddingProvider,
|
|
Provider,
|
|
RerankProvider,
|
|
STTProvider,
|
|
TTSProvider,
|
|
)
|
|
from astrbot.core.star.filter.platform_adapter_type import (
|
|
ADAPTER_NAME_2_TYPE,
|
|
PlatformAdapterType,
|
|
)
|
|
|
|
from ..exceptions import ProviderNotFoundError
|
|
from .filter.command import CommandFilter
|
|
from .filter.regex import RegexFilter
|
|
from .star import StarMetadata, star_map, star_registry
|
|
from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
|
|
|
logger = logging.getLogger("astrbot")
|
|
|
|
|
|
class Context:
|
|
"""暴露给插件的接口上下文。"""
|
|
|
|
registered_web_apis: list = []
|
|
|
|
# back compatibility
|
|
_register_tasks: list[Awaitable] = []
|
|
_star_manager = None
|
|
|
|
def __init__(
|
|
self,
|
|
event_queue: Queue,
|
|
config: AstrBotConfig,
|
|
db: BaseDatabase,
|
|
provider_manager: ProviderManager,
|
|
platform_manager: PlatformManager,
|
|
conversation_manager: ConversationManager,
|
|
message_history_manager: PlatformMessageHistoryManager,
|
|
persona_manager: PersonaManager,
|
|
astrbot_config_mgr: AstrBotConfigManager,
|
|
knowledge_base_manager: KnowledgeBaseManager,
|
|
):
|
|
self._event_queue = event_queue
|
|
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
|
self._config = config
|
|
"""AstrBot 默认配置"""
|
|
self._db = db
|
|
"""AstrBot 数据库"""
|
|
self.provider_manager = provider_manager
|
|
self.platform_manager = platform_manager
|
|
self.conversation_manager = conversation_manager
|
|
self.message_history_manager = message_history_manager
|
|
self.persona_manager = persona_manager
|
|
self.astrbot_config_mgr = astrbot_config_mgr
|
|
self.kb_manager = knowledge_base_manager
|
|
|
|
async def llm_generate(
|
|
self,
|
|
*,
|
|
chat_provider_id: str,
|
|
prompt: str | None = None,
|
|
image_urls: list[str] | None = None,
|
|
tools: ToolSet | None = None,
|
|
system_prompt: str | None = None,
|
|
contexts: list[Message] | None = None,
|
|
**kwargs: Any,
|
|
) -> LLMResponse:
|
|
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
|
|
|
|
.. versionadded:: 4.5.7 (sdk)
|
|
|
|
Args:
|
|
chat_provider_id: The chat provider ID to use.
|
|
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
|
|
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
|
|
tools: ToolSet of tools available to the LLM
|
|
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
|
|
contexts: context messages for the LLM
|
|
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
|
|
|
|
Raises:
|
|
ChatProviderNotFoundError: If the specified chat provider ID is not found
|
|
Exception: For other errors during LLM generation
|
|
"""
|
|
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
|
|
if not prov or not isinstance(prov, Provider):
|
|
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
|
|
llm_resp = await prov.text_chat(
|
|
prompt=prompt,
|
|
image_urls=image_urls,
|
|
func_tool=tools,
|
|
contexts=contexts,
|
|
system_prompt=system_prompt,
|
|
**kwargs,
|
|
)
|
|
return llm_resp
|
|
|
|
async def tool_loop_agent(
|
|
self,
|
|
*,
|
|
event: AstrMessageEvent,
|
|
chat_provider_id: str,
|
|
prompt: str | None = None,
|
|
image_urls: list[str] | None = None,
|
|
tools: ToolSet | None = None,
|
|
system_prompt: str | None = None,
|
|
contexts: list[Message] | None = None,
|
|
max_steps: int = 30,
|
|
tool_call_timeout: int = 60,
|
|
**kwargs: Any,
|
|
) -> LLMResponse:
|
|
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
|
|
If you do not pass the agent_context parameter, the method will recreate a new agent context.
|
|
|
|
.. versionadded:: 4.5.7 (sdk)
|
|
|
|
Args:
|
|
chat_provider_id: The chat provider ID to use.
|
|
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
|
|
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
|
|
tools: ToolSet of tools available to the LLM
|
|
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
|
|
contexts: context messages for the LLM
|
|
max_steps: Maximum number of tool calls before stopping the loop
|
|
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
|
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
|
agent_context: AstrAgentContext - context to use for the agent
|
|
|
|
Returns:
|
|
The final LLMResponse after tool calls are completed.
|
|
|
|
Raises:
|
|
ChatProviderNotFoundError: If the specified chat provider ID is not found
|
|
Exception: For other errors during LLM generation
|
|
"""
|
|
# Import here to avoid circular imports
|
|
from astrbot.core.astr_agent_context import (
|
|
AgentContextWrapper,
|
|
AstrAgentContext,
|
|
)
|
|
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
|
|
|
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
|
|
if not prov or not isinstance(prov, Provider):
|
|
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
|
|
|
|
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
|
|
agent_context = kwargs.get("agent_context")
|
|
|
|
context_ = []
|
|
for msg in contexts or []:
|
|
if isinstance(msg, Message):
|
|
context_.append(msg.model_dump())
|
|
else:
|
|
context_.append(msg)
|
|
|
|
request = ProviderRequest(
|
|
prompt=prompt,
|
|
image_urls=image_urls or [],
|
|
func_tool=tools,
|
|
contexts=context_,
|
|
system_prompt=system_prompt or "",
|
|
)
|
|
if agent_context is None:
|
|
agent_context = AstrAgentContext(
|
|
context=self,
|
|
event=event,
|
|
)
|
|
agent_runner = ToolLoopAgentRunner()
|
|
tool_executor = FunctionToolExecutor()
|
|
await agent_runner.reset(
|
|
provider=prov,
|
|
request=request,
|
|
run_context=AgentContextWrapper(
|
|
context=agent_context,
|
|
tool_call_timeout=tool_call_timeout,
|
|
),
|
|
tool_executor=tool_executor,
|
|
agent_hooks=agent_hooks,
|
|
streaming=kwargs.get("stream", False),
|
|
)
|
|
async for _ in agent_runner.step_until_done(max_steps):
|
|
pass
|
|
llm_resp = agent_runner.get_final_llm_resp()
|
|
if not llm_resp:
|
|
raise Exception("Agent did not produce a final LLM response")
|
|
return llm_resp
|
|
|
|
async def get_current_chat_provider_id(self, umo: str) -> str:
|
|
"""Get the ID of the currently used chat provider.
|
|
|
|
Args:
|
|
umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
|
|
|
|
Raises:
|
|
ProviderNotFoundError: If the specified chat provider is not found
|
|
|
|
"""
|
|
prov = self.get_using_provider(umo)
|
|
if not prov:
|
|
raise ProviderNotFoundError("Provider not found")
|
|
return prov.meta().id
|
|
|
|
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
|
"""根据插件名获取插件的 Metadata"""
|
|
for star in star_registry:
|
|
if star.name == star_name:
|
|
return star
|
|
|
|
def get_all_stars(self) -> list[StarMetadata]:
|
|
"""获取当前载入的所有插件 Metadata 的列表"""
|
|
return star_registry
|
|
|
|
def get_llm_tool_manager(self) -> FunctionToolManager:
|
|
"""获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools"""
|
|
return self.provider_manager.llm_tools
|
|
|
|
def activate_llm_tool(self, name: str) -> bool:
|
|
"""激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
|
|
|
|
Returns:
|
|
如果没找到,会返回 False
|
|
|
|
"""
|
|
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
|
|
|
|
def deactivate_llm_tool(self, name: str) -> bool:
|
|
"""停用一个已经注册的函数调用工具。
|
|
|
|
Returns:
|
|
如果没找到,会返回 False
|
|
|
|
"""
|
|
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
|
|
|
|
def get_provider_by_id(
|
|
self,
|
|
provider_id: str,
|
|
) -> (
|
|
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
|
|
):
|
|
"""通过 ID 获取对应的 LLM Provider。"""
|
|
prov = self.provider_manager.inst_map.get(provider_id)
|
|
return prov
|
|
|
|
def get_all_providers(self) -> list[Provider]:
|
|
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
|
return self.provider_manager.provider_insts
|
|
|
|
def get_all_tts_providers(self) -> list[TTSProvider]:
|
|
"""获取所有用于 TTS 任务的 Provider。"""
|
|
return self.provider_manager.tts_provider_insts
|
|
|
|
def get_all_stt_providers(self) -> list[STTProvider]:
|
|
"""获取所有用于 STT 任务的 Provider。"""
|
|
return self.provider_manager.stt_provider_insts
|
|
|
|
def get_all_embedding_providers(self) -> list[EmbeddingProvider]:
|
|
"""获取所有用于 Embedding 任务的 Provider。"""
|
|
return self.provider_manager.embedding_provider_insts
|
|
|
|
def get_using_provider(self, umo: str | None = None) -> Provider:
|
|
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
|
|
|
Args:
|
|
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
|
|
|
"""
|
|
prov = self.provider_manager.get_using_provider(
|
|
provider_type=ProviderType.CHAT_COMPLETION,
|
|
umo=umo,
|
|
)
|
|
if not isinstance(prov, Provider):
|
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
|
return prov
|
|
|
|
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
|
|
"""获取当前使用的用于 TTS 任务的 Provider。
|
|
|
|
Args:
|
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
|
|
|
"""
|
|
prov = self.provider_manager.get_using_provider(
|
|
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
umo=umo,
|
|
)
|
|
if prov and not isinstance(prov, TTSProvider):
|
|
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
|
|
return prov
|
|
|
|
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
|
|
"""获取当前使用的用于 STT 任务的 Provider。
|
|
|
|
Args:
|
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
|
|
|
"""
|
|
prov = self.provider_manager.get_using_provider(
|
|
provider_type=ProviderType.SPEECH_TO_TEXT,
|
|
umo=umo,
|
|
)
|
|
if prov and not isinstance(prov, STTProvider):
|
|
raise ValueError("返回的 Provider 不是 STTProvider 类型")
|
|
return prov
|
|
|
|
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
|
"""获取 AstrBot 的配置。"""
|
|
if not umo:
|
|
# using default config
|
|
return self._config
|
|
return self.astrbot_config_mgr.get_conf(umo)
|
|
|
|
async def send_message(
|
|
self,
|
|
session: str | MessageSesion,
|
|
message_chain: MessageChain,
|
|
) -> bool:
|
|
"""根据 session(unified_msg_origin) 主动发送消息。
|
|
|
|
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
|
@param message_chain: 消息链。
|
|
|
|
@return: 是否找到匹配的平台。
|
|
|
|
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
|
|
|
NOTE: qq_official(QQ 官方 API 平台) 不支持此方法
|
|
"""
|
|
if isinstance(session, str):
|
|
try:
|
|
session = MessageSesion.from_str(session)
|
|
except BaseException as e:
|
|
raise ValueError("不合法的 session 字符串: " + str(e))
|
|
|
|
for platform in self.platform_manager.platform_insts:
|
|
if platform.meta().id == session.platform_name:
|
|
await platform.send_by_session(session, message_chain)
|
|
return True
|
|
return False
|
|
|
|
def add_llm_tools(self, *tools: FunctionTool) -> None:
|
|
"""添加 LLM 工具。"""
|
|
tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list}
|
|
module_path = ""
|
|
for tool in tools:
|
|
if not module_path:
|
|
_parts = []
|
|
module_part = tool.__module__.split(".")
|
|
flags = ["packages", "plugins"]
|
|
for i, part in enumerate(module_part):
|
|
_parts.append(part)
|
|
if part in flags and i + 1 < len(module_part):
|
|
_parts.append(module_part[i + 1])
|
|
break
|
|
tool.handler_module_path = ".".join(_parts)
|
|
module_path = tool.handler_module_path
|
|
else:
|
|
tool.handler_module_path = module_path
|
|
logger.info(
|
|
f"plugin(module_path {module_path}) added LLM tool: {tool.name}"
|
|
)
|
|
|
|
if tool.name in tool_name:
|
|
logger.warning("替换已存在的 LLM 工具: " + tool.name)
|
|
self.provider_manager.llm_tools.remove_func(tool.name)
|
|
self.provider_manager.llm_tools.func_list.append(tool)
|
|
|
|
def register_web_api(
|
|
self,
|
|
route: str,
|
|
view_handler: Awaitable,
|
|
methods: list,
|
|
desc: str,
|
|
):
|
|
for idx, api in enumerate(self.registered_web_apis):
|
|
if api[0] == route and methods == api[2]:
|
|
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
|
|
return
|
|
self.registered_web_apis.append((route, view_handler, methods, desc))
|
|
|
|
"""
|
|
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
|
"""
|
|
|
|
def get_event_queue(self) -> Queue:
|
|
"""获取事件队列。"""
|
|
return self._event_queue
|
|
|
|
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
|
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
|
|
"""获取指定类型的平台适配器。
|
|
|
|
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
|
|
"""
|
|
for platform in self.platform_manager.platform_insts:
|
|
name = platform.meta().name
|
|
if isinstance(platform_type, str):
|
|
if name == platform_type:
|
|
return platform
|
|
elif (
|
|
name in ADAPTER_NAME_2_TYPE
|
|
and ADAPTER_NAME_2_TYPE[name] & platform_type
|
|
):
|
|
return platform
|
|
|
|
def get_platform_inst(self, platform_id: str) -> Platform | None:
|
|
"""获取指定 ID 的平台适配器实例。
|
|
|
|
Args:
|
|
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
|
|
|
|
Returns:
|
|
Platform: 平台适配器实例,如果未找到则返回 None。
|
|
|
|
"""
|
|
for platform in self.platform_manager.platform_insts:
|
|
if platform.meta().id == platform_id:
|
|
return platform
|
|
|
|
def get_db(self) -> BaseDatabase:
|
|
"""获取 AstrBot 数据库。"""
|
|
return self._db
|
|
|
|
def register_provider(self, provider: Provider):
|
|
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
|
|
self.provider_manager.provider_insts.append(provider)
|
|
|
|
def register_llm_tool(
|
|
self,
|
|
name: str,
|
|
func_args: list,
|
|
desc: str,
|
|
func_obj: Callable[..., Awaitable[Any]],
|
|
) -> None:
|
|
"""[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。
|
|
|
|
@param name: 函数名
|
|
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
|
@param desc: 函数描述
|
|
@param func_obj: 异步处理函数。
|
|
|
|
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
|
"""
|
|
md = StarHandlerMetadata(
|
|
event_type=EventType.OnLLMRequestEvent,
|
|
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
|
handler_name=func_obj.__name__,
|
|
handler_module_path=func_obj.__module__,
|
|
handler=func_obj,
|
|
event_filters=[],
|
|
desc=desc,
|
|
)
|
|
star_handlers_registry.append(md)
|
|
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
|
|
|
|
def unregister_llm_tool(self, name: str) -> None:
|
|
"""[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
|
self.provider_manager.llm_tools.remove_func(name)
|
|
|
|
def register_commands(
|
|
self,
|
|
star_name: str,
|
|
command_name: str,
|
|
desc: str,
|
|
priority: int,
|
|
awaitable: Callable[..., Awaitable[Any]],
|
|
use_regex=False,
|
|
ignore_prefix=False,
|
|
):
|
|
"""注册一个命令。
|
|
|
|
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
|
|
|
|
@param star_name: 插件(Star)名称。
|
|
@param command_name: 命令名称。
|
|
@param desc: 命令描述。
|
|
@param priority: 优先级。1-10。
|
|
@param awaitable: 异步处理函数。
|
|
|
|
"""
|
|
md = StarHandlerMetadata(
|
|
event_type=EventType.AdapterMessageEvent,
|
|
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
|
handler_name=awaitable.__name__,
|
|
handler_module_path=awaitable.__module__,
|
|
handler=awaitable,
|
|
event_filters=[],
|
|
desc=desc,
|
|
)
|
|
if use_regex:
|
|
md.event_filters.append(RegexFilter(regex=command_name))
|
|
else:
|
|
md.event_filters.append(
|
|
CommandFilter(command_name=command_name, handler_md=md),
|
|
)
|
|
star_handlers_registry.append(md)
|
|
|
|
def register_task(self, task: Awaitable, desc: str):
|
|
"""[DEPRECATED]注册一个异步任务。"""
|
|
self._register_tasks.append(task)
|