Files
AstrBot/astrbot/core/star/register/star_handler.py
T
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00

558 lines
18 KiB
Python

from __future__ import annotations
import re
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import docstring_parser
from astrbot.core import logger
from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
from ..filter.event_message_type import EventMessageType, EventMessageTypeFilter
from ..filter.permission import PermissionType, PermissionTypeFilter
from ..filter.platform_adapter_type import (
PlatformAdapterType,
PlatformAdapterTypeFilter,
)
from ..filter.regex import RegexFilter
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
def get_handler_full_name(
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> str:
"""获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(
handler: Callable[
...,
Awaitable[MessageEventResult | str | None]
| AsyncGenerator[MessageEventResult | str | None],
],
event_type: EventType,
dont_add=False,
**kwargs,
) -> StarHandlerMetadata:
"""获取 Handler 或者创建一个新的 Handler"""
handler_full_name = get_handler_full_name(handler)
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
if md:
return md
md = StarHandlerMetadata(
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[],
)
# 插件handler的附加额外信息
if handler.__doc__:
md.desc = handler.__doc__.strip()
if "desc" in kwargs:
md.desc = kwargs["desc"]
del kwargs["desc"]
md.extras_configs = kwargs
if not dont_add:
star_handlers_registry.append(md)
return md
def register_command(
command_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
):
"""注册一个 Command."""
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
# 子指令
if sub_command is not None:
parent_command_names = (
command_name.parent_group.get_complete_command_names()
)
new_command = CommandFilter(
sub_command,
alias,
None,
parent_command_names=parent_command_names,
)
command_name.parent_group.add_sub_command_filter(new_command)
else:
logger.warning(
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。",
)
# 裸指令
elif command_name is None:
logger.warning("注册裸指令时未提供 command_name 参数。")
else:
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
def decorator(awaitable):
if not add_to_event_filters:
kwargs["sub_command"] = (
True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管)
)
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
**kwargs,
)
if new_command:
new_command.init_handler_md(handler_md)
handler_md.event_filters.append(new_command)
return awaitable
return decorator
def register_custom_filter(custom_type_filter, *args, **kwargs):
"""注册一个自定义的 CustomFilter
Args:
custom_type_filter: 在裸指令时为CustomFilter对象
在指令组时为父指令的RegisteringCommandable对象,即self或者command_group的返回
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
"""
add_to_event_filters = False
raise_error = True
# 判断是否是指令组,指令组则添加到指令组的CommandGroupFilter对象中在waking_check的时候一起判断
if isinstance(custom_type_filter, RegisteringCommandable):
# 子指令, 此时函数为RegisteringCommandable对象的方法,首位参数为RegisteringCommandable对象的self。
parent_register_commandable = custom_type_filter
custom_filter = args[0]
if len(args) > 1:
raise_error = args[1]
else:
# 裸指令
add_to_event_filters = True
custom_filter = custom_type_filter
if args:
raise_error = args[0]
if not isinstance(custom_filter, (CustomFilterAnd, CustomFilterOr)):
custom_filter = custom_filter(raise_error)
def decorator(awaitable):
# 裸指令,子指令与指令组的区分,指令组会因为标记跳过wake。
if (
not add_to_event_filters and isinstance(awaitable, RegisteringCommandable)
) or (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)):
# 指令组 与 根指令组,添加到本层的grouphandle中一起判断
awaitable.parent_group.add_custom_filter(custom_filter)
else:
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
**kwargs,
)
if not add_to_event_filters and not isinstance(
awaitable,
RegisteringCommandable,
):
# 底层子指令
handle_full_name = get_handler_full_name(awaitable)
for (
sub_handle
) in parent_register_commandable.parent_group.sub_command_filters:
if isinstance(sub_handle, CommandGroupFilter):
continue
# 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
sub_handle_md = sub_handle.get_handler_md()
if (
sub_handle_md
and sub_handle_md.handler_full_name == handle_full_name
):
sub_handle.add_custom_filter(custom_filter)
else:
# 裸指令
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
assert isinstance(awaitable, Callable)
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
**kwargs,
)
handler_md.event_filters.append(custom_filter)
return awaitable
return decorator
def register_command_group(
command_group_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
):
"""注册一个 CommandGroup"""
new_group = None
if isinstance(command_group_name, RegisteringCommandable):
# 子指令组
if sub_command is None:
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
else:
new_group = CommandGroupFilter(
sub_command,
alias,
parent_group=command_group_name.parent_group,
)
command_group_name.parent_group.add_sub_command_filter(new_group)
# 根指令组
elif command_group_name is None:
logger.warning("根指令组的名称未指定")
else:
new_group = CommandGroupFilter(command_group_name, alias)
def decorator(obj):
if new_group:
handler_md = get_handler_or_create(
obj,
EventType.AdapterMessageEvent,
**kwargs,
)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
raise ValueError("注册指令组失败。")
return decorator
class RegisteringCommandable:
"""用于指令组级联注册"""
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
command: Callable[..., Callable[..., None]] = register_command
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
def register_event_message_type(event_message_type: EventMessageType, **kwargs):
"""注册一个 EventMessageType"""
def decorator(awaitable):
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
**kwargs,
)
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
return awaitable
return decorator
def register_platform_adapter_type(
platform_adapter_type: PlatformAdapterType,
**kwargs,
):
"""注册一个 PlatformAdapterType"""
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(
PlatformAdapterTypeFilter(platform_adapter_type),
)
return awaitable
return decorator
def register_regex(regex: str, **kwargs):
"""注册一个 Regex"""
def decorator(awaitable):
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
**kwargs,
)
handler_md.event_filters.append(RegexFilter(regex))
return awaitable
return decorator
def register_permission_type(permission_type: PermissionType, raise_error: bool = True):
"""注册一个 PermissionType
Args:
permission_type: PermissionType
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
"""
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(
PermissionTypeFilter(permission_type, raise_error),
)
return awaitable
return decorator
def register_on_astrbot_loaded(**kwargs):
"""当 AstrBot 加载完成时"""
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnAstrBotLoadedEvent, **kwargs)
return awaitable
return decorator
def register_on_platform_loaded(**kwargs):
"""当平台加载完成时"""
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnPlatformLoadedEvent, **kwargs)
return awaitable
return decorator
def register_on_llm_request(**kwargs):
"""当有 LLM 请求时的事件
Examples:
```py
from astrbot.api.provider import ProviderRequest
@on_llm_request()
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
request.system_prompt += "你是一个猫娘..."
```
请务必接收两个参数:event, request
"""
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs)
return awaitable
return decorator
def register_on_llm_response(**kwargs):
"""当有 LLM 请求后的事件
Examples:
```py
from astrbot.api.provider import LLMResponse
@on_llm_response()
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
...
```
请务必接收两个参数:event, request
"""
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs)
return awaitable
return decorator
def register_llm_tool(name: str | None = None, **kwargs):
"""为函数调用(function-calling / tools-use)添加工具。
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str):
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
可接受的参数类型有:string, number, object, array, boolean。
返回值:
- 返回 str:结果会被加入下一次 LLM 请求的 prompt 中,用于让 LLM 总结工具返回的结果
- 返回 None:结果不会被加入下一次 LLM 请求的 prompt 中。
可以使用 yield 发送消息、终止事件。
发送消息:请参考文档。
终止事件:
```
event.stop_event()
yield
```
"""
name_ = name
registering_agent = None
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
def decorator(
awaitable: Callable[
...,
AsyncGenerator[MessageEventResult | str | None]
| Awaitable[MessageEventResult | str | None],
],
):
llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc)
args = []
for arg in docstring.params:
sub_type_name = None
type_name = arg.type_name
if not type_name:
raise ValueError(
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 的参数 {arg.arg_name} 缺少类型注释。",
)
# parse type_name to handle cases like "list[string]"
match = re.match(r"(\w+)\[(\w+)\]", type_name)
if match:
type_name = match.group(1)
sub_type_name = match.group(2)
type_name = PY_TO_JSON_TYPE.get(type_name, type_name)
if sub_type_name:
sub_type_name = PY_TO_JSON_TYPE.get(sub_type_name, sub_type_name)
if type_name not in SUPPORTED_TYPES or (
sub_type_name and sub_type_name not in SUPPORTED_TYPES
):
raise ValueError(
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}",
)
arg_json_schema = {
"type": type_name,
"name": arg.arg_name,
"description": arg.description,
}
if sub_type_name:
if type_name == "array":
arg_json_schema["items"] = {"type": sub_type_name}
args.append(arg_json_schema)
if not registering_agent:
doc_desc = docstring.description.strip() if docstring.description else ""
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
else:
assert isinstance(registering_agent, RegisteringAgent)
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
if registering_agent._agent.tools is None:
registering_agent._agent.tools = []
desc = docstring.description.strip() if docstring.description else ""
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
registering_agent._agent.tools.append(tool)
return awaitable
return decorator
class RegisteringAgent:
"""用于 Agent 注册"""
def llm_tool(self, *args, **kwargs):
kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[AstrAgentContext]):
self._agent = agent
def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
):
"""注册一个 Agent
Args:
name: Agent 的名称
instruction: Agent 的指令
tools: Agent 使用的工具列表
run_hooks: Agent 运行时的钩子函数
"""
tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,
instructions=instruction,
tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable
llm_tools.func_list.append(handoff_tool)
return RegisteringAgent(agent)
return decorator
def register_on_decorating_result(**kwargs):
"""在发送消息前的事件"""
def decorator(awaitable):
_ = get_handler_or_create(
awaitable,
EventType.OnDecoratingResultEvent,
**kwargs,
)
return awaitable
return decorator
def register_after_message_sent(**kwargs):
"""在消息发送后的事件"""
def decorator(awaitable):
_ = get_handler_or_create(
awaitable,
EventType.OnAfterMessageSentEvent,
**kwargs,
)
return awaitable
return decorator