diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 0facf2037..af64e0dd6 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -78,6 +78,7 @@ class AstrBotBootstrap(): self.context.updator = self.updator self.context.plugin_updator = self.plugin_manager.updator self.context.message_handler = self.message_handler + self.context.command_manager = self.command_manager # load plugins, plugins' commands. self.load_plugins() diff --git a/model/command/manager.py b/model/command/manager.py index 3f7608c71..f6e90290c 100644 --- a/model/command/manager.py +++ b/model/command/manager.py @@ -21,6 +21,7 @@ class CommandMetadata(): plugin_metadata: PluginMetadata handler: callable use_regex: bool = False + ignore_prefix: bool = False description: str = "" class CommandManager(): @@ -35,6 +36,7 @@ class CommandManager(): priority: int, handler: callable, use_regex: bool = False, + ignore_prefix: bool = False, plugin_metadata: PluginMetadata = None, ): ''' @@ -53,6 +55,7 @@ class CommandManager(): plugin_metadata=plugin_metadata, handler=handler, use_regex=use_regex, + ignore_prefix=ignore_prefix, description=description ) if plugin_metadata: @@ -75,9 +78,23 @@ class CommandManager(): priority=request.priority, handler=request.handler, use_regex=request.use_regex, + ignore_prefix=request.ignore_prefix, plugin_metadata=plugin.metadata) self.plugin_commands_waitlist = [] - + + async def check_command_ignore_prefix(self, message_str: str) -> bool: + for _, command in self.commands: + command_metadata = self.commands_handler[command] + if command_metadata.ignore_prefix: + trig = False + if self.commands_handler[command].use_regex: + trig = self.command_parser.regex_match(message_str, command) + else: + trig = message_str.startswith(command) + if trig: + return True + return False + async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult: message_str = message_event.message_str for _, command in self.commands: @@ -89,6 +106,8 @@ class CommandManager(): if trig: logger.info(f"触发 {command} 指令。") command_result = await self.execute_handler(command, message_event, context) + if not command_result: + continue if command_result.hit: return command_result diff --git a/model/platform/__init__.py b/model/platform/__init__.py index 7eadf4270..1acac0fb5 100644 --- a/model/platform/__init__.py +++ b/model/platform/__init__.py @@ -3,6 +3,7 @@ from typing import Union, Any, List from nakuru.entities.components import Plain, At, Image, BaseMessageComponent from type.astrbot_message import AstrBotMessage from type.command import CommandResult +from type.astrbot_message import MessageType class Platform(): @@ -30,6 +31,13 @@ class Platform(): 发送消息(主动) ''' pass + + @abc.abstractmethod + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + ''' + 发送消息(主动) + ''' + pass def parse_message_outline(self, message: AstrBotMessage) -> str: ''' diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 0788a5b2e..9fc73ee52 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -103,12 +103,16 @@ class AIOCQHTTP(Platform): await asyncio.sleep(1) def pre_check(self, message: AstrBotMessage) -> bool: - # if message chain contains Plain components or At components which points to self_id, return True + # if message chain contains Plain components or + # At components which points to self_id, return True if message.type == MessageType.FRIEND_MESSAGE: return True for comp in message.message: if isinstance(comp, At) and str(comp.qq) == message.self_id: return True + # check commands which ignore prefix + if self.context.command_manager.check_command_ignore_prefix(message.message_str): + return True # check nicks if self.check_nick(message.message_str): return True @@ -129,14 +133,28 @@ class AIOCQHTTP(Platform): else: role = 'member' + # parse unified message origin + unified_msg_origin = None + assert isinstance(message.raw_message, Event) + if message.type == MessageType.GROUP_MESSAGE: + unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}" + elif message.type == MessageType.FRIEND_MESSAGE: + unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}" + + logger.debug(f"unified_msg_origin: {unified_msg_origin}") + # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, + self.context, + "aiocqhttp", + message.session_id, + role, unified_msg_origin) # transfer control to message handler message_result = await self.message_handler.handle(ame) if not message_result: return - await self.reply_msg(message, message_result.result_message) + await self.reply_msg(message, message_result.result_message, message_result.use_t2i) if message_result.callback: message_result.callback() @@ -147,7 +165,8 @@ class AIOCQHTTP(Platform): async def reply_msg(self, message: AstrBotMessage, - result_message: list): + result_message: list, + use_t2i: bool = None): """ 回复用户唤醒机器人的消息。(被动回复) """ @@ -160,7 +179,7 @@ class AIOCQHTTP(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): + if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: @@ -223,4 +242,12 @@ class AIOCQHTTP(Platform): ''' - await self._reply(target, result_message.message_chain) \ No newline at end of file + await self._reply(target, result_message.message_chain) + + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + if message_type == MessageType.GROUP_MESSAGE: + await self.send_msg({'group_id': int(target)}, result_message) + elif message_type == MessageType.FRIEND_MESSAGE: + await self.send_msg({'user_id': int(target)}, result_message) + else: + raise Exception("aiocqhttp: 无法识别的消息类型。") \ No newline at end of file diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index d4052094d..f3311a8b7 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -78,6 +78,9 @@ class QQGOCQ(Platform): for comp in message.message: if isinstance(comp, At) and str(comp.qq) == message.self_id: return True + # check commands which ignore prefix + if self.context.command_manager.check_command_ignore_prefix(message.message_str): + return True # check nicks if self.check_nick(message.message_str): return True @@ -118,14 +121,34 @@ class QQGOCQ(Platform): else: role = 'member' + # parse unified message origin + unified_msg_origin = None + if message.type == MessageType.GROUP_MESSAGE: + assert isinstance(message.raw_message, GroupMessage) + unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}" + elif message.type == MessageType.FRIEND_MESSAGE: + assert isinstance(message.raw_message, FriendMessage) + unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}" + elif message.type == MessageType.GUILD_MESSAGE: + assert isinstance(message.raw_message, GuildMessage) + unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}" + + logger.debug(f"unified_msg_origin: {unified_msg_origin}") + + # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, + self.context, + "nakuru", + session_id, + role, + unified_msg_origin) # transfer control to message handler message_result = await self.message_handler.handle(ame) if not message_result: return - await self.reply_msg(message, message_result.result_message) + await self.reply_msg(message, message_result.result_message, message_result.use_t2i) if message_result.callback: message_result.callback() @@ -135,7 +158,8 @@ class QQGOCQ(Platform): async def reply_msg(self, message: AstrBotMessage, - result_message: List[BaseMessageComponent]): + result_message: List[BaseMessageComponent], + use_t2i: bool = None): """ 回复用户唤醒机器人的消息。(被动回复) """ @@ -152,7 +176,7 @@ class QQGOCQ(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): + if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: @@ -213,6 +237,23 @@ class QQGOCQ(Platform): guild_id 不是频道号。 ''' await self._reply(target, result_message.message_chain) + + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + ''' + 以主动的方式给用户、群或者频道发送一条消息。 + + `message_type` 为 MessageType 枚举类型。 + + - 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE; + - 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE; + - 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。 + ''' + if message_type == MessageType.FRIEND_MESSAGE: + await self.send_msg({"user_id": int(target)}, result_message) + elif message_type == MessageType.GROUP_MESSAGE: + await self.send_msg({"group_id": int(target)}, result_message) + elif message_type == MessageType.GUILD_MESSAGE: + await self.send_msg({"channel_id": int(target)}, result_message) def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage: abm = AstrBotMessage() @@ -233,7 +274,7 @@ class QQGOCQ(Platform): str(message.sender.user_id), str(message.sender.nickname) ) - abm.tag = "gocq" + abm.tag = "nakuru" abm.message = message.message return abm diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 5ca3e301b..dc28453b8 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -222,7 +222,7 @@ class QQOfficial(Platform): if not message_result: return - ret = await self.reply_msg(message, message_result.result_message) + ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i) if message_result.callback: message_result.callback() @@ -234,7 +234,8 @@ class QQOfficial(Platform): async def reply_msg(self, message: AstrBotMessage, - result_message: List[BaseMessageComponent]): + result_message: List[BaseMessageComponent], + use_t2i: bool = None): ''' 回复频道消息 ''' @@ -249,7 +250,7 @@ class QQOfficial(Platform): msg_ref = None rendered_images = [] - if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list): + if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(result_message) if isinstance(result_message, list): @@ -388,6 +389,9 @@ class QQOfficial(Platform): if image_path: payload['file_image'] = image_path await self._reply(**payload) + + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + raise NotImplementedError("qqofficial 不支持此方法。") def wait_for_message(self, channel_id: int) -> AstrBotMessage: ''' diff --git a/model/plugin/command.py b/model/plugin/command.py index 3321d52c7..1e4d8fab9 100644 --- a/model/plugin/command.py +++ b/model/plugin/command.py @@ -15,12 +15,13 @@ class CommandRegisterRequest(): handler: Callable use_regex: bool = False plugin_name: str = None + ignore_prefix: bool = False class PluginCommandBridge(): def __init__(self, cached_plugins: RegisteredPlugins): self.plugin_commands_waitlist: List[CommandRegisterRequest] = [] self.cached_plugins = cached_plugins - def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False): - self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name)) + def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False): + self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix)) \ No newline at end of file diff --git a/type/command.py b/type/command.py index a8723063f..ac9ca0d50 100644 --- a/type/command.py +++ b/type/command.py @@ -2,7 +2,6 @@ from typing import Union, List, Callable from dataclasses import dataclass from nakuru.entities.components import Plain, Image - @dataclass class CommandItem(): ''' @@ -19,12 +18,17 @@ class CommandResult(): 用于在Command中返回多个值 ''' - def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None: + def __init__(self, + hit: bool = True, + success: bool = True, + message_chain: list = [], + command_name: str = "unknown_command", + use_t2i: bool = None) -> None: self.hit = hit self.success = success self.message_chain = message_chain self.command_name = command_name - self.is_use_t2i = None # default + self.is_use_t2i = use_t2i def message(self, message: str): ''' @@ -63,14 +67,12 @@ class CommandResult(): self.message_chain = [Image.fromFileSystem(path), ] return self - # def use_t2i(self, use_t2i: bool): - # ''' - # 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 - - # CommandResult().use_t2i(False) - # ''' - # self.is_use_t2i = use_t2i - # return self + def use_t2i(self, use_t2i: bool): + ''' + 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 + ''' + self.is_use_t2i = use_t2i + return self def _result_tuple(self): return (self.success, self.message_chain, self.command_name) diff --git a/type/config.py b/type/config.py index e13bc9262..6d7fa437c 100644 --- a/type/config.py +++ b/type/config.py @@ -1,4 +1,4 @@ -VERSION = '3.3.7' +VERSION = '3.3.8' DEFAULT_CONFIG = { "qqbot": { diff --git a/type/message_event.py b/type/message_event.py index 222ac91e0..5ff8a7007 100644 --- a/type/message_event.py +++ b/type/message_event.py @@ -2,7 +2,14 @@ from typing import List, Union, Optional from dataclasses import dataclass from type.register import RegisteredPlatform from type.types import Context -from type.astrbot_message import AstrBotMessage +from type.astrbot_message import AstrBotMessage, MessageType + +@dataclass +class MessageResult(): + result_message: Union[str, list] + is_command_call: Optional[bool] = False + use_t2i: Optional[bool] = None # None 为跟随用户设置 + callback: Optional[callable] = None class AstrMessageEvent(): @@ -12,7 +19,8 @@ class AstrMessageEvent(): platform: RegisteredPlatform, role: str, context: Context, - session_id: str = None): + session_id: str = None, + unified_msg_origin: str = None): ''' AstrBot 消息事件。 @@ -22,6 +30,7 @@ class AstrMessageEvent(): `role`: 角色,`admin` or `member` `context`: 全局对象 `session_id`: 会话id + `unified_msg_origin`: 统一消息来源 ''' self.context = context self.message_str = message_str @@ -29,24 +38,21 @@ class AstrMessageEvent(): self.platform = platform self.role = role self.session_id = session_id + self.unified_msg_origin = unified_msg_origin def from_astrbot_message(message: AstrBotMessage, context: Context, platform_name: str, session_id: str, - role: str = "member"): + role: str = "member", + unified_msg_origin: str = None): ame = AstrMessageEvent(message.message_str, message, context.find_platform(platform_name), role, context, - session_id) + session_id, + unified_msg_origin) return ame -@dataclass -class MessageResult(): - result_message: Union[str, list] - is_command_call: Optional[bool] = False - use_t2i: Optional[bool] = None # None 为跟随用户设置 - callback: Optional[callable] = None diff --git a/type/types.py b/type/types.py index 195e2b1a6..0e3fee1cd 100644 --- a/type/types.py +++ b/type/types.py @@ -8,6 +8,8 @@ from util.t2i.renderer import TextToImageRenderer from util.updator.astrbot_updator import AstrBotUpdator from util.image_uploader import ImageUploader from util.updator.plugin_updator import PluginUpdator +from type.command import CommandResult +from type.astrbot_message import MessageType from model.plugin.command import PluginCommandBridge from model.provider.provider import Provider @@ -40,6 +42,8 @@ class Context: self.image_uploader = ImageUploader() self.message_handler = None # see astrbot/message/handler.py self.ext_tasks: List[Task] = [] + + self.command_manager = None # useless self.reply_prefix = "" @@ -50,7 +54,8 @@ class Context: description: str, priority: int, handler: callable, - use_regex: bool = False): + use_regex: bool = False, + ignore_prefix: bool = False): ''' 注册插件指令。 @@ -60,8 +65,19 @@ class Context: @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context @param use_regex: 是否使用正则表达式匹配指令名。 + @param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。 + + .. Example:: + + ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。 ''' - self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex) + self.plugin_command_bridge.register_command(plugin_name, + command_name, + description, + priority, + handler, + use_regex, + ignore_prefix) def register_task(self, coro: Awaitable, task_name: str): ''' @@ -87,3 +103,18 @@ class Context: return platform raise ValueError("couldn't find the platform you specified") + + async def send_message(self, unified_msg_origin: str, message: CommandResult): + ''' + 发送消息。 + + `unified_msg_origin`: 统一消息来源 + `message`: 消息内容 + ''' + l = unified_msg_origin.split(":") + if len(l) != 3: + raise ValueError("Invalid unified_msg_origin") + platform_name, message_type, id = l + platform = self.find_platform(platform_name) + await platform.platform_instance.send_msg_new(MessageType(message_type), id, message) + \ No newline at end of file