From f8a4b54165d3df02332f0093684897e1cb6e4359 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Wed, 1 Oct 2025 21:46:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E6=8C=87=E4=BB=A4=E6=B3=A8=E8=A7=A3=E4=B8=BA=E8=81=94=E5=90=88?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=97=B6=E5=A4=84=E7=90=86=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98=20(#2925)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 修复插件指令注解为联合类型时处理异常的问题 * fix: 修复参数类型检查以支持 typing.Union * Update astrbot/core/star/filter/command.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/star/filter/command.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: 修复参数类型检查以支持 typing.Union 的处理逻辑 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- astrbot/core/star/filter/command.py | 34 +++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index d8a1eb22e..3d67cb750 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -1,5 +1,7 @@ import re import inspect +import types +import typing from typing import List, Any, Type, Dict from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -14,6 +16,18 @@ class GreedyStr(str): pass +def unwrap_optional(annotation) -> tuple: + """去掉 Optional[T] / Union[T, None] / T|None,返回 T""" + args = typing.get_args(annotation) + non_none_args = [a for a in args if a is not type(None)] + if len(non_none_args) == 1: + return (non_none_args[0],) + elif len(non_none_args) > 1: + return tuple(non_none_args) + else: + return () + + # 标准指令受到 wake_prefix 的制约。 class CommandFilter(HandlerFilter): """标准指令过滤器""" @@ -40,6 +54,8 @@ class CommandFilter(HandlerFilter): for k, v in self.handler_params.items(): if isinstance(v, type): result += f"{k}({v.__name__})," + elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union: + result += f"{k}({v})," else: result += f"{k}({type(v).__name__})={v}," result = result.rstrip(",") @@ -95,7 +111,8 @@ class CommandFilter(HandlerFilter): # 没有 GreedyStr 的情况 if i >= len(params): if ( - isinstance(param_type_or_default_val, Type) + isinstance(param_type_or_default_val, (Type, types.UnionType)) + or typing.get_origin(param_type_or_default_val) is typing.Union or param_type_or_default_val is inspect.Parameter.empty ): # 是类型 @@ -132,7 +149,20 @@ class CommandFilter(HandlerFilter): elif isinstance(param_type_or_default_val, float): result[param_name] = float(params[i]) else: - result[param_name] = param_type_or_default_val(params[i]) + origin = typing.get_origin(param_type_or_default_val) + if origin in (typing.Union, types.UnionType): + # 注解是联合类型 + # NOTE: 目前没有处理联合类型嵌套相关的注解写法 + nn_types = unwrap_optional(param_type_or_default_val) + if len(nn_types) == 1: + # 只有一个非 NoneType 类型 + result[param_name] = nn_types[0](params[i]) + else: + # 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。 + # NOTE: 目前还没有做类型校验 + result[param_name] = params[i] + else: + result[param_name] = param_type_or_default_val(params[i]) except ValueError: raise ValueError( f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"