Merge pull request #202 from Soulter/feat-middleware

支持插件注册消息中间件
This commit is contained in:
Soulter
2024-09-18 13:29:48 +08:00
committed by GitHub
7 changed files with 47 additions and 10 deletions
+8 -4
View File
@@ -1,7 +1,11 @@
flag_not_support = False
try:
from util.plugin_dev.api.v1.bot import Context, AstrMessageEvent, CommandResult
from util.plugin_dev.api.v1.config import *
from util.plugin_dev.api.v1 import (
Context,
CommandResult,
AstrMessageEvent,
Middleware,
)
except ImportError:
flag_not_support = True
print("导入接口失败。请升级到 AstrBot 最新版本。")
@@ -21,12 +25,12 @@ class HelloWorldPlugin:
def __init__(self, context: Context) -> None:
self.context = context
self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld)
"""
指令处理函数。
- 需要接收两个参数:message: AstrMessageEvent, context: Context
- 返回 CommandResult 对象
"""
def helloworld(self, message: AstrMessageEvent, context: Context):
async def helloworld(self, message: AstrMessageEvent, context: Context):
return CommandResult().message("Hello, World!")
+10 -2
View File
@@ -154,12 +154,20 @@ class MessageHandler():
is_command_call=True,
use_t2i=cmd_res.is_use_t2i
)
# next is the LLM part
# middlewares
for middleware in self.context.middlewares:
try:
logger.info(f"执行中间件 {middleware.origin}/{middleware.name}...")
await middleware.func(message, self.context)
except BaseException as e:
logger.error(f"中间件 {middleware.origin}/{middleware.name} 处理消息时发生异常:{e},跳过。")
logger.error(traceback.format_exc())
if message.only_command:
return
# next is the LLM part
# check if the message is a llm-wake-up command
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
+3 -3
View File
@@ -181,18 +181,18 @@ class InternalCommandHandler:
except BaseException as e:
logger.warning("An error occurred while fetching astrbot notice. Never mind, it's not important.")
msg = "# Help Center\n## 指令列表\n"
msg = "# 帮助中心\n## 指令\n"
for key, value in self.manager.commands_handler.items():
if value.plugin_metadata:
msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n"
else: msg += f"- `{key}`: {value.description}\n"
# plugins
if context.cached_plugins != None:
if context.cached_plugins:
plugin_list_info = ""
for plugin in context.cached_plugins:
plugin_list_info += f"- `{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n"
if plugin_list_info.strip() != "":
msg += "\n## 插件列表\n> 使用plugin v 插件名 查看插件帮助\n"
msg += "\n## 插件\n> 使用plugin v 插件名 查看插件帮助\n"
msg += plugin_list_info
msg += notice
+8
View File
@@ -0,0 +1,8 @@
from dataclasses import dataclass
@dataclass
class Middleware():
name: str = ""
description: str = ""
origin: str = "" # 注册来源
func: callable = None
+10
View File
@@ -9,6 +9,7 @@ from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator
from type.command import CommandResult
from type.middleware import Middleware
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
@@ -43,6 +44,7 @@ class Context:
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
self.middlewares: List[Middleware] = []
self.command_manager = None
self.running = True
@@ -115,6 +117,14 @@ class Context:
删除一个函数调用工具。
'''
self.message_handler.llm_tools.remove_func(tool_name)
def register_middleware(self, middleware: Middleware):
'''
注册一个中间件。所有的消息事件都会经过中间件处理,然后再进入 LLM 聊天模块。
在 AstrBot 中,会对到来的消息事件首先检查指令,然后再检查中间件。触发指令后将不会进入 LLM 聊天模块,而中间件会。
'''
self.middlewares.append(middleware)
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
+7
View File
@@ -0,0 +1,7 @@
from .bot import *
from .config import *
from .llm import *
from .message import *
from .platform import *
from .register import *
from .types import *
+1 -1
View File
@@ -3,5 +3,5 @@
'''
from type.plugin import PluginType
from type.middleware import Middleware
from nakuru.entities.components import Image, Plain, At, Node, BaseMessageComponent