From 1e99797df803194202b69bcb14ae033245ff046e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 28 May 2025 12:53:00 +0800 Subject: [PATCH] refactor: improve message segment handle --- astrbot/core/message/components.py | 87 +++++++++++++++---- .../aiocqhttp/aiocqhttp_message_event.py | 67 +++++++------- 2 files changed, 104 insertions(+), 50 deletions(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 97b712d0b..f917160f4 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -118,6 +118,9 @@ class Plain(BaseMessageComponent): self.text.replace("&", "&").replace("[", "[").replace("]", "]") ) + def toDict(self): + return {"type": "text", "data": {"text": self.text.strip()}} + class Face(BaseMessageComponent): type: ComponentType = "Face" @@ -304,6 +307,12 @@ class At(BaseMessageComponent): def __init__(self, **_): super().__init__(**_) + def toDict(self): + return { + "type": "at", + "data": {"qq": str(self.qq)}, + } + class AtAll(At): qq: str = "all" @@ -559,27 +568,43 @@ class Node(BaseMessageComponent): id: T.Optional[int] = 0 # 忽略 name: T.Optional[str] = "" # qq昵称 uin: T.Optional[str] = "0" # qq号 - content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表 + content: T.Optional[list[BaseMessageComponent]] = None seq: T.Optional[T.Union[str, list]] = "" # 忽略 time: T.Optional[int] = 0 # 忽略 - def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_): - if isinstance(content, list): - _content = None - if all(isinstance(item, Node) for item in content): - _content = [node.toDict() for node in content] - else: - _content = "" - for chain in content: - _content += chain.toString() - content = _content - elif isinstance(content, Node): - content = content.toDict() + def __init__(self, content: list[BaseMessageComponent] | "Node", **_): + if isinstance(content, Node): + content = [content] super().__init__(content=content, **_) - def toString(self): - # logger.warn("Protocol: node doesn't support stringify") - return "" + async def to_dict(self): + data_content = [] + for comp in self.content: + if isinstance(comp, (Image, Record)): + # For Image and Record segments, we convert them to base64 + bs64 = await comp.convert_to_base64() + data_content.append( + { + "type": comp.type.lower(), + "data": {"file": f"base64://{bs64}"}, + } + ) + elif isinstance(comp, File): + # For File segments, we need to handle the file differently + d = await comp.to_dict() + data_content.append(d) + elif isinstance(comp, (Node, Nodes)): + # For Node segments, we recursively convert them to dict + d = await comp.to_dict() + data_content.append(d) + return { + "type": "node", + "data": { + "user_id": str(self.uin), + "nickname": self.name, + "content": data_content, + }, + } class Nodes(BaseMessageComponent): @@ -590,15 +615,23 @@ class Nodes(BaseMessageComponent): super().__init__(nodes=nodes, **_) def toDict(self): + """Deprecated. Use to_dict instead""" ret = { "messages": [], } for node in self.nodes: d = node.toDict() - d["data"]["uin"] = str(node.uin) # 转为字符串 ret["messages"].append(d) return ret + async def to_dict(self): + """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" + ret = { + "type": "nodes", + "data": {"nodes": [await node.to_dict() for node in self.nodes]}, + } + return ret + class Xml(BaseMessageComponent): type: ComponentType = "Xml" @@ -768,6 +801,26 @@ class File(BaseMessageComponent): return f"{callback_host}/api/file/{token}" + async def to_dict(self): + """需要和 toDict 区分开,toDict 是同步方法""" + url_or_path = await self.get_file(allow_return_url=True) + if url_or_path.startswith("http"): + payload_file = url_or_path + elif callback_host := astrbot_config.get("callback_api_base"): + callback_host = str(callback_host).removesuffix("/") + token = await file_token_service.register_file(url_or_path) + payload_file = f"{callback_host}/api/file/{token}" + logger.debug(f"Generated file callback link: {payload_file}") + else: + payload_file = url_or_path + return { + "type": "file", + "data": { + "name": self.name, + "file": payload_file, + }, + } + class WechatEmoji(BaseMessageComponent): type: ComponentType = "WechatEmoji" diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 068a8bf3c..af749a554 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -3,9 +3,16 @@ import re from typing import AsyncGenerator, Dict, List from aiocqhttp import CQHttp from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File +from astrbot.api.message_components import ( + Image, + Node, + Nodes, + Plain, + Record, + File, + BaseMessageComponent, +) from astrbot.api.platform import Group, MessageMember -from astrbot.core import file_token_service, astrbot_config, logger class AiocqhttpMessageEvent(AstrMessageEvent): @@ -15,28 +22,35 @@ class AiocqhttpMessageEvent(AstrMessageEvent): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot + @staticmethod + async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict: + """修复部分字段""" + if isinstance(segment, (Image, Record)): + # For Image and Record segments, we convert them to base64 + bs64 = await segment.convert_to_base64() + return { + "type": segment.type.lower(), + "data": { + "file": f"base64://{bs64}", + }, + } + elif isinstance(segment, File): + # For File segments, we need to handle the file differently + d = await segment.to_dict() + return d + else: + # For other segments, we simply convert them to a dict by calling toDict + return segment.toDict() + @staticmethod async def _parse_onebot_json(message_chain: MessageChain): """解析成 OneBot json 格式""" ret = [] for segment in message_chain.chain: - d = segment.toDict() if isinstance(segment, Plain): - d["type"] = "text" - d["data"]["text"] = segment.text.strip() - # 如果是空文本或者只带换行符的文本,不发送 - if not d["data"]["text"]: + if not segment.text.strip(): continue - elif isinstance(segment, (Image, Record)): - # convert to base64 - bs64 = await segment.convert_to_base64() - d["data"] = { - "file": f"base64://{bs64}", - } - elif isinstance(segment, At): - d["data"] = { - "qq": str(segment.qq), # 转换为字符串 - } + d = await AiocqhttpMessageEvent._from_segment_to_dict(segment) ret.append(d) return ret @@ -54,7 +68,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent): nodes = Nodes([seg]) seg = nodes - payload = seg.toDict() + payload = await seg.to_dict() + if self.get_group_id(): payload["group_id"] = self.get_group_id() await self.bot.call_action("send_group_forward_msg", **payload) @@ -64,21 +79,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent): "send_private_forward_msg", **payload ) elif isinstance(seg, File): - d = seg.toDict() - url_or_path = await seg.get_file(allow_return_url=True) - if url_or_path.startswith("http"): - payload_file = url_or_path - elif callback_host := astrbot_config.get("callback_api_base"): - callback_host = str(callback_host).removesuffix("/") - token = await file_token_service.register_file(url_or_path) - payload_file = f"{callback_host}/api/file/{token}" - logger.debug(f"Generated file callback link: {payload_file}") - else: - payload_file = url_or_path - d["data"] = { - "name": seg.name, - "file": payload_file, - } + d = await AiocqhttpMessageEvent._from_segment_to_dict(seg) await self.bot.send( self.message_obj.raw_message, [d],