fix: 修复插件指令注解为联合类型时处理异常的问题 (#2925)

* fix: 修复插件指令注解为联合类型时处理异常的问题

* fix: 修复参数类型检查以支持 typing.Union

* Update astrbot/core/star/filter/command.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/star/filter/command.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: 修复参数类型检查以支持 typing.Union 的处理逻辑

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Soulter
2025-10-01 21:46:49 +08:00
committed by GitHub
parent afe007ca0b
commit f8a4b54165
+32 -2
View File
@@ -1,5 +1,7 @@
import re
import inspect
import types
import typing
from typing import List, Any, Type, Dict
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -14,6 +16,18 @@ class GreedyStr(str):
pass
def unwrap_optional(annotation) -> tuple:
"""去掉 Optional[T] / Union[T, None] / T|None,返回 T"""
args = typing.get_args(annotation)
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1:
return (non_none_args[0],)
elif len(non_none_args) > 1:
return tuple(non_none_args)
else:
return ()
# 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter):
"""标准指令过滤器"""
@@ -40,6 +54,8 @@ class CommandFilter(HandlerFilter):
for k, v in self.handler_params.items():
if isinstance(v, type):
result += f"{k}({v.__name__}),"
elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union:
result += f"{k}({v}),"
else:
result += f"{k}({type(v).__name__})={v},"
result = result.rstrip(",")
@@ -95,7 +111,8 @@ class CommandFilter(HandlerFilter):
# 没有 GreedyStr 的情况
if i >= len(params):
if (
isinstance(param_type_or_default_val, Type)
isinstance(param_type_or_default_val, (Type, types.UnionType))
or typing.get_origin(param_type_or_default_val) is typing.Union
or param_type_or_default_val is inspect.Parameter.empty
):
# 是类型
@@ -132,7 +149,20 @@ class CommandFilter(HandlerFilter):
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
origin = typing.get_origin(param_type_or_default_val)
if origin in (typing.Union, types.UnionType):
# 注解是联合类型
# NOTE: 目前没有处理联合类型嵌套相关的注解写法
nn_types = unwrap_optional(param_type_or_default_val)
if len(nn_types) == 1:
# 只有一个非 NoneType 类型
result[param_name] = nn_types[0](params[i])
else:
# 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。
# NOTE: 目前还没有做类型校验
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
raise ValueError(
f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"