diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 97b712d0b..d4dbc12b5 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -102,6 +102,10 @@ class BaseMessageComponent(BaseModel): data[k] = v return {"type": self.type.lower(), "data": data} + async def to_dict(self) -> dict: + # 默认情况下,回退到旧的同步 toDict() + return self.toDict() + class Plain(BaseMessageComponent): type: ComponentType = "Plain" @@ -118,6 +122,9 @@ class Plain(BaseMessageComponent): self.text.replace("&", "&").replace("[", "[").replace("]", "]") ) + def toDict(self): + return {"type": "text", "data": {"text": self.text.strip()}} + class Face(BaseMessageComponent): type: ComponentType = "Face" @@ -235,9 +242,6 @@ class Video(BaseMessageComponent): path: T.Optional[str] = "" def __init__(self, file: str, **_): - # for k in _.keys(): - # if k == "c" and _[k] not in [2, 3]: - # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(file=file, **_) @staticmethod @@ -295,6 +299,25 @@ class Video(BaseMessageComponent): return f"{callback_host}/api/file/{token}" + async def to_dict(self): + """需要和 toDict 区分开,toDict 是同步方法""" + url_or_path = self.file + if url_or_path.startswith("http"): + payload_file = url_or_path + elif callback_host := astrbot_config.get("callback_api_base"): + callback_host = str(callback_host).removesuffix("/") + token = await file_token_service.register_file(url_or_path) + payload_file = f"{callback_host}/api/file/{token}" + logger.debug(f"Generated video file callback link: {payload_file}") + else: + payload_file = url_or_path + return { + "type": "video", + "data": { + "file": payload_file, + }, + } + class At(BaseMessageComponent): type: ComponentType = "At" @@ -304,6 +327,12 @@ class At(BaseMessageComponent): def __init__(self, **_): super().__init__(**_) + def toDict(self): + return { + "type": "at", + "data": {"qq": str(self.qq)}, + } + class AtAll(At): qq: str = "all" @@ -559,27 +588,47 @@ class Node(BaseMessageComponent): id: T.Optional[int] = 0 # 忽略 name: T.Optional[str] = "" # qq昵称 uin: T.Optional[str] = "0" # qq号 - content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表 + content: T.Optional[list[BaseMessageComponent]] = [] seq: T.Optional[T.Union[str, list]] = "" # 忽略 time: T.Optional[int] = 0 # 忽略 - def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_): - if isinstance(content, list): - _content = None - if all(isinstance(item, Node) for item in content): - _content = [node.toDict() for node in content] - else: - _content = "" - for chain in content: - _content += chain.toString() - content = _content - elif isinstance(content, Node): - content = content.toDict() + def __init__(self, content: list[BaseMessageComponent], **_): + if isinstance(content, Node): + # back + content = [content] super().__init__(content=content, **_) - def toString(self): - # logger.warn("Protocol: node doesn't support stringify") - return "" + async def to_dict(self): + data_content = [] + for comp in self.content: + if isinstance(comp, (Image, Record)): + # For Image and Record segments, we convert them to base64 + bs64 = await comp.convert_to_base64() + data_content.append( + { + "type": comp.type.lower(), + "data": {"file": f"base64://{bs64}"}, + } + ) + elif isinstance(comp, File): + # For File segments, we need to handle the file differently + d = await comp.to_dict() + data_content.append(d) + elif isinstance(comp, (Node, Nodes)): + # For Node segments, we recursively convert them to dict + d = await comp.to_dict() + data_content.append(d) + else: + d = comp.toDict() + data_content.append(d) + return { + "type": "node", + "data": { + "user_id": str(self.uin), + "nickname": self.name, + "content": data_content, + }, + } class Nodes(BaseMessageComponent): @@ -590,12 +639,20 @@ class Nodes(BaseMessageComponent): super().__init__(nodes=nodes, **_) def toDict(self): + """Deprecated. Use to_dict instead""" ret = { "messages": [], } for node in self.nodes: d = node.toDict() - d["data"]["uin"] = str(node.uin) # 转为字符串 + ret["messages"].append(d) + return ret + + async def to_dict(self): + """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" + ret = {"messages": []} + for node in self.nodes: + d = await node.to_dict() ret["messages"].append(d) return ret @@ -768,6 +825,26 @@ class File(BaseMessageComponent): return f"{callback_host}/api/file/{token}" + async def to_dict(self): + """需要和 toDict 区分开,toDict 是同步方法""" + url_or_path = await self.get_file(allow_return_url=True) + if url_or_path.startswith("http"): + payload_file = url_or_path + elif callback_host := astrbot_config.get("callback_api_base"): + callback_host = str(callback_host).removesuffix("/") + token = await file_token_service.register_file(url_or_path) + payload_file = f"{callback_host}/api/file/{token}" + logger.debug(f"Generated file callback link: {payload_file}") + else: + payload_file = url_or_path + return { + "type": "file", + "data": { + "name": self.name, + "file": payload_file, + }, + } + class WechatEmoji(BaseMessageComponent): type: ComponentType = "WechatEmoji" diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 1bd70cd14..c990e9d8d 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -29,9 +29,7 @@ class RespondStage(Stage): Comp.Image: lambda comp: bool(comp.file), # 图片 Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳 - Comp.Node: lambda comp: bool(comp.name) - and comp.uin != 0 - and bool(comp.content), # 一个转发节点 + Comp.Node: lambda comp: bool(comp.content), # 转发节点 Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 Comp.File: lambda comp: bool(comp.file_ or comp.url), } diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 068a8bf3c..b5539b7eb 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -3,9 +3,17 @@ import re from typing import AsyncGenerator, Dict, List from aiocqhttp import CQHttp from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File +from astrbot.api.message_components import ( + Image, + Node, + Nodes, + Plain, + Record, + Video, + File, + BaseMessageComponent, +) from astrbot.api.platform import Group, MessageMember -from astrbot.core import file_token_service, astrbot_config, logger class AiocqhttpMessageEvent(AstrMessageEvent): @@ -15,28 +23,38 @@ class AiocqhttpMessageEvent(AstrMessageEvent): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot + @staticmethod + async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict: + """修复部分字段""" + if isinstance(segment, (Image, Record)): + # For Image and Record segments, we convert them to base64 + bs64 = await segment.convert_to_base64() + return { + "type": segment.type.lower(), + "data": { + "file": f"base64://{bs64}", + }, + } + elif isinstance(segment, File): + # For File segments, we need to handle the file differently + d = await segment.to_dict() + return d + elif isinstance(segment, Video): + d = await segment.to_dict() + return d + else: + # For other segments, we simply convert them to a dict by calling toDict + return segment.toDict() + @staticmethod async def _parse_onebot_json(message_chain: MessageChain): """解析成 OneBot json 格式""" ret = [] for segment in message_chain.chain: - d = segment.toDict() if isinstance(segment, Plain): - d["type"] = "text" - d["data"]["text"] = segment.text.strip() - # 如果是空文本或者只带换行符的文本,不发送 - if not d["data"]["text"]: + if not segment.text.strip(): continue - elif isinstance(segment, (Image, Record)): - # convert to base64 - bs64 = await segment.convert_to_base64() - d["data"] = { - "file": f"base64://{bs64}", - } - elif isinstance(segment, At): - d["data"] = { - "qq": str(segment.qq), # 转换为字符串 - } + d = await AiocqhttpMessageEvent._from_segment_to_dict(segment) ret.append(d) return ret @@ -54,7 +72,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent): nodes = Nodes([seg]) seg = nodes - payload = seg.toDict() + payload = await seg.to_dict() + if self.get_group_id(): payload["group_id"] = self.get_group_id() await self.bot.call_action("send_group_forward_msg", **payload) @@ -64,21 +83,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent): "send_private_forward_msg", **payload ) elif isinstance(seg, File): - d = seg.toDict() - url_or_path = await seg.get_file(allow_return_url=True) - if url_or_path.startswith("http"): - payload_file = url_or_path - elif callback_host := astrbot_config.get("callback_api_base"): - callback_host = str(callback_host).removesuffix("/") - token = await file_token_service.register_file(url_or_path) - payload_file = f"{callback_host}/api/file/{token}" - logger.debug(f"Generated file callback link: {payload_file}") - else: - payload_file = url_or_path - d["data"] = { - "name": seg.name, - "file": payload_file, - } + d = await AiocqhttpMessageEvent._from_segment_to_dict(seg) await self.bot.send( self.message_obj.raw_message, [d],