Files
AstrBot/astrbot/core/platform/manager.py
T
shuiping233 866e546b59 feat: integrates KOOK platform adapter (#5658)
* feat: 将kook适配器插件并入astrbot官方适配器目录中

* refactor: 重命名函数名为 _warp_message

* refactor: 使用Protocol替换Union类型

* bugfix: 修复base64前缀处理问题

* refactor: 抛出的错误不再包含"[kook]"

* refactor: 添加读取本地文件时的路径安全检查

* refactor: 卡片消息解析失败时会打印错误信息

* refactor: 添加处理接收卡片消息内的图片url时的安全校验

* refactor: 安全处理ws需要重连的情况

* Revert "refactor: 使用Protocol替换Union类型"

This reverts commit 58e0dceeb20c3d7dddb16f623fd3bbdcfa632173.

* feat: 添加获取机器人名称的实现

* refactor: 让send_by_session发送主动消息时正确传入当前消息链的文本消息内容

* refactor: 统一处理适配器配置相关内容,处理仪表盘出传入配置,并添加仪表盘的kook适配器配置页面的i18n文本

* unittest: 添加kook适配器的单元测试,虽然没覆盖多少单测

* unittest: TEST_DATA_DIR用更安全的路径

* refactor: KookConfig使用了更好的默认值处理方式

* refactor: 移除kook_adapter 的config字段重复定义

* refactor: 隐藏获取kook gateway时url里的token,防止把token打印出来

* refactor: KookConfig.pretty_jsons使用*来屏蔽token内容

* bugfix: 修复主动发送消息时,调用了父方法`send_by_session`可能导致指标被重复上传的bug

* refactor: 优化upload_asset的路径处理报错

* bugfix: 修复kook ws心跳间隔可能会出现负数时间的bug

* refactor: KookClient移到KookPlatformAdapter.__init__里初始化

* bugfix: 修复处理base64 url 多替换了/而报错的bug

* refactor: kook适配器上传文件失败时,会抛出错误

* chore: 移除一条注释

* refactor: 移除没用的return

* refactor: 即使消息链中有消息发送失败了,也尽可能将其他消息发送出去,并把报错信息也发送出去

* refactor: 增强上传任务失败时的错误处理,使其发生错误时尽力而为发送其余消息

* refactor: 发送到消息频道的报错消息加了个⚠️,小巧思这块?

* refactor: 咱们在写适配器啊,要什么小巧思呢,小巧思给上游插件开发弄不好么)

* refactor: enhance Kook adapter with kmarkdown parsing and improve file URL handling

* refactor: extract card message parsing logic into a separate method

* feat: add kook_bot_nickname configuration to ignore messages from specific nicknames

* refactor: remove commented-out code and clean up file upload error handling

* fix: remove redundant prefix handling for file URLs in asset upload

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-03 15:08:16 +08:00

341 lines
13 KiB
Python

