feat: Satori适配器引用消息无法正确识别 (#2686)

* Update PlatformPage.vue

* Update PlatformPage.vue

* Update PlatformPage.vue

* Update satori_adapter.py

* Update satori_event.py

* Update default.py

* Update satori_adapter.py

* Update satori_adapter.py

* style: format code

---------

Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
shangxue
2025-09-21 21:45:35 +08:00
committed by GitHub
parent 3a044bb71a
commit fc76665615
3 changed files with 313 additions and 38 deletions
+13 -13
View File
@@ -263,7 +263,7 @@ CONFIG_METADATA_2 = {
"type": "satori",
"enable": False,
"satori_api_base_url": "http://localhost:5140/satori/v1",
"satori_endpoint": "ws://127.0.0.1:5140/satori/v1/events",
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
"satori_token": "",
"satori_auto_reconnect": True,
"satori_heartbeat_interval": 10,
@@ -272,34 +272,34 @@ CONFIG_METADATA_2 = {
},
"items": {
"satori_api_base_url": {
"description": "Satori API Base URL",
"description": "Satori API 终结点",
"type": "string",
"hint": "The base URL for the Satori API.",
"hint": "Satori API 的基础地址。",
},
"satori_endpoint": {
"description": "Satori WebSocket Endpoint",
"description": "Satori WebSocket 终结点",
"type": "string",
"hint": "The WebSocket endpoint for Satori events.",
"hint": "Satori 事件的 WebSocket 端点。",
},
"satori_token": {
"description": "Satori Token",
"description": "Satori 令牌",
"type": "string",
"hint": "The token used for authenticating with the Satori API.",
"hint": "用于 Satori API 身份验证的令牌。",
},
"satori_auto_reconnect": {
"description": "Enable Auto Reconnect",
"description": "启用自动重连",
"type": "bool",
"hint": "Whether to automatically reconnect the WebSocket on disconnection.",
"hint": "断开连接时是否自动重新连接 WebSocket。",
},
"satori_heartbeat_interval": {
"description": "Satori Heartbeat Interval",
"description": "Satori 心跳间隔",
"type": "int",
"hint": "The interval (in seconds) for sending heartbeat messages.",
"hint": "发送心跳消息的间隔(秒)。",
},
"satori_reconnect_delay": {
"description": "Satori Reconnect Delay",
"description": "Satori 重连延迟",
"type": "int",
"hint": "The delay (in seconds) before attempting to reconnect.",
"hint": "尝试重新连接前的延迟时间(秒)。",
},
"slack_connection_mode": {
"description": "Slack Connection Mode",
@@ -17,7 +17,14 @@ from astrbot.api.platform import (
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.api.message_components import Plain, Image, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
At,
File,
Record,
Reply,
)
from xml.etree import ElementTree as ET
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
)
self.token = self.config.get("satori_token", "")
self.endpoint = self.config.get(
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
)
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
self.metadata = PlatformMetadata(
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
)
self.ws: Optional[ClientConnection] = None
self.session: Optional[ClientSession] = None
self.sequence = 0
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
return self.metadata
def _is_websocket_closed(self, ws) -> bool:
"""检查WebSocket连接是否已关闭"""
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
abm.self_id = login.get("user", {}).get("id", "")
content = message.get("content", "")
abm.message = await self.parse_satori_elements(content)
# 消息链
abm.message = []
content = message.get("content", "")
quote = message.get("quote")
content_for_parsing = content # 副本
# 提取<quote>标签
if "<quote" in content:
try:
quote_info = await self._extract_quote_element(content)
if quote_info:
quote = quote_info["quote"]
content_for_parsing = quote_info["content_without_quote"]
except Exception as e:
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
if quote:
# 引用消息
quote_abm = await self._convert_quote_message(quote)
if quote_abm:
sender_id = quote_abm.sender.user_id
if isinstance(sender_id, str) and sender_id.isdigit():
sender_id = int(sender_id)
elif not isinstance(sender_id, int):
sender_id = 0 # 默认值
reply_component = Reply(
id=quote_abm.message_id,
chain=quote_abm.message,
sender_id=quote_abm.sender.user_id,
sender_nickname=quote_abm.sender.nickname,
time=quote_abm.timestamp,
message_str=quote_abm.message_str,
text=quote_abm.message_str,
qq=sender_id,
)
abm.message.append(reply_component)
# 解析消息内容
content_elements = await self.parse_satori_elements(content_for_parsing)
abm.message.extend(content_elements)
# parse message_str
abm.message_str = ""
for comp in abm.message:
for comp in content_elements:
if isinstance(comp, Plain):
abm.message_str += comp.text
@@ -333,6 +386,163 @@ class SatoriPlatformAdapter(Platform):
logger.error(f"转换 Satori 消息失败: {e}")
return None
def _extract_namespace_prefixes(self, content: str) -> set:
"""提取XML内容中的命名空间前缀"""
prefixes = set()
# 查找所有标签
i = 0
while i < len(content):
# 查找开始标签
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 1 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content and "xmlns:" not in tag_content:
# 分割标签名
parts = tag_content.split()
if parts:
tag_name = parts[0]
if ":" in tag_name:
prefix = tag_name.split(":")[0]
# 确保是有效的命名空间前缀
if (
prefix.isalnum()
or prefix.replace("_", "").isalnum()
):
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
# 查找结束标签
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 2 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content:
prefix = tag_content.split(":")[0]
# 确保是有效的命名空间前缀
if prefix.isalnum() or prefix.replace("_", "").isalnum():
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
else:
i += 1
return prefixes
async def _extract_quote_element(self, content: str) -> Optional[dict]:
"""提取<quote>标签信息"""
try:
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
# 查找<quote>标签
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
if tag_name.lower() == "quote":
quote_element = elem
break
if quote_element is not None:
# 提取quote标签的属性
quote_id = quote_element.get("id", "")
# 提取<quote>标签内部的内容
inner_content = ""
if quote_element.text:
inner_content += quote_element.text
for child in quote_element:
inner_content += ET.tostring(
child, encoding="unicode", method="xml"
)
if child.tail:
inner_content += child.tail
# 构造移除了<quote>标签的内容
content_without_quote = content.replace(
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
)
return {
"quote": {"id": quote_id, "content": inner_content},
"content_without_quote": content_without_quote,
}
return None
except Exception as e:
logger.error(f"提取<quote>标签时发生错误: {e}")
return None
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
"""转换引用消息"""
try:
quote_abm = AstrBotMessage()
quote_abm.message_id = quote.get("id", "")
# 解析引用消息的发送者
quote_author = quote.get("author", {})
if quote_author:
quote_abm.sender = MessageMember(
user_id=quote_author.get("id", ""),
nickname=quote_author.get("nick", quote_author.get("name", "")),
)
else:
# 如果没有作者信息,使用默认值
quote_abm.sender = MessageMember(
user_id=quote.get("user_id", ""),
nickname="内容",
)
# 解析引用消息内容
quote_content = quote.get("content", "")
quote_abm.message = await self.parse_satori_elements(quote_content)
quote_abm.message_str = ""
for comp in quote_abm.message:
if isinstance(comp, Plain):
quote_abm.message_str += comp.text
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
# 如果没有任何内容,使用默认文本
if not quote_abm.message_str.strip():
quote_abm.message_str = "[引用消息]"
return quote_abm
except Exception as e:
logger.error(f"转换引用消息失败: {e}")
return None
async def parse_satori_elements(self, content: str) -> list:
"""解析 Satori 消息元素"""
elements = []
@@ -341,12 +551,35 @@ class SatoriPlatformAdapter(Platform):
return elements
try:
wrapped_content = f"<root>{content}</root>"
root = ET.fromstring(wrapped_content)
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
# 如果解析失败,将整个内容当作纯文本
if content.strip():
elements.append(Plain(text=content))
except Exception as e:
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
raise e
# 如果没有解析到任何元素,将整个内容当作纯文本
@@ -361,7 +594,12 @@ class SatoriPlatformAdapter(Platform):
elements.append(Plain(text=node.text))
for child in node:
tag_name = child.tag.lower()
# 获取标签名,去除命名空间前缀
tag_name = child.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
tag_name = tag_name.lower()
attrs = child.attrib
if tag_name == "at":
@@ -372,31 +610,59 @@ class SatoriPlatformAdapter(Platform):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:image/"):
src = src.split(",")[1]
elements.append(Image.fromBase64(src))
elif src.startswith("http"):
elements.append(Image.fromURL(src))
else:
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
elements.append(Image(file=src))
elif tag_name == "file":
src = attrs.get("src", "")
name = attrs.get("name", "文件")
if src:
elements.append(File(file=src, name=name))
elements.append(File(name=name, file=src))
elif tag_name in ("audio", "record"):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:audio/"):
src = src.split(",")[1]
elements.append(Record.fromBase64(src))
elif src.startswith("http"):
elements.append(Record.fromURL(src))
elements.append(Record(file=src))
elif tag_name == "quote":
# quote标签已经被特殊处理
pass
elif tag_name == "face":
face_id = attrs.get("id", "")
face_name = attrs.get("name", "")
face_type = attrs.get("type", "")
if face_name:
elements.append(Plain(text=f"[表情:{face_name}]"))
elif face_id and face_type:
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
elif face_id:
elements.append(Plain(text=f"[表情ID:{face_id}]"))
else:
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
elements.append(Plain(text="[表情]"))
elif tag_name == "ark":
# 作为纯文本添加到消息链中
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[ARK卡片]"))
elif tag_name == "json":
# JSON标签 视为ARK卡片消息
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[JSON卡片]"))
else:
# 未知标签,递归处理其内容
@@ -17,6 +17,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
session_id: str,
adapter: "SatoriPlatformAdapter",
):
# 更新平台元数据
if adapter and hasattr(adapter, "logins") and adapter.logins:
current_login = adapter.logins[0]
platform_name = current_login.get("platform", "satori")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
if not platform_meta.id and user_id:
platform_meta.id = f"{platform_name}({user_id})"
super().__init__(message_str, message_obj, platform_meta, session_id)
self.adapter = adapter
self.platform = None