import heapq from asyncio import Queue from . import RegisteredPlugin, PluginMetadata from typing import List, Dict, Awaitable, Union from dataclasses import dataclass from core.platform import Platform from core.db import BaseDatabase from core.config.astrbot_config import AstrBotConfig from core.utils.func_call import FuncCall from core.platform.astr_message_event import MessageSesion from core.message_event_result import MessageChain @dataclass class CommandMetadata(): ''' 显式指令 ''' plugin_name: str plugin_metadata: PluginMetadata handler: Awaitable use_regex: bool = False ignore_prefix: bool = False description: str = "" @dataclass class EventListenerMetadata(): ''' 事件监听器 ''' plugin_name: str plugin_metadata: PluginMetadata handler: Awaitable description: str = "" after_commands: bool = False class Context: ''' 暴露给插件的接口上下文,用于注册指令、事件监听器、消息平台、模型提供商等。 ''' # 事件队列。消息平台通过事件队列传递消息事件。 _event_queue: Queue = None # AstrBot 配置信息 _config: AstrBotConfig = None # AstrBot 数据库 _db: BaseDatabase = None # 维护了注册的插件的信息 registered_plugins: List[RegisteredPlugin] = [] # 维护了插件注册的指令的信息的名字列表,用于优先级排序 registered_commands: List[str] = [] # 维护了插件注册的指令的信息 commands_handler: Dict[str, CommandMetadata] = {} # 维护了插件注册的中间件的名字列表,用于优先级排序 registered_listeners: List[str] = [] # 维护了插件注册的中间件的信息 listeners_handler: Dict[str, EventListenerMetadata] = {} # 维护了注册的平台的信息 registered_platforms: List[Platform] = [] # 维护了 LLM Tools 信息 llm_tools: FuncCall = FuncCall() def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase): self._event_queue = event_queue self._config = config self._db = db def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin: for plugin in self.registered_plugins: if plugin.metadata.plugin_name == plugin_name: return plugin return None def register_listener(self, plugin_name: str, name: str, handler: Awaitable, description: str = None, after_commands: bool = False): ''' 注册一个事件监听器。 after_commands: 是否在指令处理后执行。 ''' if name in self.registered_listeners: raise ValueError(f"Middleware {name} already exists.") self.registered_listeners.append(name) self.listeners_handler[name] = EventListenerMetadata( plugin_name=plugin_name, plugin_metadata=None, handler=handler, description=description, after_commands=after_commands ) def register_commands(self, plugin_name: str, command_name: str, description: str, priority: int, handler: Awaitable, use_regex: bool = False, ignore_prefix: bool = False): ''' 注册插件指令。 @param plugin_name: 插件名,注意需要和你的 metadata 中的一致。 @param command_name: 指令名,如 "help"。不需要带前缀。 @param description: 指令描述。 @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context @param use_regex: 是否使用正则表达式匹配指令名。 @param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。 .. Example:: ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。 ''' for command in self.registered_commands: if command_name in command[1]: raise ValueError(f"Command {command_name} already exists.") if not handler: raise ValueError(f"Handler of {command_name} is None.") heapq.heappush(self.registered_commands, (-priority, command_name)) self.commands_handler[command_name] = CommandMetadata( plugin_name=plugin_name, plugin_metadata=None, handler=handler, use_regex=use_regex, ignore_prefix=ignore_prefix, description=description ) heapq.heapify(self.registered_commands) def register_platform(self, platform: Platform): ''' 注册一个消息平台。 ''' self.registered_platforms.append(platform) def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: ''' 为函数调用(function-calling / tools-use)添加工具。 @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @param desc: 函数描述 @param func_obj: 异步处理函数。 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 ''' self.llm_tools.add_func(name, func_args, desc, func_obj) def unregister_llm_tool(self, name: str) -> None: ''' 删除一个函数调用工具。 ''' self.llm_tools.remove_func(name) def get_config(self) -> AstrBotConfig: ''' 获取 AstrBot 配置信息。 ''' return self._config def get_db(self) -> BaseDatabase: ''' 获取 AstrBot 数据库。 ''' return self._db def get_event_queue(self) -> Queue: ''' 获取事件队列。 ''' return self._event_queue async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool: ''' 根据 session(unified_msg_origin) 发送消息。 @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 @param message_chain: 消息链。 @return: 是否找到匹配的平台。 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 ''' if isinstance(session, str): try: session = MessageSesion.from_str(session) except BaseException as e: raise ValueError("不合法的 session 字符串: " + str(e)) for platform in self.registered_platforms: if platform.meta().name == session.platform_name: await platform.send_by_session(session, message_chain) return True return False