208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
import heapq
|
|
from asyncio import Queue
|
|
from . import RegisteredPlugin, PluginMetadata
|
|
from typing import List, Dict, Awaitable, Union
|
|
from dataclasses import dataclass
|
|
|
|
from core.platform import Platform
|
|
from core.db import BaseDatabase
|
|
from core.config.astrbot_config import AstrBotConfig
|
|
from core.utils.func_call import FuncCall
|
|
from core.platform.astr_message_event import MessageSesion
|
|
from core.message_event_result import MessageChain
|
|
|
|
@dataclass
|
|
class CommandMetadata():
|
|
'''
|
|
显式指令
|
|
'''
|
|
plugin_name: str
|
|
plugin_metadata: PluginMetadata
|
|
handler: Awaitable
|
|
use_regex: bool = False
|
|
ignore_prefix: bool = False
|
|
description: str = ""
|
|
|
|
@dataclass
|
|
class EventListenerMetadata():
|
|
'''
|
|
事件监听器
|
|
'''
|
|
plugin_name: str
|
|
plugin_metadata: PluginMetadata
|
|
handler: Awaitable
|
|
description: str = ""
|
|
after_commands: bool = False
|
|
|
|
|
|
class Context:
|
|
'''
|
|
暴露给插件的接口上下文,用于注册指令、事件监听器、消息平台、模型提供商等。
|
|
'''
|
|
# 事件队列。消息平台通过事件队列传递消息事件。
|
|
_event_queue: Queue = None
|
|
|
|
# AstrBot 配置信息
|
|
_config: AstrBotConfig = None
|
|
|
|
# AstrBot 数据库
|
|
_db: BaseDatabase = None
|
|
|
|
# 维护了注册的插件的信息
|
|
registered_plugins: List[RegisteredPlugin] = []
|
|
|
|
# 维护了插件注册的指令的信息的名字列表,用于优先级排序
|
|
registered_commands: List[str] = []
|
|
# 维护了插件注册的指令的信息
|
|
commands_handler: Dict[str, CommandMetadata] = {}
|
|
|
|
# 维护了插件注册的中间件的名字列表,用于优先级排序
|
|
registered_listeners: List[str] = []
|
|
# 维护了插件注册的中间件的信息
|
|
listeners_handler: Dict[str, EventListenerMetadata] = {}
|
|
|
|
# 维护了注册的平台的信息
|
|
registered_platforms: List[Platform] = []
|
|
|
|
# 维护了 LLM Tools 信息
|
|
llm_tools: FuncCall = FuncCall()
|
|
|
|
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
|
self._event_queue = event_queue
|
|
self._config = config
|
|
self._db = db
|
|
|
|
def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin:
|
|
for plugin in self.registered_plugins:
|
|
if plugin.metadata.plugin_name == plugin_name:
|
|
return plugin
|
|
return None
|
|
|
|
def register_listener(self,
|
|
plugin_name: str,
|
|
name: str,
|
|
handler: Awaitable,
|
|
description: str = None,
|
|
after_commands: bool = False):
|
|
'''
|
|
注册一个事件监听器。
|
|
|
|
after_commands: 是否在指令处理后执行。
|
|
'''
|
|
if name in self.registered_listeners:
|
|
raise ValueError(f"Middleware {name} already exists.")
|
|
self.registered_listeners.append(name)
|
|
self.listeners_handler[name] = EventListenerMetadata(
|
|
plugin_name=plugin_name,
|
|
plugin_metadata=None,
|
|
handler=handler,
|
|
description=description,
|
|
after_commands=after_commands
|
|
)
|
|
|
|
def register_commands(self,
|
|
plugin_name: str,
|
|
command_name: str,
|
|
description: str,
|
|
priority: int,
|
|
handler: Awaitable,
|
|
use_regex: bool = False,
|
|
ignore_prefix: bool = False):
|
|
'''
|
|
注册插件指令。
|
|
|
|
@param plugin_name: 插件名,注意需要和你的 metadata 中的一致。
|
|
@param command_name: 指令名,如 "help"。不需要带前缀。
|
|
@param description: 指令描述。
|
|
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
|
|
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
|
|
@param use_regex: 是否使用正则表达式匹配指令名。
|
|
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
|
|
|
|
.. Example::
|
|
|
|
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
|
|
'''
|
|
for command in self.registered_commands:
|
|
if command_name in command[1]:
|
|
raise ValueError(f"Command {command_name} already exists.")
|
|
if not handler:
|
|
raise ValueError(f"Handler of {command_name} is None.")
|
|
|
|
heapq.heappush(self.registered_commands, (-priority, command_name))
|
|
self.commands_handler[command_name] = CommandMetadata(
|
|
plugin_name=plugin_name,
|
|
plugin_metadata=None,
|
|
handler=handler,
|
|
use_regex=use_regex,
|
|
ignore_prefix=ignore_prefix,
|
|
description=description
|
|
)
|
|
heapq.heapify(self.registered_commands)
|
|
|
|
def register_platform(self, platform: Platform):
|
|
'''
|
|
注册一个消息平台。
|
|
'''
|
|
self.registered_platforms.append(platform)
|
|
|
|
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
|
'''
|
|
为函数调用(function-calling / tools-use)添加工具。
|
|
|
|
@param name: 函数名
|
|
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
|
@param desc: 函数描述
|
|
@param func_obj: 异步处理函数。
|
|
|
|
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
|
'''
|
|
self.llm_tools.add_func(name, func_args, desc, func_obj)
|
|
|
|
def unregister_llm_tool(self, name: str) -> None:
|
|
'''
|
|
删除一个函数调用工具。
|
|
'''
|
|
self.llm_tools.remove_func(name)
|
|
|
|
def get_config(self) -> AstrBotConfig:
|
|
'''
|
|
获取 AstrBot 配置信息。
|
|
'''
|
|
return self._config
|
|
|
|
def get_db(self) -> BaseDatabase:
|
|
'''
|
|
获取 AstrBot 数据库。
|
|
'''
|
|
return self._db
|
|
|
|
def get_event_queue(self) -> Queue:
|
|
'''
|
|
获取事件队列。
|
|
'''
|
|
return self._event_queue
|
|
|
|
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
|
'''
|
|
根据 session(unified_msg_origin) 发送消息。
|
|
|
|
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
|
@param message_chain: 消息链。
|
|
|
|
@return: 是否找到匹配的平台。
|
|
|
|
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
|
'''
|
|
|
|
if isinstance(session, str):
|
|
try:
|
|
session = MessageSesion.from_str(session)
|
|
except BaseException as e:
|
|
raise ValueError("不合法的 session 字符串: " + str(e))
|
|
|
|
for platform in self.registered_platforms:
|
|
if platform.meta().name == session.platform_name:
|
|
await platform.send_by_session(session, message_chain)
|
|
return True
|
|
return False |