import asyncio
import traceback
from asyncio import Queue
from dataclasses import dataclass
from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
from .platform import Platform, PlatformStatus
from .register import platform_cls_map
from .sources.webchat.webchat_adapter import WebChatAdapter
@dataclass
class PlatformTasks:
run: asyncio.Task
wrapper: asyncio.Task
class PlatformManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None:
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""
self._inst_map: dict[str, dict] = {}
self._platform_tasks: dict[str, PlatformTasks] = {}
self.astrbot_config = config
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
这个配置中的 unique_session 需要特殊处理,
约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
self.event_queue = event_queue
def _is_valid_platform_id(self, platform_id: str | None) -> bool:
if not platform_id:
return False
return ":" not in platform_id and "!" not in platform_id
def _sanitize_platform_id(self, platform_id: str | None) -> tuple[str | None, bool]:
if not platform_id:
return platform_id, False
sanitized = platform_id.replace(":", "_").replace("!", "_")
return sanitized, sanitized != platform_id
def _start_platform_task(self, task_name: str, inst: Platform) -> None:
run_task = asyncio.create_task(inst.run(), name=task_name)
wrapper_task = asyncio.create_task(
self._task_wrapper(run_task, platform=inst),
name=f"{task_name}_wrapper",
)
self._platform_tasks[inst.client_self_id] = PlatformTasks(
run=run_task,
wrapper=wrapper_task,
)
async def _stop_platform_task(self, client_id: str) -> None:
tasks = self._platform_tasks.pop(client_id, None)
if not tasks:
return
for task in (tasks.run, tasks.wrapper):
if not task.done():
task.cancel()
await asyncio.gather(tasks.run, tasks.wrapper, return_exceptions=True)
async def _terminate_inst_and_tasks(self, inst: Platform) -> None:
client_id = inst.client_self_id
try:
if getattr(inst, "terminate", None):
try:
await inst.terminate()
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(
"终止平台适配器失败: client_id=%s, error=%s",
client_id,
e,
)
logger.error(traceback.format_exc())
finally:
await self._stop_platform_task(client_id)
async def initialize(self) -> None:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
try:
if ensure_platform_webhook_config(platform):
self.astrbot_config.save_config()
await self.load_platform(platform)
except Exception as e:
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
# 网页聊天
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
self._start_platform_task("webchat", webchat_inst)
async def load_platform(self, platform_config: dict) -> None:
"""实例化一个平台"""
# 动态导入
try:
if not platform_config["enable"]:
return
platform_id = platform_config.get("id")
if not self._is_valid_platform_id(platform_id):
sanitized_id, changed = self._sanitize_platform_id(platform_id)
if sanitized_id and changed:
logger.warning(
"平台 ID %r 包含非法字符 ':''!',已替换为 %r",
platform_id,
sanitized_id,
)
platform_config["id"] = sanitized_id
self.astrbot_config.save_config()
else:
logger.error(
f"平台 ID {platform_id!r} 不能为空,跳过加载该平台适配器。",
)
return
logger.info(
f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...",
)
match platform_config["type"]:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import (
AiocqhttpAdapter, # noqa: F401
)
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import (
QQOfficialPlatformAdapter, # noqa: F401
)
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import (
QQOfficialWebhookPlatformAdapter, # noqa: F401
)
case "lark":
from .sources.lark.lark_adapter import (
LarkPlatformAdapter, # noqa: F401
)
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram":
from .sources.telegram.tg_adapter import (
TelegramPlatformAdapter, # noqa: F401
)
case "wecom":
from .sources.wecom.wecom_adapter import (
WecomPlatformAdapter, # noqa: F401
)
case "wecom_ai_bot":
from .sources.wecom_ai_bot.wecomai_adapter import (
WecomAIBotAdapter, # noqa: F401
)
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa: F401
)
case "discord":
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "misskey":
from .sources.misskey.misskey_adapter import (
MisskeyPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
case "satori":
from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401
)
case "line":
from .sources.line.line_adapter import (
LinePlatformAdapter, # noqa: F401
)
case "kook":
from .sources.kook.kook_adapter import (
KookPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
)
except Exception as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}")
if platform_config["type"] not in platform_cls_map:
logger.error(
f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误",
)
return
cls_type = platform_cls_map[platform_config["type"]]
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config["id"]] = {
"inst": inst,
"client_id": inst.client_self_id,
}
self.platform_insts.append(inst)
self._start_platform_task(
f"platform_{platform_config['type']}_{platform_config['id']}",
inst,
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnPlatformLoadedEvent,
)
for handler in handlers:
try:
logger.info(
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
)
await handler.handler()
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(
self, task: asyncio.Task, platform: Platform | None = None
) -> None:
# 设置平台状态为运行中
if platform:
platform.status = PlatformStatus.RUNNING
try:
await task
except asyncio.CancelledError:
if platform:
platform.status = PlatformStatus.STOPPED
except Exception as e:
error_msg = str(e)
tb_str = traceback.format_exc()
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in tb_str.split("\n"):
logger.error(f"| {line}")
logger.error("-------")
# 记录错误到平台实例
if platform:
platform.record_error(error_msg, tb_str)
async def reload(self, platform_config: dict) -> None:
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
await self.load_platform(platform_config)
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
await self.terminate_platform(key)
async def terminate_platform(self, platform_id: str) -> None:
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id)
client_id = info["client_id"]
inst: Platform = info["inst"]
try:
self.platform_insts.remove(
next(
inst
for inst in self.platform_insts
if inst.client_self_id == client_id
),
)
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
await self._terminate_inst_and_tasks(inst)
async def terminate(self) -> None:
terminated_client_ids: set[str] = set()
for platform_id in list(self._inst_map.keys()):
info = self._inst_map.get(platform_id)
if info:
terminated_client_ids.add(info["client_id"])
await self.terminate_platform(platform_id)
for inst in list(self.platform_insts):
client_id = inst.client_self_id
if client_id in terminated_client_ids:
continue
await self._terminate_inst_and_tasks(inst)
self.platform_insts.clear()
self._inst_map.clear()
self._platform_tasks.clear()
def get_insts(self):
return self.platform_insts
def get_all_stats(self) -> dict:
"""获取所有平台的统计信息
Returns:
包含所有平台统计信息的字典
"""
stats_list = []
total_errors = 0
running_count = 0
error_count = 0
for inst in self.platform_insts:
try:
stat = inst.get_stats()
stats_list.append(stat)
total_errors += stat.get("error_count", 0)
if stat.get("status") == PlatformStatus.RUNNING.value:
running_count += 1
elif stat.get("status") == PlatformStatus.ERROR.value:
error_count += 1
except Exception as e:
# 如果获取统计信息失败,记录基本信息
logger.warning(f"获取平台统计信息失败: {e}")
stats_list.append(
{
"id": getattr(inst, "config", {}).get("id", "unknown"),
"type": "unknown",
"status": "unknown",
"error_count": 0,
"last_error": None,
}
)
return {
"platforms": stats_list,
"summary": {
"total": len(stats_list),
"running": running_count,
"error": error_count,
"total_errors": total_errors,
},
}