From d0b10b9195cb86689edc5c0a2e18da4a40600689 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Fri, 20 Jun 2025 21:22:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20Discord=20?= =?UTF-8?q?=E5=B9=B3=E5=8F=B0=E9=80=82=E9=85=8D=E5=99=A8=E5=8F=8A=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E7=BB=84=E4=BB=B6=EF=BC=8C=E6=94=AF=E6=8C=81=20Discor?= =?UTF-8?q?d=20Bot=20=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加了一个新的依赖 py-cord[speed] - 添加了针对 Discord 平台的 Discord Bot 适配器 --- astrbot/core/config/default.py | 17 + astrbot/core/platform/manager.py | 8 +- .../discord/discord_platform_adapter.py | 444 ++++++++++++++++++ .../sources/discord/discord_platform_event.py | 250 ++++++++++ pyproject.toml | 1 + 5 files changed, 719 insertions(+), 1 deletion(-) create mode 100644 astrbot/core/platform/sources/discord/discord_platform_adapter.py create mode 100644 astrbot/core/platform/sources/discord/discord_platform_event.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 32dd1b454..bedaaad7f 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -225,6 +225,13 @@ CONFIG_METADATA_2 = { "telegram_command_auto_refresh": True, "telegram_command_register_interval": 300, }, + "discord":{ + "id": "discord", + "type": "discord", + "enable": False, + "discord_token": "在此处填入你的Discord Bot Token", + "discord_proxy": "", + } }, "items": { "active_send_mode": { @@ -324,6 +331,16 @@ CONFIG_METADATA_2 = { "hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", "obvious_hint": True, }, + "discord_token":{ + "description": "Discord Bot Token", + "type": "string", + "hint": "在此处填入你的Discord Bot Token" + }, + "discord_proxy":{ + "description": "Discord 代理地址", + "type": "string", + "hint": "可选的代理地址:http://ip:port" + }, }, }, "platform_settings": { diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 494900564..0cec2de20 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -77,7 +77,13 @@ class PlatformManager: case "wecom": from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401 case "weixin_official_account": - from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa + from .sources.weixin_official_account.weixin_offacc_adapter import ( + WeixinOfficialAccountPlatformAdapter, # noqa + ) + case "discord": + from .sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, # noqa: F401 + ) except (ImportError, ModuleNotFoundError) as e: logger.error( f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。" diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py new file mode 100644 index 000000000..9e4844970 --- /dev/null +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -0,0 +1,444 @@ +import asyncio +import discord +from typing import List + +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + PlatformMetadata, + MessageType, +) +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.api.platform import register_platform_adapter +from astrbot import logger + +try: + from .discord_platform_event import DiscordPlatformEvent +except ImportError: + # 如果相对导入失败,尝试绝对导入 + from discord_platform_event import DiscordPlatformEvent + + +# Discord专用组件 +class DiscordEmbed(BaseMessageComponent): + """Discord Embed消息组件""" + + type: str = "discord_embed" + + def __init__( + self, + title: str = None, + description: str = None, + color: int = None, + url: str = None, + thumbnail: str = None, + image: str = None, + footer: str = None, + fields: List[dict] = None, + ): + self.title = title + self.description = description + self.color = color + self.url = url + self.thumbnail = thumbnail + self.image = image + self.footer = footer + self.fields = fields or [] + + def to_discord_embed(self) -> discord.Embed: + """转换为Discord Embed对象""" + embed = discord.Embed() + + if self.title: + embed.title = self.title + if self.description: + embed.description = self.description + if self.color: + embed.color = self.color + if self.url: + embed.url = self.url + if self.thumbnail: + embed.set_thumbnail(url=self.thumbnail) + if self.image: + embed.set_image(url=self.image) + if self.footer: + embed.set_footer(text=self.footer) + + for field in self.fields: + embed.add_field( + name=field.get("name", ""), + value=field.get("value", ""), + inline=field.get("inline", False), + ) + + return embed + + +class DiscordButton(BaseMessageComponent): + """Discord按钮组件""" + + type: str = "discord_button" + + def __init__( + self, + label: str, + custom_id: str = None, + style: str = "primary", + emoji: str = None, + url: str = None, + disabled: bool = False, + ): + self.label = label + self.custom_id = custom_id + self.style = style + self.emoji = emoji + self.url = url + self.disabled = disabled + + +class DiscordView(BaseMessageComponent): + """Discord视图组件,包含按钮和选择菜单""" + + type: str = "discord_view" + + def __init__( + self, components: List[BaseMessageComponent] = None, timeout: float = None + ): + self.components = components or [] + self.timeout = timeout + + def to_discord_view(self) -> discord.ui.View: + """转换为Discord View对象""" + view = discord.ui.View(timeout=self.timeout) + + for component in self.components: + if isinstance(component, DiscordButton): + button_style = getattr( + discord.ButtonStyle, component.style, discord.ButtonStyle.primary + ) + + if component.url: + # URL按钮 + button = discord.ui.Button( + label=component.label, + style=discord.ButtonStyle.link, + url=component.url, + emoji=component.emoji, + disabled=component.disabled, + ) + else: + # 普通按钮 + button = discord.ui.Button( + label=component.label, + style=button_style, + custom_id=component.custom_id, + emoji=component.emoji, + disabled=component.disabled, + ) + + view.add_item(button) + + return view + + +# Discord Bot客户端 +class DiscordBotClient(discord.Bot): + """Discord客户端封装""" + + def __init__(self, token: str, proxy: str = None): + self.token = token + self.proxy = proxy + + # 设置Intent权限,为了最大兼容性使用all() + intents = discord.Intents.all() + + # 初始化Bot + super().__init__(intents=intents, proxy=proxy) + + # 回调函数 + self.on_message_received = None + + async def on_ready(self): + """当机器人成功连接并准备就绪时触发""" + logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") + logger.info("[Discord] 客户端已准备就绪。") + + def _create_message_data(self, message: discord.Message) -> dict: + """从 discord.Message 创建数据字典""" + is_mentioned = self.user in message.mentions + return { + "message": message, + "bot_id": str(self.user.id), + "content": message.content, + "username": message.author.display_name, + "userid": str(message.author.id), + "message_id": str(message.id), + "channel_id": str(message.channel.id), + "guild_id": str(message.guild.id) if message.guild else None, + "type": "message", + "is_mentioned": is_mentioned, + "clean_content": message.clean_content, + } + + def _create_interaction_data(self, interaction: discord.Interaction) -> dict: + """从 discord.Interaction 创建数据字典""" + return { + "interaction": interaction, + "bot_id": str(self.user.id), + "content": self._extract_interaction_content(interaction), + "username": interaction.user.display_name, + "userid": str(interaction.user.id), + "message_id": str(interaction.id), + "channel_id": str(interaction.channel_id) + if interaction.channel_id + else None, + "guild_id": str(interaction.guild_id) if interaction.guild_id else None, + "type": "interaction", + } + + async def on_message(self, message: discord.Message): + """当接收到消息时触发""" + if message.author.bot: + return + + logger.debug( + f"[Discord] 收到原始消息 from {message.author.name}: {message.content}" + ) + + if self.on_message_received: + message_data = self._create_message_data(message) + await self.on_message_received(message_data) + + async def on_interaction(self, interaction: discord.Interaction): + """当接收到交互(按钮点击等)时触发""" + logger.debug( + f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}" + ) + + if self.on_message_received: + interaction_data = self._create_interaction_data(interaction) + await self.on_message_received(interaction_data) + + def _extract_interaction_content(self, interaction: discord.Interaction) -> str: + """从交互中提取内容""" + interaction_type = interaction.type + interaction_data = getattr(interaction, "data", {}) + + if not interaction_data: + return "" + + if interaction_type == discord.InteractionType.application_command: + command_name = interaction_data.get("name", "") + options = interaction_data.get("options", []) + if options: + params = " ".join( + [f"{opt['name']}:{opt.get('value', '')}" for opt in options] + ) + return f"/{command_name} {params}" + return f"/{command_name}" + + elif interaction_type == discord.InteractionType.component: + custom_id = interaction_data.get("custom_id", "") + component_type = interaction_data.get("component_type", "") + return f"component:{custom_id}:{component_type}" + + return str(interaction_data) + + async def start_polling(self): + """开始轮询消息,这是个阻塞方法""" + await self.start(self.token) + + async def close(self): + """关闭客户端""" + if not self.is_closed(): + await super().close() + + +# 注册平台适配器 +@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)") +class DiscordPlatformAdapter(Platform): + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: + super().__init__(event_queue) + self.config = platform_config + self.settings = platform_settings + self.client_self_id = None + self.registered_handlers = [] + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): + """通过会话发送消息""" + # 创建临时事件对象来发送消息 + temp_event = DiscordPlatformEvent( + message_str="", + message_obj=None, + platform_meta=self.meta(), + session_id=session.session_id, + client=self.client, + ) + await temp_event.send(message_chain) + await super().send_by_session(session, message_chain) + + def meta(self) -> PlatformMetadata: + """返回平台元数据""" + return PlatformMetadata( + "discord", + "Discord 适配器 (基于 Pycord)", + id=self.config.get("id"), + default_config_tmpl=self.config, + ) + + async def run(self): + """主要运行逻辑""" + + # 初始化回调函数 + async def on_received(message_data): + logger.debug(f"[Discord] 收到消息: {message_data}") + abm = await self.convert_message(data=message_data) + await self.handle_msg(abm) + + # 初始化 Discord 客户端 + token = str(self.config.get("discord_token")) + if not token or "在此处" in token: + logger.error("[Discord] Bot Token 未配置。请在配置文件中正确设置 token。") + return + + proxy = self.config.get("discord_proxy") or None + self.client = DiscordBotClient(token, proxy) + self.client.on_message_received = on_received + + # 注册已登记的命令处理器 + self._register_handlers() + + try: + await self.client.start_polling() + except discord.errors.LoginFailure: + logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。") + except discord.errors.ConnectionClosed: + logger.warning("[Discord] 与 Discord 的连接已关闭。") + except Exception as e: + logger.error(f"[Discord] 适配器运行时发生意外错误: {e}", exc_info=True) + + def _register_handlers(self): + """注册命令处理器""" + # 这里可以扫描插件中使用装饰器的方法并注册 + # 由于AstrBot的插件系统,这部分需要在插件加载时处理 + pass + + def _determine_message_type( + self, channel, guild_id=None + ) -> tuple[MessageType, str]: + """判断消息类型和群组ID""" + if guild_id is None and ( + isinstance(channel, discord.DMChannel) + or getattr(channel, "guild", None) is None + ): + return MessageType.FRIEND_MESSAGE, "" + + gid = guild_id or getattr(channel, "guild", None).id + return MessageType.GROUP_MESSAGE, str(gid) + + def _convert_interaction_to_abm(self, data: dict) -> AstrBotMessage: + """将交互事件转换为 AstrBotMessage""" + interaction = data["interaction"] + abm = AstrBotMessage() + + abm.type, abm.group_id = self._determine_message_type( + interaction.channel, interaction.guild_id + ) + + abm.message_str = data["content"] + abm.sender = MessageMember( + user_id=str(interaction.user.id), nickname=interaction.user.display_name + ) + abm.message = [Plain(text=data["content"])] + abm.raw_message = interaction + abm.self_id = self.client_self_id + abm.session_id = ( + str(interaction.channel_id) + if interaction.channel_id + else str(interaction.user.id) + ) + abm.message_id = str(interaction.id) + return abm + + def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: + """将普通消息转换为 AstrBotMessage""" + message = data["message"] + is_mentioned = data.get("is_mentioned", False) + clean_content = data.get("clean_content", message.content) + abm = AstrBotMessage() + + abm.type, abm.group_id = self._determine_message_type(message.channel) + + abm.message_str = clean_content if is_mentioned else message.content + abm.sender = MessageMember( + user_id=str(message.author.id), nickname=message.author.display_name + ) + + message_chain = [] + if abm.message_str: + message_chain.append(Plain(text=abm.message_str)) + + if message.attachments: + for attachment in message.attachments: + if attachment.content_type and attachment.content_type.startswith( + "image/" + ): + message_chain.append( + Image(file=attachment.url, filename=attachment.filename) + ) + else: + message_chain.append( + File(name=attachment.filename, url=attachment.url) + ) + + abm.message = message_chain + abm.raw_message = message + abm.self_id = self.client_self_id + abm.session_id = str(message.channel.id) + abm.message_id = str(message.id) + return abm + + async def convert_message(self, data: dict) -> AstrBotMessage: + """将平台消息转换成 AstrBotMessage""" + if data.get("type") in ["interaction", "slash_command"]: + return self._convert_interaction_to_abm(data) + else: + return self._convert_message_to_abm(data) + + async def handle_msg(self, message: AstrBotMessage): + """处理消息""" + message_event = DiscordPlatformEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + client=self.client, + ) + + # 如果是被@的消息,设置为唤醒状态 + if ( + hasattr(message.raw_message, "mentions") + and self.client.user in message.raw_message.mentions + ): + message_event.is_wake = True + message_event.is_at_or_wake_command = True + + self.commit_event(message_event) + + async def terminate(self): + """终止适配器""" + logger.info("[Discord] 正在终止适配器...") + if self.client and hasattr(self.client, "close"): + await self.client.close() + logger.info("[Discord] 适配器已终止。") + + def register_handler(self, handler_info): + """注册处理器信息""" + self.registered_handlers.append(handler_info) diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py new file mode 100644 index 000000000..3cce9e1f3 --- /dev/null +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -0,0 +1,250 @@ +import asyncio +import discord +import base64 +from io import BytesIO +from pathlib import Path +from typing import Optional + +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata +from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent +from astrbot.core.utils.io import download_image_by_url +from astrbot import logger + + +# 自定义Discord视图组件(兼容旧版本) +class DiscordViewComponent(BaseMessageComponent): + type: str = "discord_view" + + def __init__(self, view: discord.ui.View): + self.view = view + + +class DiscordPlatformEvent(AstrMessageEvent): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + client, + ): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.client = client + + async def send(self, message: MessageChain): + """发送消息到Discord平台""" + try: + channel = await self._get_channel() + if not channel: + logger.error(f"[Discord] 无法获取频道 {self.session_id}") + return + + # 解析消息链 + content, files, view, embeds = await self._parse_to_discord(message) + + # Discord 不允许发送完全空的消息 + if not content and not files and not view and not embeds: + logger.debug("[Discord] 尝试发送空消息,已忽略。") + return + + # 发送消息 + await channel.send( + content=content or None, + files=files or None, + view=view or None, + embeds=embeds or None, + ) + + except discord.errors.HTTPException as e: + logger.error(f"[Discord] 发送消息失败: {e.status} {e.code} - {e.text}") + except Exception as e: + logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True) + + await super().send(message) + + async def _get_channel(self) -> Optional[discord.abc.Messageable]: + """获取当前事件对应的频道对象""" + try: + channel_id = int(self.session_id) + channel = self.client.get_channel( + channel_id + ) or await self.client.fetch_channel(channel_id) + return channel + except (ValueError, discord.errors.NotFound, discord.errors.Forbidden): + logger.error(f"[Discord] 无法获取频道 {self.session_id}") + return None + + async def _parse_to_discord( + self, + message: MessageChain, + ) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]: + """将 MessageChain 解析为 Discord 发送所需的内容""" + try: + from .discord_platform_adapter import DiscordEmbed, DiscordView + except ImportError: + from discord_platform_adapter import DiscordEmbed, DiscordView + + plain_text_parts = [] + files = [] + view = None + embeds = [] + + for i in message.chain: # 遍历消息链 + if isinstance(i, Plain): # 如果是文字类型的 + plain_text_parts.append(i.text) + elif isinstance(i, Image): # 如果是图片类型的 + try: + discord_file = None + # 优先使用组件指定的filename,否则从路径推断,最后使用默认值 + filename = i.filename + + async def process_local_path(p_str: str) -> Optional[discord.File]: + nonlocal filename + path = Path(p_str) + if not await asyncio.to_thread(path.exists): + logger.warning(f"[Discord] 图片文件不存在: {p_str}") + return None + + if not filename: # 如果没有指定filename,则从路径推断 + filename = path.name + + file_bytes = await asyncio.to_thread(path.read_bytes) + return discord.File(BytesIO(file_bytes), filename=filename) + + if i.file.startswith("file:///"): + discord_file = await process_local_path(i.file[8:]) + elif i.file.startswith("http"): + downloaded_path_str = await download_image_by_url(i.file) + if downloaded_path_str: + discord_file = await process_local_path(downloaded_path_str) + elif i.file.startswith("base64://"): + img_bytes = base64.b64decode(i.file.split("base64://")[1]) + discord_file = discord.File( + BytesIO(img_bytes), filename=filename or "image.png" + ) + else: # Treat as a local path + discord_file = await process_local_path(i.file) + + if discord_file: + files.append(discord_file) + + except Exception as e: + logger.warning(f"[Discord] 处理图片失败: {i.file}, 错误: {e}") + elif isinstance(i, File): + try: + file_path_str = await i.get_file() + if file_path_str: + path = Path(file_path_str) + if await asyncio.to_thread(path.exists): + file_bytes = await asyncio.to_thread(path.read_bytes) + files.append( + discord.File(BytesIO(file_bytes), filename=i.name) + ) + else: + logger.warning( + f"[Discord] 获取文件失败,路径不存在: {file_path_str}" + ) + else: + logger.warning(f"[Discord] 获取文件失败: {i.name}") + except Exception as e: + logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}") + elif isinstance(i, DiscordEmbed): + # Discord Embed消息 + embeds.append(i.to_discord_embed()) + elif isinstance(i, DiscordView): + # Discord视图组件(按钮、选择菜单等) + view = i.to_discord_view() + elif isinstance(i, DiscordViewComponent): + # 如果消息链中包含Discord视图组件(兼容旧版本) + if isinstance(i.view, discord.ui.View): + view = i.view + else: + logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}") + + # 合并文本内容 + content = "\n".join(plain_text_parts) + if len(content) > 20000: + logger.warning("[Discord] 消息内容超过20000字符,将被截断。") + content = content[:20000] + + return content, files, view, embeds + + async def reply(self, message: MessageChain): + """回复消息(如果原消息存在)""" + try: + if hasattr(self.message_obj, "raw_message") and hasattr( + self.message_obj.raw_message, "reply" + ): + # 解析消息链 + content, files, view, embeds = await self._parse_to_discord(message) + + # 使用Discord的回复功能 + await self.message_obj.raw_message.reply( + content=content or None, + files=files or None, + view=view or None, + embeds=embeds or None, + ) + else: + # 如果无法回复,使用普通发送 + await self.send(message) + except Exception as e: + logger.error(f"[Discord] 回复消息失败: {e}") + # 回退到普通发送 + await self.send(message) + + async def react(self, emoji: str): + """对原消息添加反应""" + try: + if hasattr(self.message_obj, "raw_message") and hasattr( + self.message_obj.raw_message, "add_reaction" + ): + await self.message_obj.raw_message.add_reaction(emoji) + except Exception as e: + logger.error(f"[Discord] 添加反应失败: {e}") + + def is_slash_command(self) -> bool: + """判断是否为斜杠命令""" + return ( + hasattr(self.message_obj, "raw_message") + and hasattr(self.message_obj.raw_message, "type") + and self.message_obj.raw_message.type + == discord.InteractionType.application_command + ) + + def is_button_interaction(self) -> bool: + """判断是否为按钮交互""" + return ( + hasattr(self.message_obj, "raw_message") + and hasattr(self.message_obj.raw_message, "type") + and self.message_obj.raw_message.type == discord.InteractionType.component + ) + + def get_interaction_custom_id(self) -> str: + """获取交互组件的custom_id""" + if self.is_button_interaction(): + try: + return self.message_obj.raw_message.data.get("custom_id", "") + except Exception: + pass + return "" + + def is_mentioned(self) -> bool: + """判断机器人是否被@""" + if hasattr(self.message_obj, "raw_message") and hasattr( + self.message_obj.raw_message, "mentions" + ): + return any( + mention.id == int(self.message_obj.self_id) + for mention in self.message_obj.raw_message.mentions + ) + return False + + def get_mention_clean_content(self) -> str: + """获取去除@后的清洁内容""" + if hasattr(self.message_obj, "raw_message") and hasattr( + self.message_obj.raw_message, "clean_content" + ): + return self.message_obj.raw_message.clean_content + return self.message_str diff --git a/pyproject.toml b/pyproject.toml index 9e903af3d..f15c6ccdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "pillow>=11.2.1", "pip>=25.1.1", "psutil>=5.8.0", + "py-cord[speed]>=2.6.1", "pydantic~=2.10.3", "pydub>=0.25.1", "pyjwt>=2.10.1",