feat: 注册指令支持忽略指令前缀;快捷主动回复

This commit is contained in:
Soulter
2024-08-10 02:35:54 -04:00
parent 0f470cf96f
commit 9db43ac5e6
11 changed files with 181 additions and 41 deletions
+1
View File
@@ -78,6 +78,7 @@ class AstrBotBootstrap():
self.context.updator = self.updator
self.context.plugin_updator = self.plugin_manager.updator
self.context.message_handler = self.message_handler
self.context.command_manager = self.command_manager
# load plugins, plugins' commands.
self.load_plugins()
+20 -1
View File
@@ -21,6 +21,7 @@ class CommandMetadata():
plugin_metadata: PluginMetadata
handler: callable
use_regex: bool = False
ignore_prefix: bool = False
description: str = ""
class CommandManager():
@@ -35,6 +36,7 @@ class CommandManager():
priority: int,
handler: callable,
use_regex: bool = False,
ignore_prefix: bool = False,
plugin_metadata: PluginMetadata = None,
):
'''
@@ -53,6 +55,7 @@ class CommandManager():
plugin_metadata=plugin_metadata,
handler=handler,
use_regex=use_regex,
ignore_prefix=ignore_prefix,
description=description
)
if plugin_metadata:
@@ -75,9 +78,23 @@ class CommandManager():
priority=request.priority,
handler=request.handler,
use_regex=request.use_regex,
ignore_prefix=request.ignore_prefix,
plugin_metadata=plugin.metadata)
self.plugin_commands_waitlist = []
async def check_command_ignore_prefix(self, message_str: str) -> bool:
for _, command in self.commands:
command_metadata = self.commands_handler[command]
if command_metadata.ignore_prefix:
trig = False
if self.commands_handler[command].use_regex:
trig = self.command_parser.regex_match(message_str, command)
else:
trig = message_str.startswith(command)
if trig:
return True
return False
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
message_str = message_event.message_str
for _, command in self.commands:
@@ -89,6 +106,8 @@ class CommandManager():
if trig:
logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context)
if not command_result:
continue
if command_result.hit:
return command_result
+8
View File
@@ -3,6 +3,7 @@ from typing import Union, Any, List
from nakuru.entities.components import Plain, At, Image, BaseMessageComponent
from type.astrbot_message import AstrBotMessage
from type.command import CommandResult
from type.astrbot_message import MessageType
class Platform():
@@ -30,6 +31,13 @@ class Platform():
发送消息(主动)
'''
pass
@abc.abstractmethod
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
'''
发送消息(主动)
'''
pass
def parse_message_outline(self, message: AstrBotMessage) -> str:
'''
+33 -6
View File
@@ -103,12 +103,16 @@ class AIOCQHTTP(Platform):
await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True
# if message chain contains Plain components or
# At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE:
return True
for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
return True
# check nicks
if self.check_nick(message.message_str):
return True
@@ -129,14 +133,28 @@ class AIOCQHTTP(Platform):
else:
role = 'member'
# parse unified message origin
unified_msg_origin = None
assert isinstance(message.raw_message, Event)
if message.type == MessageType.GROUP_MESSAGE:
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}"
elif message.type == MessageType.FRIEND_MESSAGE:
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}"
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
# construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role)
ame = AstrMessageEvent.from_astrbot_message(message,
self.context,
"aiocqhttp",
message.session_id,
role, unified_msg_origin)
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
if not message_result: return
await self.reply_msg(message, message_result.result_message)
await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback:
message_result.callback()
@@ -147,7 +165,8 @@ class AIOCQHTTP(Platform):
async def reply_msg(self,
message: AstrBotMessage,
result_message: list):
result_message: list,
use_t2i: bool = None):
"""
回复用户唤醒机器人的消息。(被动回复)
"""
@@ -160,7 +179,7 @@ class AIOCQHTTP(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
@@ -223,4 +242,12 @@ class AIOCQHTTP(Platform):
'''
await self._reply(target, result_message.message_chain)
await self._reply(target, result_message.message_chain)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
if message_type == MessageType.GROUP_MESSAGE:
await self.send_msg({'group_id': int(target)}, result_message)
elif message_type == MessageType.FRIEND_MESSAGE:
await self.send_msg({'user_id': int(target)}, result_message)
else:
raise Exception("aiocqhttp: 无法识别的消息类型。")
+46 -5
View File
@@ -78,6 +78,9 @@ class QQGOCQ(Platform):
for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
return True
# check nicks
if self.check_nick(message.message_str):
return True
@@ -118,14 +121,34 @@ class QQGOCQ(Platform):
else:
role = 'member'
# parse unified message origin
unified_msg_origin = None
if message.type == MessageType.GROUP_MESSAGE:
assert isinstance(message.raw_message, GroupMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}"
elif message.type == MessageType.FRIEND_MESSAGE:
assert isinstance(message.raw_message, FriendMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}"
elif message.type == MessageType.GUILD_MESSAGE:
assert isinstance(message.raw_message, GuildMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}"
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
# construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role)
ame = AstrMessageEvent.from_astrbot_message(message,
self.context,
"nakuru",
session_id,
role,
unified_msg_origin)
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
if not message_result: return
await self.reply_msg(message, message_result.result_message)
await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback:
message_result.callback()
@@ -135,7 +158,8 @@ class QQGOCQ(Platform):
async def reply_msg(self,
message: AstrBotMessage,
result_message: List[BaseMessageComponent]):
result_message: List[BaseMessageComponent],
use_t2i: bool = None):
"""
回复用户唤醒机器人的消息。(被动回复)
"""
@@ -152,7 +176,7 @@ class QQGOCQ(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
@@ -213,6 +237,23 @@ class QQGOCQ(Platform):
guild_id 不是频道号。
'''
await self._reply(target, result_message.message_chain)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
'''
以主动的方式给用户、群或者频道发送一条消息。
`message_type` 为 MessageType 枚举类型。
- 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE
- 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE
- 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。
'''
if message_type == MessageType.FRIEND_MESSAGE:
await self.send_msg({"user_id": int(target)}, result_message)
elif message_type == MessageType.GROUP_MESSAGE:
await self.send_msg({"group_id": int(target)}, result_message)
elif message_type == MessageType.GUILD_MESSAGE:
await self.send_msg({"channel_id": int(target)}, result_message)
def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage:
abm = AstrBotMessage()
@@ -233,7 +274,7 @@ class QQGOCQ(Platform):
str(message.sender.user_id),
str(message.sender.nickname)
)
abm.tag = "gocq"
abm.tag = "nakuru"
abm.message = message.message
return abm
+7 -3
View File
@@ -222,7 +222,7 @@ class QQOfficial(Platform):
if not message_result:
return
ret = await self.reply_msg(message, message_result.result_message)
ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback:
message_result.callback()
@@ -234,7 +234,8 @@ class QQOfficial(Platform):
async def reply_msg(self,
message: AstrBotMessage,
result_message: List[BaseMessageComponent]):
result_message: List[BaseMessageComponent],
use_t2i: bool = None):
'''
回复频道消息
'''
@@ -249,7 +250,7 @@ class QQOfficial(Platform):
msg_ref = None
rendered_images = []
if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list):
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(result_message)
if isinstance(result_message, list):
@@ -388,6 +389,9 @@ class QQOfficial(Platform):
if image_path:
payload['file_image'] = image_path
await self._reply(**payload)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
raise NotImplementedError("qqofficial 不支持此方法。")
def wait_for_message(self, channel_id: int) -> AstrBotMessage:
'''
+3 -2
View File
@@ -15,12 +15,13 @@ class CommandRegisterRequest():
handler: Callable
use_regex: bool = False
plugin_name: str = None
ignore_prefix: bool = False
class PluginCommandBridge():
def __init__(self, cached_plugins: RegisteredPlugins):
self.plugin_commands_waitlist: List[CommandRegisterRequest] = []
self.cached_plugins = cached_plugins
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False):
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name))
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False):
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix))
+13 -11
View File
@@ -2,7 +2,6 @@ from typing import Union, List, Callable
from dataclasses import dataclass
from nakuru.entities.components import Plain, Image
@dataclass
class CommandItem():
'''
@@ -19,12 +18,17 @@ class CommandResult():
用于在Command中返回多个值
'''
def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None:
def __init__(self,
hit: bool = True,
success: bool = True,
message_chain: list = [],
command_name: str = "unknown_command",
use_t2i: bool = None) -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
self.is_use_t2i = None # default
self.is_use_t2i = use_t2i
def message(self, message: str):
'''
@@ -63,14 +67,12 @@ class CommandResult():
self.message_chain = [Image.fromFileSystem(path), ]
return self
# def use_t2i(self, use_t2i: bool):
# '''
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
# CommandResult().use_t2i(False)
# '''
# self.is_use_t2i = use_t2i
# return self
def use_t2i(self, use_t2i: bool):
'''
设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
'''
self.is_use_t2i = use_t2i
return self
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+1 -1
View File
@@ -1,4 +1,4 @@
VERSION = '3.3.7'
VERSION = '3.3.8'
DEFAULT_CONFIG = {
"qqbot": {
+16 -10
View File
@@ -2,7 +2,14 @@ from typing import List, Union, Optional
from dataclasses import dataclass
from type.register import RegisteredPlatform
from type.types import Context
from type.astrbot_message import AstrBotMessage
from type.astrbot_message import AstrBotMessage, MessageType
@dataclass
class MessageResult():
result_message: Union[str, list]
is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None
class AstrMessageEvent():
@@ -12,7 +19,8 @@ class AstrMessageEvent():
platform: RegisteredPlatform,
role: str,
context: Context,
session_id: str = None):
session_id: str = None,
unified_msg_origin: str = None):
'''
AstrBot 消息事件。
@@ -22,6 +30,7 @@ class AstrMessageEvent():
`role`: 角色,`admin` or `member`
`context`: 全局对象
`session_id`: 会话id
`unified_msg_origin`: 统一消息来源
'''
self.context = context
self.message_str = message_str
@@ -29,24 +38,21 @@ class AstrMessageEvent():
self.platform = platform
self.role = role
self.session_id = session_id
self.unified_msg_origin = unified_msg_origin
def from_astrbot_message(message: AstrBotMessage,
context: Context,
platform_name: str,
session_id: str,
role: str = "member"):
role: str = "member",
unified_msg_origin: str = None):
ame = AstrMessageEvent(message.message_str,
message,
context.find_platform(platform_name),
role,
context,
session_id)
session_id,
unified_msg_origin)
return ame
@dataclass
class MessageResult():
result_message: Union[str, list]
is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None
+33 -2
View File
@@ -8,6 +8,8 @@ from util.t2i.renderer import TextToImageRenderer
from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator
from type.command import CommandResult
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
@@ -40,6 +42,8 @@ class Context:
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
self.command_manager = None
# useless
self.reply_prefix = ""
@@ -50,7 +54,8 @@ class Context:
description: str,
priority: int,
handler: callable,
use_regex: bool = False):
use_regex: bool = False,
ignore_prefix: bool = False):
'''
注册插件指令。
@@ -60,8 +65,19 @@ class Context:
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
@param use_regex: 是否使用正则表达式匹配指令名。
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
.. Example::
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
'''
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex)
self.plugin_command_bridge.register_command(plugin_name,
command_name,
description,
priority,
handler,
use_regex,
ignore_prefix)
def register_task(self, coro: Awaitable, task_name: str):
'''
@@ -87,3 +103,18 @@ class Context:
return platform
raise ValueError("couldn't find the platform you specified")
async def send_message(self, unified_msg_origin: str, message: CommandResult):
'''
发送消息。
`unified_msg_origin`: 统一消息来源
`message`: 消息内容
'''
l = unified_msg_origin.split(":")
if len(l) != 3:
raise ValueError("Invalid unified_msg_origin")
platform_name, message_type, id = l
platform = self.find_platform(platform_name)
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)