Files
AstrBot/astrbot/core/star/context.py
T
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00

528 lines
20 KiB
Python

import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from typing import Any
from deprecated import deprecated
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import Message
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.db import BaseDatabase
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import (
EmbeddingProvider,
Provider,
RerankProvider,
STTProvider,
TTSProvider,
)
from astrbot.core.star.filter.platform_adapter_type import (
ADAPTER_NAME_2_TYPE,
PlatformAdapterType,
)
from ..exceptions import ProviderNotFoundError
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from .star import StarMetadata, star_map, star_registry
from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
logger = logging.getLogger("astrbot")
class Context:
"""暴露给插件的接口上下文。"""
registered_web_apis: list = []
# back compatibility
_register_tasks: list[Awaitable] = []
_star_manager = None
def __init__(
self,
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager,
platform_manager: PlatformManager,
conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
astrbot_config_mgr: AstrBotConfigManager,
knowledge_base_manager: KnowledgeBaseManager,
):
self._event_queue = event_queue
"""事件队列。消息平台通过事件队列传递消息事件。"""
self._config = config
"""AstrBot 默认配置"""
self._db = db
"""AstrBot 数据库"""
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.conversation_manager = conversation_manager
self.message_history_manager = message_history_manager
self.persona_manager = persona_manager
self.astrbot_config_mgr = astrbot_config_mgr
self.kb_manager = knowledge_base_manager
async def llm_generate(
self,
*,
chat_provider_id: str,
prompt: str | None = None,
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
.. versionadded:: 4.5.7 (sdk)
Args:
chat_provider_id: The chat provider ID to use.
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
tools: ToolSet of tools available to the LLM
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
contexts: context messages for the LLM
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
Raises:
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
if not prov or not isinstance(prov, Provider):
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
llm_resp = await prov.text_chat(
prompt=prompt,
image_urls=image_urls,
func_tool=tools,
contexts=contexts,
system_prompt=system_prompt,
**kwargs,
)
return llm_resp
async def tool_loop_agent(
self,
*,
event: AstrMessageEvent,
chat_provider_id: str,
prompt: str | None = None,
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | None = None,
max_steps: int = 30,
tool_call_timeout: int = 60,
**kwargs: Any,
) -> LLMResponse:
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
If you do not pass the agent_context parameter, the method will recreate a new agent context.
.. versionadded:: 4.5.7 (sdk)
Args:
chat_provider_id: The chat provider ID to use.
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
tools: ToolSet of tools available to the LLM
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
Returns:
The final LLMResponse after tool calls are completed.
Raises:
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
# Import here to avoid circular imports
from astrbot.core.astr_agent_context import (
AgentContextWrapper,
AstrAgentContext,
)
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
if not prov or not isinstance(prov, Provider):
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
agent_context = kwargs.get("agent_context")
context_ = []
for msg in contexts or []:
if isinstance(msg, Message):
context_.append(msg.model_dump())
else:
context_.append(msg)
request = ProviderRequest(
prompt=prompt,
image_urls=image_urls or [],
func_tool=tools,
contexts=context_,
system_prompt=system_prompt or "",
)
if agent_context is None:
agent_context = AstrAgentContext(
context=self,
event=event,
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
await agent_runner.reset(
provider=prov,
request=request,
run_context=AgentContextWrapper(
context=agent_context,
tool_call_timeout=tool_call_timeout,
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
)
async for _ in agent_runner.step_until_done(max_steps):
pass
llm_resp = agent_runner.get_final_llm_resp()
if not llm_resp:
raise Exception("Agent did not produce a final LLM response")
return llm_resp
async def get_current_chat_provider_id(self, umo: str) -> str:
"""Get the ID of the currently used chat provider.
Args:
umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
Raises:
ProviderNotFoundError: If the specified chat provider is not found
"""
prov = self.get_using_provider(umo)
if not prov:
raise ProviderNotFoundError("Provider not found")
return prov.meta().id
def get_registered_star(self, star_name: str) -> StarMetadata | None:
"""根据插件名获取插件的 Metadata"""
for star in star_registry:
if star.name == star_name:
return star
def get_all_stars(self) -> list[StarMetadata]:
"""获取当前载入的所有插件 Metadata 的列表"""
return star_registry
def get_llm_tool_manager(self) -> FunctionToolManager:
"""获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools"""
return self.provider_manager.llm_tools
def activate_llm_tool(self, name: str) -> bool:
"""激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
Returns:
如果没找到,会返回 False
"""
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False
"""
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
def get_provider_by_id(
self,
provider_id: str,
) -> (
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
):
"""通过 ID 获取对应的 LLM Provider。"""
prov = self.provider_manager.inst_map.get(provider_id)
return prov
def get_all_providers(self) -> list[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.provider_insts
def get_all_tts_providers(self) -> list[TTSProvider]:
"""获取所有用于 TTS 任务的 Provider。"""
return self.provider_manager.tts_provider_insts
def get_all_stt_providers(self) -> list[STTProvider]:
"""获取所有用于 STT 任务的 Provider。"""
return self.provider_manager.stt_provider_insts
def get_all_embedding_providers(self) -> list[EmbeddingProvider]:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
"""
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
"""获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
if prov and not isinstance(prov, TTSProvider):
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
return prov
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
"""获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
if prov and not isinstance(prov, STTProvider):
raise ValueError("返回的 Provider 不是 STTProvider 类型")
return prov
def get_config(self, umo: str | None = None) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
if not umo:
# using default config
return self._config
return self.astrbot_config_mgr.get_conf(umo)
async def send_message(
self,
session: str | MessageSesion,
message_chain: MessageChain,
) -> bool:
"""根据 session(unified_msg_origin) 主动发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
NOTE: qq_official(QQ 官方 API 平台) 不支持此方法
"""
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().id == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
def add_llm_tools(self, *tools: FunctionTool) -> None:
"""添加 LLM 工具。"""
tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list}
module_path = ""
for tool in tools:
if not module_path:
_parts = []
module_part = tool.__module__.split(".")
flags = ["packages", "plugins"]
for i, part in enumerate(module_part):
_parts.append(part)
if part in flags and i + 1 < len(module_part):
_parts.append(module_part[i + 1])
break
tool.handler_module_path = ".".join(_parts)
module_path = tool.handler_module_path
else:
tool.handler_module_path = module_path
logger.info(
f"plugin(module_path {module_path}) added LLM tool: {tool.name}"
)
if tool.name in tool_name:
logger.warning("替换已存在的 LLM 工具: " + tool.name)
self.provider_manager.llm_tools.remove_func(tool.name)
self.provider_manager.llm_tools.func_list.append(tool)
def register_web_api(
self,
route: str,
view_handler: Awaitable,
methods: list,
desc: str,
):
for idx, api in enumerate(self.registered_web_apis):
if api[0] == route and methods == api[2]:
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
return
self.registered_web_apis.append((route, view_handler, methods, desc))
"""
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""
def get_event_queue(self) -> Queue:
"""获取事件队列。"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if name == platform_type:
return platform
elif (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
return self._db
def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
self.provider_manager.provider_insts.append(provider)
def register_llm_tool(
self,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
"""[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
"""
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc,
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
def unregister_llm_tool(self, name: str) -> None:
"""[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。"""
self.provider_manager.llm_tools.remove_func(name)
def register_commands(
self,
star_name: str,
command_name: str,
desc: str,
priority: int,
awaitable: Callable[..., Awaitable[Any]],
use_regex=False,
ignore_prefix=False,
):
"""注册一个命令。
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
@param star_name: 插件(Star)名称。
@param command_name: 命令名称。
@param desc: 命令描述。
@param priority: 优先级。1-10。
@param awaitable: 异步处理函数。
"""
md = StarHandlerMetadata(
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_path=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc,
)
if use_regex:
md.event_filters.append(RegexFilter(regex=command_name))
else:
md.event_filters.append(
CommandFilter(command_name=command_name, handler_md=md),
)
star_handlers_registry.append(md)
def register_task(self, task: Awaitable, desc: str):
"""[DEPRECATED]注册一个异步任务。"""
self._register_tasks.append(task)