feat: middleware

This commit is contained in:
Soulter
2024-09-11 16:47:44 +08:00
parent 6db8c38c58
commit a93e6ff01a
6 changed files with 47 additions and 6 deletions
+12 -3
View File
@@ -1,7 +1,11 @@
flag_not_support = False
try:
from util.plugin_dev.api.v1.bot import Context, AstrMessageEvent, CommandResult
from util.plugin_dev.api.v1.config import *
from util.plugin_dev.api.v1 import (
Context,
CommandResult,
AstrMessageEvent,
Middleware,
)
except ImportError:
flag_not_support = True
print("导入接口失败。请升级到 AstrBot 最新版本。")
@@ -21,12 +25,17 @@ class HelloWorldPlugin:
def __init__(self, context: Context) -> None:
self.context = context
self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld)
self.context.register_middleware("audio_to_text", self.audio_to_text)
async def audio_to_text(self, message: AstrMessageEvent, context: Context):
print(message)
"""
指令处理函数。
- 需要接收两个参数:message: AstrMessageEvent, context: Context
- 返回 CommandResult 对象
"""
def helloworld(self, message: AstrMessageEvent, context: Context):
async def helloworld(self, message: AstrMessageEvent, context: Context):
return CommandResult().message("Hello, World!")
+9 -2
View File
@@ -154,12 +154,19 @@ class MessageHandler():
is_command_call=True,
use_t2i=cmd_res.is_use_t2i
)
# next is the LLM part
# middlewares
for middleware in self.context.middlewares:
try:
await middleware.func(message)
except BaseException as e:
logger.error(f"中间件 {middleware.origin}/{middleware.name} 处理消息时发生异常:{e},跳过。")
logger.error(traceback.format_exc())
if message.only_command:
return
# next is the LLM part
# check if the message is a llm-wake-up command
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
+8
View File
@@ -0,0 +1,8 @@
from dataclasses import dataclass
@dataclass
class Middleware():
name: str = ""
description: str = ""
func: callable = None
origin: str # 注册来源
+10
View File
@@ -9,6 +9,7 @@ from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator
from type.command import CommandResult
from type.middleware import Middleware
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
@@ -43,6 +44,7 @@ class Context:
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
self.middlewares: List[Middleware] = []
self.command_manager = None
self.running = True
@@ -115,6 +117,14 @@ class Context:
删除一个函数调用工具。
'''
self.message_handler.llm_tools.remove_func(tool_name)
def register_middleware(self, middleware: Middleware):
'''
注册一个中间件。所有的消息事件都会经过中间件处理,然后再进入 LLM 聊天模块。
在 AstrBot 中,会对到来的消息事件首先检查指令,然后再检查中间件。触发指令后将不会进入 LLM 聊天模块,而中间件会。
'''
self.middlewares.append(middleware)
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
+7
View File
@@ -0,0 +1,7 @@
from .bot import *
from .config import *
from .llm import *
from .message import *
from .platform import *
from .register import *
from .types import *
+1 -1
View File
@@ -3,5 +3,5 @@
'''
from type.plugin import PluginType
from type.middleware import Middleware
from nakuru.entities.components import Image, Plain, At, Node, BaseMessageComponent