feat: supports 小黑盒语音机器人
This commit is contained in:
@@ -180,6 +180,10 @@ class PlatformManager:
|
||||
from .sources.line.line_adapter import (
|
||||
LinePlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "heihe":
|
||||
from .sources.heihe.heihe_adapter import (
|
||||
HeihePlatformAdapter, # noqa: F401
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||
|
||||
@@ -0,0 +1,467 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection, connect
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import At, Image, Plain
|
||||
from astrbot.api.platform import (
|
||||
AstrBotMessage,
|
||||
Group,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .heihe_event import HeiheMessageEvent
|
||||
|
||||
HEIHE_CONFIG_METADATA = {
|
||||
"heihe_ws_url": {
|
||||
"description": "Heihe WebSocket URL",
|
||||
"type": "string",
|
||||
"hint": "黑盒 WebSocket 地址,示例:wss://chat.xiaoheihe.cn/chatroom/api/v2/websocket。",
|
||||
},
|
||||
"heihe_token": {
|
||||
"description": "Bot Token",
|
||||
"type": "string",
|
||||
"hint": "黑盒 Bot Token。适配器会放入 Authorization: Bearer <token>。",
|
||||
},
|
||||
"heihe_heartbeat_interval": {
|
||||
"description": "Heartbeat Interval (seconds)",
|
||||
"type": "int",
|
||||
"hint": "发送心跳包间隔。<=0 表示关闭主动心跳。",
|
||||
},
|
||||
"heihe_reconnect_delay": {
|
||||
"description": "Reconnect Delay (seconds)",
|
||||
"type": "int",
|
||||
"hint": "WebSocket 断开后的重连等待时间。",
|
||||
},
|
||||
}
|
||||
|
||||
HEIHE_I18N_RESOURCES = {
|
||||
"zh-CN": {
|
||||
"heihe_ws_url": {
|
||||
"description": "黑盒 WebSocket 地址",
|
||||
"hint": "示例:wss://chat.xiaoheihe.cn/chatroom/api/v2/websocket。",
|
||||
},
|
||||
"heihe_token": {
|
||||
"description": "机器人 Token",
|
||||
"hint": "会以 Authorization: Bearer <token> 方式发送。",
|
||||
},
|
||||
"heihe_heartbeat_interval": {
|
||||
"description": "心跳间隔(秒)",
|
||||
"hint": "设置 <=0 将关闭主动心跳。",
|
||||
},
|
||||
"heihe_reconnect_delay": {
|
||||
"description": "重连间隔(秒)",
|
||||
"hint": "WebSocket 断开后的重连等待时间。",
|
||||
},
|
||||
},
|
||||
"en-US": {
|
||||
"heihe_ws_url": {
|
||||
"description": "Heihe WebSocket URL",
|
||||
"hint": "Example: wss://chat.xiaoheihe.cn/chatroom/api/v2/websocket.",
|
||||
},
|
||||
"heihe_token": {
|
||||
"description": "Bot Token",
|
||||
"hint": "Sent as Authorization: Bearer <token>.",
|
||||
},
|
||||
"heihe_heartbeat_interval": {
|
||||
"description": "Heartbeat Interval (seconds)",
|
||||
"hint": "Set <=0 to disable active heartbeat.",
|
||||
},
|
||||
"heihe_reconnect_delay": {
|
||||
"description": "Reconnect Delay (seconds)",
|
||||
"hint": "Delay before reconnecting after disconnect.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"heihe",
|
||||
"黑盒机器人(WebSocket)适配器",
|
||||
support_streaming_message=False,
|
||||
default_config_tmpl={
|
||||
"id": "heihe",
|
||||
"type": "heihe",
|
||||
"enable": False,
|
||||
"heihe_ws_url": "wss://chat.xiaoheihe.cn/chatroom/api/v2/websocket",
|
||||
"heihe_token": "",
|
||||
"heihe_bot_id": "",
|
||||
"heihe_auto_reconnect": True,
|
||||
"heihe_heartbeat_interval": 20,
|
||||
"heihe_reconnect_delay": 5,
|
||||
"heihe_ignore_self_message": True,
|
||||
},
|
||||
config_metadata=HEIHE_CONFIG_METADATA,
|
||||
i18n_resources=HEIHE_I18N_RESOURCES,
|
||||
)
|
||||
class HeihePlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
platform_config: dict,
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
|
||||
self.ws_url = str(platform_config.get("heihe_ws_url", "")).strip()
|
||||
self.token = str(platform_config.get("heihe_token", "")).strip()
|
||||
self.bot_id = str(platform_config.get("heihe_bot_id", "")).strip()
|
||||
self.auto_reconnect = bool(platform_config.get("heihe_auto_reconnect", True))
|
||||
self.heartbeat_interval = int(
|
||||
cast(int, platform_config.get("heihe_heartbeat_interval", 20)),
|
||||
)
|
||||
self.reconnect_delay = int(
|
||||
cast(int, platform_config.get("heihe_reconnect_delay", 5)),
|
||||
)
|
||||
self.ignore_self_message = bool(
|
||||
platform_config.get("heihe_ignore_self_message", True),
|
||||
)
|
||||
|
||||
if not self.ws_url:
|
||||
raise ValueError("heihe_ws_url 不能为空。")
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="heihe",
|
||||
description="黑盒机器人(WebSocket)适配器",
|
||||
id=cast(str, self.config.get("id", "heihe")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
self.ws: ClientConnection | None = None
|
||||
self.running = False
|
||||
self.heartbeat_task: asyncio.Task | None = None
|
||||
self._last_heartbeat_ts = 0
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def run(self) -> None:
|
||||
self.running = True
|
||||
while self.running:
|
||||
try:
|
||||
await self._connect_and_loop()
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning("[heihe] websocket disconnected: %s", e)
|
||||
except Exception as e:
|
||||
logger.error("[heihe] websocket failed: %s", e)
|
||||
|
||||
if not self.running:
|
||||
break
|
||||
if not self.auto_reconnect:
|
||||
break
|
||||
await asyncio.sleep(max(1, self.reconnect_delay))
|
||||
|
||||
async def terminate(self) -> None:
|
||||
self.running = False
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
try:
|
||||
await self.heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.ws = None
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
) -> None:
|
||||
await HeiheMessageEvent.send_with_adapter(
|
||||
self,
|
||||
message_chain,
|
||||
session.session_id,
|
||||
)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def send_payload(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.ws:
|
||||
raise RuntimeError("[heihe] websocket not connected")
|
||||
if self.ws.close_code is not None:
|
||||
raise RuntimeError("[heihe] websocket already closed")
|
||||
|
||||
body = dict(payload)
|
||||
body.setdefault("timestamp", int(time.time()))
|
||||
await self.ws.send(json.dumps(body, ensure_ascii=False))
|
||||
|
||||
async def _connect_and_loop(self) -> None:
|
||||
logger.info("[heihe] connecting websocket: %s", self.ws_url)
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
headers["X-Token"] = self.token
|
||||
|
||||
websocket = await connect(
|
||||
self.ws_url,
|
||||
additional_headers=headers,
|
||||
max_size=10 * 1024 * 1024,
|
||||
ping_interval=None,
|
||||
)
|
||||
self.ws = websocket
|
||||
logger.info("[heihe] websocket connected")
|
||||
|
||||
if self.heartbeat_interval > 0:
|
||||
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
try:
|
||||
async for raw in websocket:
|
||||
await self._handle_incoming(raw)
|
||||
finally:
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
try:
|
||||
await self.heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.heartbeat_task = None
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.ws = None
|
||||
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
try:
|
||||
while self.running and self.ws and self.ws.close_code is None:
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
self._last_heartbeat_ts = int(time.time())
|
||||
await self.send_payload(
|
||||
{
|
||||
"type": "ping",
|
||||
"ping": self._last_heartbeat_ts,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("[heihe] heartbeat error: %s", e)
|
||||
|
||||
async def _handle_incoming(self, raw: Any) -> None:
|
||||
if isinstance(raw, bytes):
|
||||
try:
|
||||
raw = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return
|
||||
if not isinstance(raw, str):
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("[heihe] skip non-json frame: %s", raw[:200])
|
||||
return
|
||||
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
await self._handle_packet(item)
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
await self._handle_packet(data)
|
||||
|
||||
async def _handle_packet(self, packet: dict[str, Any]) -> None:
|
||||
if "ping" in packet:
|
||||
await self.send_payload({"type": "pong", "pong": packet.get("ping")})
|
||||
return
|
||||
if str(packet.get("type", "")).lower() == "ping":
|
||||
await self.send_payload({"type": "pong", "pong": packet.get("ping")})
|
||||
return
|
||||
|
||||
event_type = str(
|
||||
packet.get("event")
|
||||
or packet.get("event_type")
|
||||
or packet.get("type")
|
||||
or packet.get("topic")
|
||||
or "",
|
||||
).lower()
|
||||
payload_obj = packet.get("data")
|
||||
payload = payload_obj if isinstance(payload_obj, dict) else packet
|
||||
|
||||
if not self._is_message_event(event_type, payload):
|
||||
return
|
||||
|
||||
abm = self._convert_message(payload, packet)
|
||||
if not abm:
|
||||
return
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@staticmethod
|
||||
def _is_message_event(event_type: str, payload: Mapping[str, Any]) -> bool:
|
||||
if "message" in event_type:
|
||||
return True
|
||||
keys = payload.keys()
|
||||
return "content" in keys or "text" in keys or "message" in keys
|
||||
|
||||
def _convert_message(
|
||||
self,
|
||||
payload: Mapping[str, Any],
|
||||
raw_packet: Mapping[str, Any],
|
||||
) -> AstrBotMessage | None:
|
||||
message_obj = payload.get("message")
|
||||
message = message_obj if isinstance(message_obj, Mapping) else payload
|
||||
|
||||
sender_data_obj = (
|
||||
payload.get("sender") or payload.get("author") or payload.get("user") or {}
|
||||
)
|
||||
sender_data = sender_data_obj if isinstance(sender_data_obj, Mapping) else {}
|
||||
sender_id = str(
|
||||
sender_data.get("id")
|
||||
or sender_data.get("user_id")
|
||||
or payload.get("sender_id")
|
||||
or payload.get("user_id")
|
||||
or "",
|
||||
).strip()
|
||||
sender_name = str(
|
||||
sender_data.get("nickname")
|
||||
or sender_data.get("name")
|
||||
or sender_data.get("username")
|
||||
or sender_id
|
||||
or "unknown",
|
||||
)
|
||||
|
||||
self_id = str(
|
||||
payload.get("self_id")
|
||||
or payload.get("bot_id")
|
||||
or self.bot_id
|
||||
or self.meta().id,
|
||||
)
|
||||
if self.ignore_self_message and sender_id and self_id and sender_id == self_id:
|
||||
return None
|
||||
|
||||
channel_id = str(
|
||||
payload.get("channel_id")
|
||||
or payload.get("room_id")
|
||||
or payload.get("chat_id")
|
||||
or payload.get("session_id")
|
||||
or "",
|
||||
).strip()
|
||||
guild_id = str(
|
||||
payload.get("guild_id")
|
||||
or payload.get("server_id")
|
||||
or payload.get("group_id")
|
||||
or "",
|
||||
).strip()
|
||||
is_private = bool(payload.get("is_private", False))
|
||||
if str(payload.get("message_type", "")).lower() in {"private", "friend", "dm"}:
|
||||
is_private = True
|
||||
|
||||
session_id = channel_id or sender_id
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
text = str(message.get("content") or message.get("text") or "").strip()
|
||||
components = self._build_components(text, payload)
|
||||
if not components:
|
||||
return None
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self_id
|
||||
abm.message_id = str(
|
||||
message.get("id")
|
||||
or message.get("message_id")
|
||||
or payload.get("message_id")
|
||||
or payload.get("msg_id")
|
||||
or uuid.uuid4().hex
|
||||
)
|
||||
timestamp_raw = (
|
||||
payload.get("timestamp")
|
||||
or payload.get("time")
|
||||
or message.get("timestamp")
|
||||
or message.get("time")
|
||||
)
|
||||
abm.timestamp = int(time.time())
|
||||
if isinstance(timestamp_raw, int):
|
||||
abm.timestamp = (
|
||||
timestamp_raw // 1000
|
||||
if timestamp_raw > 1_000_000_000_000
|
||||
else timestamp_raw
|
||||
)
|
||||
|
||||
if not is_private and (channel_id or guild_id):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group = Group(
|
||||
group_id=guild_id or channel_id, group_name=guild_id or ""
|
||||
)
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
abm.session_id = session_id
|
||||
abm.sender = MessageMember(user_id=sender_id or "unknown", nickname=sender_name)
|
||||
abm.message = components
|
||||
abm.message_str = self._build_message_str(components)
|
||||
abm.raw_message = dict(raw_packet)
|
||||
return abm
|
||||
|
||||
@staticmethod
|
||||
def _build_components(text: str, payload: Mapping[str, Any]) -> list:
|
||||
components: list = []
|
||||
if text:
|
||||
components.append(Plain(text=text))
|
||||
|
||||
mentions_obj = payload.get("mentions")
|
||||
if isinstance(mentions_obj, list):
|
||||
for mention in mentions_obj:
|
||||
if not isinstance(mention, Mapping):
|
||||
continue
|
||||
user_id = str(mention.get("user_id") or mention.get("id") or "").strip()
|
||||
name = str(mention.get("name") or mention.get("nickname") or "").strip()
|
||||
if user_id or name:
|
||||
components.append(At(qq=user_id, name=name))
|
||||
|
||||
attachments_obj = payload.get("attachments")
|
||||
if isinstance(attachments_obj, list):
|
||||
for item in attachments_obj:
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
url = str(item.get("url") or item.get("file_url") or "").strip()
|
||||
if not url:
|
||||
continue
|
||||
kind = str(item.get("type") or item.get("media_type") or "").lower()
|
||||
if "image" in kind:
|
||||
components.append(Image.fromURL(url))
|
||||
else:
|
||||
components.append(Plain(text=f"[{kind or 'file'}] {url}"))
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
def _build_message_str(components: list) -> str:
|
||||
parts: list[str] = []
|
||||
for comp in components:
|
||||
if isinstance(comp, Plain):
|
||||
parts.append(comp.text)
|
||||
elif isinstance(comp, At):
|
||||
parts.append(f"@{comp.name or comp.qq}")
|
||||
elif isinstance(comp, Image):
|
||||
parts.append("[image]")
|
||||
else:
|
||||
parts.append(f"[{comp.type}]")
|
||||
return " ".join(i for i in parts if i).strip()
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage) -> None:
|
||||
event = HeiheMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=self.meta(),
|
||||
session_id=abm.session_id,
|
||||
adapter=self,
|
||||
)
|
||||
self.commit_event(event)
|
||||
@@ -0,0 +1,108 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Image, Plain, Reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .heihe_adapter import HeihePlatformAdapter
|
||||
|
||||
|
||||
class HeiheMessageEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj,
|
||||
platform_meta,
|
||||
session_id: str,
|
||||
adapter: "HeihePlatformAdapter",
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.adapter = adapter
|
||||
|
||||
@classmethod
|
||||
async def send_with_adapter(
|
||||
cls,
|
||||
adapter: "HeihePlatformAdapter",
|
||||
message: MessageChain,
|
||||
session_id: str,
|
||||
) -> None:
|
||||
payload = await cls._build_send_payload(message, session_id)
|
||||
await adapter.send_payload(payload)
|
||||
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
await self.send_with_adapter(self.adapter, message, self.session_id)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self,
|
||||
generator: AsyncGenerator,
|
||||
use_fallback: bool = False,
|
||||
):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return None
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@classmethod
|
||||
async def _build_send_payload(
|
||||
cls,
|
||||
message: MessageChain,
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
text_parts: list[str] = []
|
||||
segments: list[dict[str, Any]] = []
|
||||
|
||||
for component in message.chain:
|
||||
if isinstance(component, Plain):
|
||||
if component.text:
|
||||
text_parts.append(component.text)
|
||||
segments.append({"type": "text", "text": component.text})
|
||||
continue
|
||||
|
||||
if isinstance(component, At):
|
||||
at_name = str(component.name or component.qq or "").strip()
|
||||
if at_name:
|
||||
text_parts.append(f"@{at_name}")
|
||||
segments.append(
|
||||
{
|
||||
"type": "mention",
|
||||
"user_id": str(component.qq or ""),
|
||||
"name": at_name,
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(component, Reply):
|
||||
if component.id:
|
||||
segments.append({"type": "reply", "message_id": component.id})
|
||||
continue
|
||||
|
||||
if isinstance(component, Image):
|
||||
image_url = ""
|
||||
try:
|
||||
image_url = await component.register_to_file_service()
|
||||
except Exception as e:
|
||||
logger.debug("[heihe] image upload fallback failed: %s", e)
|
||||
|
||||
if image_url:
|
||||
segments.append({"type": "image", "url": image_url})
|
||||
text_parts.append("[image]")
|
||||
continue
|
||||
|
||||
content = "".join(text_parts).strip()
|
||||
payload: dict[str, Any] = {
|
||||
"action": "send_message",
|
||||
"channel_id": session_id,
|
||||
"content": content,
|
||||
"segments": segments,
|
||||
}
|
||||
return payload
|
||||
Reference in New Issue
Block a user