fix: 修复 astrbot.core.star 等包下的 type checking error (#2787)

* fix: 修复 astrbot.core.star 等包下的 type checking error

* refactor: improve type checking and annotations

* chore: ruff format
This commit is contained in:
Soulter
2025-09-21 18:10:04 +08:00
committed by GitHub
parent a0ce1855ab
commit 80a86f5b1b
23 changed files with 261 additions and 130 deletions
+1 -1
View File
@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
tools: list[str | FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None
+3 -1
View File
@@ -92,7 +92,7 @@ class MCPClient:
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
@@ -198,6 +198,8 @@ class MCPClient:
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response
+28 -17
View File
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from typing import Awaitable, Callable, Literal, Any, Optional
from .mcp_client import MCPClient
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str | None = None
name: str
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
@@ -51,7 +51,7 @@ class ToolSet:
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] = None):
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
@@ -79,7 +79,13 @@ class ToolSet:
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
@@ -104,7 +110,7 @@ class ToolSet:
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
def get_func(self, name: str) -> FunctionTool | None:
"""Get all function tools."""
return self.get_tool(name)
@@ -125,7 +131,11 @@ class ToolSet:
},
}
if tool.parameters.get("properties") or not omit_empty_parameter_field:
if (
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
@@ -135,14 +145,14 @@ class ToolSet:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
"input_schema": input_schema,
}
result.append(tool_def)
return result
@@ -210,14 +220,15 @@ class ToolSet:
return result
tools = [
{
tools = []
for tool in self.tools:
d = {
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)
declarations = {}
if tools:
@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str = None
self, event: AstrMessageEvent, check_text: str | None = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
def check(self, content: str):
def check(self, content: str) -> tuple[bool, str]:
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
return False, ""
@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
def check(self, content: str) -> bool:
def check(self, content: str) -> tuple[bool, str]:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"
+4 -1
View File
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
handler: T.Callable[..., T.Awaitable[T.Any]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -36,6 +36,9 @@ async def call_handler(
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
@@ -7,6 +7,7 @@ import copy
import json
import traceback
from typing import AsyncGenerator, Union
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
@@ -133,6 +134,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
@@ -148,7 +158,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
@@ -200,7 +210,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
)
@@ -325,7 +339,7 @@ class LLMRequestSubStage(Stage):
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent):
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
@@ -337,6 +351,8 @@ class LLMRequestSubStage(Stage):
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def process(
@@ -444,7 +460,10 @@ class LLMRequestSubStage(Stage):
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
try:
wrapper = call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
+13 -3
View File
@@ -65,13 +65,16 @@ class AssistantMessageSegment:
role: str = "assistant"
def to_dict(self):
ret = {
ret: dict[str, str | list[dict]] = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
if self.tool_calls:
ret["tool_calls"] = self.tool_calls
tool_calls_dict = [
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
]
ret["tool_calls"] = tool_calls_dict
return ret
@@ -117,7 +120,14 @@ class ProviderRequest:
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
return (
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
f"image_count={len(self.image_urls or [])}, "
f"func_tool={self.func_tool}, "
f"contexts={self._print_friendly_context()}, "
f"system_prompt={self.system_prompt}, "
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
)
def __str__(self):
return self.__repr__()
+4 -4
View File
@@ -4,7 +4,7 @@ import os
import asyncio
import aiohttp
from typing import Dict, List, Awaitable
from typing import Dict, List, Awaitable, Callable, Any
from astrbot import logger
from astrbot.core import sp
@@ -109,7 +109,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -132,7 +132,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> None:
"""添加函数调用工具
@@ -220,7 +220,7 @@ class FunctionToolManager:
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
+35 -19
View File
@@ -38,7 +38,7 @@ class ProviderManager:
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map: dict[str, Provider] = {}
self.inst_map: dict[str, Provider | STTProvider | TTSProvider] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -87,19 +87,31 @@ class ProviderManager:
)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
prov = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
prov, TTSProvider
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.SPEECH_TO_TEXT:
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov, STTProvider
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.CHAT_COMPLETION:
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov, Provider
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(self, provider_type: ProviderType, umo=None):
def get_using_provider(
self, provider_type: ProviderType, umo=None
) -> Provider | STTProvider | TTSProvider | None:
"""获取正在使用的提供商实例。
Args:
@@ -303,12 +315,14 @@ class ProviderManager:
provider_metadata = provider_cls_map[provider_config["type"]]
try:
# 按任务实例化提供商
cls_type = provider_metadata.cls_type
if not cls_type:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -327,9 +341,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -345,7 +357,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
inst = cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
@@ -370,9 +382,7 @@ class ProviderManager:
ProviderType.EMBEDDING,
ProviderType.RERANK,
]:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
@@ -430,11 +440,17 @@ class ProviderManager:
)
if self.inst_map[provider_id] in self.provider_insts:
self.provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, Provider):
self.provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.stt_provider_insts:
self.stt_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, STTProvider):
self.stt_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.tts_provider_insts:
self.tts_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, TTSProvider):
self.tts_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] == self.curr_provider_inst:
self.curr_provider_inst = None
+26 -12
View File
@@ -23,7 +23,7 @@ from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from typing import Awaitable, Any, Callable
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
@@ -105,7 +105,10 @@ class Context:
def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
prov = self.provider_manager.inst_map.get(provider_id)
if prov and not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
def get_all_providers(self) -> List[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
@@ -130,34 +133,43 @@ class Context:
Args:
umo(str): unified_message_origin 如果传入并且用户启用了提供商会话隔离则使用该会话偏好的提供商
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if prov and not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
"""
获取当前使用的用于 TTS 任务的 Provider
Args:
umo(str): unified_message_origin 如果传入则使用该会话偏好的提供商
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
if prov and not isinstance(prov, TTSProvider):
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
return prov
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
"""
获取当前使用的用于 STT 任务的 Provider
Args:
umo(str): unified_message_origin 如果传入则使用该会话偏好的提供商
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
if prov and not isinstance(prov, STTProvider):
raise ValueError("返回的 Provider 不是 STTProvider 类型")
return prov
def get_config(self, umo: str | None = None) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
@@ -245,7 +257,11 @@ class Context:
"""
def register_llm_tool(
self, name: str, func_args: list, desc: str, func_obj: Awaitable
self,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
"""
为函数调用function-calling / tools-use添加工具
@@ -267,9 +283,7 @@ class Context:
desc=desc,
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(
name, func_args, desc, func_obj, func_obj
)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
def unregister_llm_tool(self, name: str) -> None:
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
@@ -281,7 +295,7 @@ class Context:
command_name: str,
desc: str,
priority: int,
awaitable: Awaitable,
awaitable: Callable[..., Awaitable[Any]],
use_regex=False,
ignore_prefix=False,
):
+4 -4
View File
@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
def __init__(
self,
group_name: str,
alias: set = None,
parent_group: CommandGroupFilter = None,
alias: set | None = None,
parent_group: CommandGroupFilter | None = None,
):
self.group_name = group_name
self.alias = alias if alias else set()
@@ -54,8 +54,8 @@ class CommandGroupFilter(HandlerFilter):
self,
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
prefix: str = "",
event: AstrMessageEvent = None,
cfg: AstrBotConfig = None,
event: AstrMessageEvent | None = None,
cfg: AstrBotConfig | None = None,
) -> str:
result = ""
for sub_filter in sub_command_filters:
@@ -2,7 +2,6 @@ import enum
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from typing import Union
class PlatformAdapterType(enum.Flag):
@@ -57,11 +56,14 @@ ADAPTER_NAME_2_TYPE = {
class PlatformAdapterTypeFilter(HandlerFilter):
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
self.type_or_str = platform_adapter_type_or_str
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
if isinstance(platform_adapter_type_or_str, str):
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
else:
self.platform_type = platform_adapter_type_or_str
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
adapter_name = event.get_platform_name()
if adapter_name in ADAPTER_NAME_2_TYPE:
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
return False
+3 -1
View File
@@ -5,7 +5,9 @@ from astrbot.core.star import StarMetadata, star_map
_warned_register_star = False
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
def register_star(
name: str, author: str, desc: str, version: str, repo: str | None = None
):
"""注册一个插件(Star)。
[DEPRECATED] 该装饰器已废弃将在未来版本中移除
+65 -36
View File
@@ -12,7 +12,7 @@ from ..filter.platform_adapter_type import (
from ..filter.permission import PermissionTypeFilter, PermissionType
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
from ..filter.regex import RegexFilter
from typing import Awaitable
from typing import Awaitable, Any, Callable
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core.agent.agent import Agent
@@ -20,15 +20,19 @@ from astrbot.core.agent.tool import FunctionTool
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core import logger
def get_handler_full_name(awaitable: Awaitable) -> str:
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
"""获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(
handler: Awaitable, event_type: EventType, dont_add=False, **kwargs
handler: Callable[..., Awaitable[Any]],
event_type: EventType,
dont_add=False,
**kwargs,
) -> StarHandlerMetadata:
"""获取 Handler 或者创建一个新的 Handler"""
handler_full_name = get_handler_full_name(handler)
@@ -59,22 +63,35 @@ def get_handler_or_create(
def register_command(
command_name: str = None, sub_command: str = None, alias: set = None, **kwargs
command_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
):
"""注册一个 Command."""
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
# 子指令
parent_command_names = command_name.parent_group.get_complete_command_names()
new_command = CommandFilter(
sub_command, alias, None, parent_command_names=parent_command_names
)
command_name.parent_group.add_sub_command_filter(new_command)
if sub_command is not None:
parent_command_names = (
command_name.parent_group.get_complete_command_names()
)
new_command = CommandFilter(
sub_command, alias, None, parent_command_names=parent_command_names
)
command_name.parent_group.add_sub_command_filter(new_command)
else:
logger.warning(
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。"
)
else:
# 裸指令
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
if command_name is None:
logger.warning("注册裸指令时未提供 command_name 参数。")
else:
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
def decorator(awaitable):
if not add_to_event_filters:
@@ -84,8 +101,9 @@ def register_command(
handler_md = get_handler_or_create(
awaitable, EventType.AdapterMessageEvent, **kwargs
)
new_command.init_handler_md(handler_md)
handler_md.event_filters.append(new_command)
if new_command:
new_command.init_handler_md(handler_md)
handler_md.event_filters.append(new_command)
return awaitable
return decorator
@@ -163,26 +181,38 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
def register_command_group(
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
command_group_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
):
"""注册一个 CommandGroup"""
new_group = None
if isinstance(command_group_name, RegisteringCommandable):
# 子指令组
new_group = CommandGroupFilter(
sub_command, alias, parent_group=command_group_name.parent_group
)
command_group_name.parent_group.add_sub_command_filter(new_group)
if sub_command is None:
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
else:
new_group = CommandGroupFilter(
sub_command, alias, parent_group=command_group_name.parent_group
)
command_group_name.parent_group.add_sub_command_filter(new_group)
else:
# 根指令组
new_group = CommandGroupFilter(command_group_name, alias)
if command_group_name is None:
logger.warning("根指令组的名称未指定")
else:
new_group = CommandGroupFilter(command_group_name, alias)
def decorator(obj):
# 根指令组
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
if new_group:
handler_md = get_handler_or_create(
obj, EventType.AdapterMessageEvent, **kwargs
)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
return RegisteringCommandable(new_group)
return decorator
@@ -323,7 +353,7 @@ def register_on_llm_response(**kwargs):
return decorator
def register_llm_tool(name: str = None, **kwargs):
def register_llm_tool(name: str | None = None, **kwargs):
"""为函数调用(function-calling / tools-use)添加工具。
请务必按照以下格式编写一个工具包括函数注释AstrBot 会尝试解析该函数注释
@@ -361,9 +391,10 @@ def register_llm_tool(name: str = None, **kwargs):
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Awaitable):
def decorator(awaitable: Callable[..., Awaitable[Any]]):
llm_tool_name = name_ if name_ else awaitable.__name__
docstring = docstring_parser.parse(awaitable.__doc__)
func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
@@ -379,20 +410,18 @@ def register_llm_tool(name: str = None, **kwargs):
)
# print(llm_tool_name, registering_agent)
if not registering_agent:
doc_desc = docstring.description.strip() if docstring.description else ""
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(
llm_tool_name, args, docstring.description.strip(), md.handler
)
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
else:
assert isinstance(registering_agent, RegisteringAgent)
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
if registering_agent._agent.tools is None:
registering_agent._agent.tools = []
registering_agent._agent.tools.append(
llm_tools.spec_to_func(
llm_tool_name, args, docstring.description.strip(), awaitable
)
)
desc = docstring.description.strip() if docstring.description else ""
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
registering_agent._agent.tools.append(tool)
return awaitable
@@ -413,8 +442,8 @@ class RegisteringAgent:
def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
):
"""注册一个 Agent
@@ -426,7 +455,7 @@ def register_agent(
"""
tools_ = tools or []
def decorator(awaitable: Awaitable):
def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,
@@ -140,6 +140,9 @@ class SessionPluginManager:
filtered_handlers.append(handler)
continue
if plugin.name is None:
continue
# 检查插件是否在当前会话中启用
if SessionPluginManager.is_plugin_enabled_for_session(
session_id, plugin.name
+4 -4
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import Awaitable, List, Dict, TypeVar, Generic
from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
@@ -60,7 +60,7 @@ class StarHandlerRegistry(Generic[T]):
handlers.append(handler)
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None:
return self.star_handlers_map.get(full_name, None)
def get_handlers_by_module_name(
@@ -87,7 +87,7 @@ class StarHandlerRegistry(Generic[T]):
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()
star_handlers_registry = StarHandlerRegistry() # type: ignore
class EventType(enum.Enum):
@@ -123,7 +123,7 @@ class StarHandlerMetadata:
handler_module_path: str
"""Handler 所在的模块路径。"""
handler: Awaitable
handler: Callable[..., Awaitable[Any]]
"""Handler 的函数对象,应当是一个异步函数"""
event_filters: List[HandlerFilter]
+10 -4
View File
@@ -43,7 +43,7 @@ class PluginManager:
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self
self.context._star_manager = self # type: ignore
self.config = config
self.plugin_store_path = get_astrbot_plugin_path()
@@ -478,9 +478,10 @@ class PluginManager:
if isinstance(func_tool, HandoffTool):
need_apply = []
sub_tools = func_tool.agent.tools
for sub_tool in sub_tools:
if isinstance(sub_tool, FunctionTool):
need_apply.append(sub_tool)
if sub_tools:
for sub_tool in sub_tools:
if isinstance(sub_tool, FunctionTool):
need_apply.append(sub_tool)
else:
need_apply = [func_tool]
@@ -686,6 +687,9 @@ class PluginManager:
)
# 从 star_registry 和 star_map 中删除
if plugin.module_path is None or root_dir_name is None:
raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。")
await self._unbind_plugin(plugin_name, plugin.module_path)
try:
@@ -800,6 +804,8 @@ class PluginManager:
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if plugin is None:
raise Exception(f"插件 {plugin_name} 不存在。")
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
+6 -2
View File
@@ -22,7 +22,7 @@ import inspect
import os
import uuid
from pathlib import Path
from typing import Union, Awaitable, List, Optional, ClassVar
from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
@@ -221,7 +221,11 @@ class StarTools:
@classmethod
def register_llm_tool(
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
cls,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
"""
为函数调用function-calling/tools-use添加工具
+3
View File
@@ -32,6 +32,9 @@ class PluginUpdator(RepoZipUpdator):
if not repo_url:
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
if not plugin.root_dir_name:
raise Exception(f"插件 {plugin.name} 的根目录名未指定。")
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
+8 -3
View File
@@ -178,7 +178,7 @@ class Main(star.Star):
return results
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str:
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
event.set_result(
MessageEventResult().message(
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。"
@@ -210,7 +210,7 @@ class Main(star.Star):
processed_results = await asyncio.gather(*tasks, return_exceptions=True)
ret = ""
for processed_result in processed_results:
if isinstance(processed_result, Exception):
if isinstance(processed_result, BaseException):
logger.error(f"Error processing search result: {processed_result}")
continue
ret += processed_result
@@ -335,7 +335,7 @@ class Main(star.Star):
@filter.on_llm_request(priority=-10000)
async def edit_web_search_tools(
self, event: AstrMessageEvent, req: ProviderRequest
) -> str:
):
"""Get the session conversation for the given event."""
cfg = self.context.get_config(umo=event.unified_msg_origin)
prov_settings = cfg.get("provider_settings", {})
@@ -347,6 +347,9 @@ class Main(star.Star):
req.func_tool = tool_set.get_full_tool_set()
tool_set = req.func_tool
if not tool_set:
return
if not websearch_enable:
# pop tools
for tool_name in self.TOOLS:
@@ -372,3 +375,5 @@ class Main(star.Star):
tool_set.add_tool(tavily_extract_web_page)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
print(req.func_tool)