refactor: improve message segment handle

This commit is contained in:
Soulter
2025-05-28 12:53:00 +08:00
parent 6b067fa6a7
commit 1e99797df8
2 changed files with 104 additions and 50 deletions
+70 -17
View File
@@ -118,6 +118,9 @@ class Plain(BaseMessageComponent):
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
)
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
class Face(BaseMessageComponent):
type: ComponentType = "Face"
@@ -304,6 +307,12 @@ class At(BaseMessageComponent):
def __init__(self, **_):
super().__init__(**_)
def toDict(self):
return {
"type": "at",
"data": {"qq": str(self.qq)},
}
class AtAll(At):
qq: str = "all"
@@ -559,27 +568,43 @@ class Node(BaseMessageComponent):
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
content: T.Optional[list[BaseMessageComponent]] = None
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0 # 忽略
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
if isinstance(content, list):
_content = None
if all(isinstance(item, Node) for item in content):
_content = [node.toDict() for node in content]
else:
_content = ""
for chain in content:
_content += chain.toString()
content = _content
elif isinstance(content, Node):
content = content.toDict()
def __init__(self, content: list[BaseMessageComponent] | "Node", **_):
if isinstance(content, Node):
content = [content]
super().__init__(content=content, **_)
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
async def to_dict(self):
data_content = []
for comp in self.content:
if isinstance(comp, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await comp.convert_to_base64()
data_content.append(
{
"type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"},
}
)
elif isinstance(comp, File):
# For File segments, we need to handle the file differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, (Node, Nodes)):
# For Node segments, we recursively convert them to dict
d = await comp.to_dict()
data_content.append(d)
return {
"type": "node",
"data": {
"user_id": str(self.uin),
"nickname": self.name,
"content": data_content,
},
}
class Nodes(BaseMessageComponent):
@@ -590,15 +615,23 @@ class Nodes(BaseMessageComponent):
super().__init__(nodes=nodes, **_)
def toDict(self):
"""Deprecated. Use to_dict instead"""
ret = {
"messages": [],
}
for node in self.nodes:
d = node.toDict()
d["data"]["uin"] = str(node.uin) # 转为字符串
ret["messages"].append(d)
return ret
async def to_dict(self):
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {
"type": "nodes",
"data": {"nodes": [await node.to_dict() for node in self.nodes]},
}
return ret
class Xml(BaseMessageComponent):
type: ComponentType = "Xml"
@@ -768,6 +801,26 @@ class File(BaseMessageComponent):
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开,toDict 是同步方法"""
url_or_path = await self.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "file",
"data": {
"name": self.name,
"file": payload_file,
},
}
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"
@@ -3,9 +3,16 @@ import re
from typing import AsyncGenerator, Dict, List
from aiocqhttp import CQHttp
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File
from astrbot.api.message_components import (
Image,
Node,
Nodes,
Plain,
Record,
File,
BaseMessageComponent,
)
from astrbot.api.platform import Group, MessageMember
from astrbot.core import file_token_service, astrbot_config, logger
class AiocqhttpMessageEvent(AstrMessageEvent):
@@ -15,28 +22,35 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@staticmethod
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
"""修复部分字段"""
if isinstance(segment, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await segment.convert_to_base64()
return {
"type": segment.type.lower(),
"data": {
"file": f"base64://{bs64}",
},
}
elif isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
return d
else:
# For other segments, we simply convert them to a dict by calling toDict
return segment.toDict()
@staticmethod
async def _parse_onebot_json(message_chain: MessageChain):
"""解析成 OneBot json 格式"""
ret = []
for segment in message_chain.chain:
d = segment.toDict()
if isinstance(segment, Plain):
d["type"] = "text"
d["data"]["text"] = segment.text.strip()
# 如果是空文本或者只带换行符的文本,不发送
if not d["data"]["text"]:
if not segment.text.strip():
continue
elif isinstance(segment, (Image, Record)):
# convert to base64
bs64 = await segment.convert_to_base64()
d["data"] = {
"file": f"base64://{bs64}",
}
elif isinstance(segment, At):
d["data"] = {
"qq": str(segment.qq), # 转换为字符串
}
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
ret.append(d)
return ret
@@ -54,7 +68,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
nodes = Nodes([seg])
seg = nodes
payload = seg.toDict()
payload = await seg.to_dict()
if self.get_group_id():
payload["group_id"] = self.get_group_id()
await self.bot.call_action("send_group_forward_msg", **payload)
@@ -64,21 +79,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
"send_private_forward_msg", **payload
)
elif isinstance(seg, File):
d = seg.toDict()
url_or_path = await seg.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
d["data"] = {
"name": seg.name,
"file": payload_file,
}
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
await self.bot.send(
self.message_obj.raw_message,
[d],