diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py new file mode 100644 index 000000000..f151c1d15 --- /dev/null +++ b/astrbot/core/platform/sources/discord/client.py @@ -0,0 +1,116 @@ +import discord +from astrbot import logger + + +# Discord Bot客户端 +class DiscordBotClient(discord.Bot): + """Discord客户端封装""" + + def __init__(self, token: str, proxy: str = None): + self.token = token + self.proxy = proxy + + # 设置Intent权限,遵循权限最小化原则 + intents = discord.Intents.default() + intents.message_content = True # 订阅消息内容事件 (Privileged) + intents.members = True # 订阅成员事件 (Privileged) + + # 初始化Bot + super().__init__(intents=intents, proxy=proxy) + + # 回调函数 + self.on_message_received = None + + async def on_ready(self): + """当机器人成功连接并准备就绪时触发""" + logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") + logger.info("[Discord] 客户端已准备就绪。") + + def _create_message_data(self, message: discord.Message) -> dict: + """从 discord.Message 创建数据字典""" + is_mentioned = self.user in message.mentions + return { + "message": message, + "bot_id": str(self.user.id), + "content": message.content, + "username": message.author.display_name, + "userid": str(message.author.id), + "message_id": str(message.id), + "channel_id": str(message.channel.id), + "guild_id": str(message.guild.id) if message.guild else None, + "type": "message", + "is_mentioned": is_mentioned, + "clean_content": message.clean_content, + } + + def _create_interaction_data(self, interaction: discord.Interaction) -> dict: + """从 discord.Interaction 创建数据字典""" + return { + "interaction": interaction, + "bot_id": str(self.user.id), + "content": self._extract_interaction_content(interaction), + "username": interaction.user.display_name, + "userid": str(interaction.user.id), + "message_id": str(interaction.id), + "channel_id": str(interaction.channel_id) + if interaction.channel_id + else None, + "guild_id": str(interaction.guild_id) if interaction.guild_id else None, + "type": "interaction", + } + + async def on_message(self, message: discord.Message): + """当接收到消息时触发""" + if message.author.bot: + return + + logger.debug( + f"[Discord] 收到原始消息 from {message.author.name}: {message.content}" + ) + + if self.on_message_received: + message_data = self._create_message_data(message) + await self.on_message_received(message_data) + + async def on_interaction(self, interaction: discord.Interaction): + """当接收到交互(按钮点击等)时触发""" + logger.debug( + f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}" + ) + + if self.on_message_received: + interaction_data = self._create_interaction_data(interaction) + await self.on_message_received(interaction_data) + + def _extract_interaction_content(self, interaction: discord.Interaction) -> str: + """从交互中提取内容""" + interaction_type = interaction.type + interaction_data = getattr(interaction, "data", {}) + + if not interaction_data: + return "" + + if interaction_type == discord.InteractionType.application_command: + command_name = interaction_data.get("name", "") + if options := interaction_data.get("options", []): + params = " ".join( + [f"{opt['name']}:{opt.get('value', '')}" for opt in options] + ) + return f"/{command_name} {params}" + return f"/{command_name}" + + elif interaction_type == discord.InteractionType.component: + custom_id = interaction_data.get("custom_id", "") + component_type = interaction_data.get("component_type", "") + return f"component:{custom_id}:{component_type}" + + return str(interaction_data) + + async def start_polling(self): + """开始轮询消息,这是个阻塞方法""" + await self.start(self.token) + + async def close(self): + """关闭客户端""" + if not self.is_closed(): + await super().close() diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py new file mode 100644 index 000000000..996f79574 --- /dev/null +++ b/astrbot/core/platform/sources/discord/components.py @@ -0,0 +1,125 @@ +import discord +from typing import List +from astrbot.api.message_components import BaseMessageComponent + + +# Discord专用组件 +class DiscordEmbed(BaseMessageComponent): + """Discord Embed消息组件""" + + type: str = "discord_embed" + + def __init__( + self, + title: str = None, + description: str = None, + color: int = None, + url: str = None, + thumbnail: str = None, + image: str = None, + footer: str = None, + fields: List[dict] = None, + ): + self.title = title + self.description = description + self.color = color + self.url = url + self.thumbnail = thumbnail + self.image = image + self.footer = footer + self.fields = fields or [] + + def to_discord_embed(self) -> discord.Embed: + """转换为Discord Embed对象""" + embed = discord.Embed() + + if self.title: + embed.title = self.title + if self.description: + embed.description = self.description + if self.color: + embed.color = self.color + if self.url: + embed.url = self.url + if self.thumbnail: + embed.set_thumbnail(url=self.thumbnail) + if self.image: + embed.set_image(url=self.image) + if self.footer: + embed.set_footer(text=self.footer) + + for field in self.fields: + embed.add_field( + name=field.get("name", ""), + value=field.get("value", ""), + inline=field.get("inline", False), + ) + + return embed + + +class DiscordButton(BaseMessageComponent): + """Discord按钮组件""" + + type: str = "discord_button" + + def __init__( + self, + label: str, + custom_id: str = None, + style: str = "primary", + emoji: str = None, + url: str = None, + disabled: bool = False, + ): + self.label = label + self.custom_id = custom_id + self.style = style + self.emoji = emoji + self.url = url + self.disabled = disabled + + +class DiscordView(BaseMessageComponent): + """Discord视图组件,包含按钮和选择菜单""" + + type: str = "discord_view" + + def __init__( + self, components: List[BaseMessageComponent] = None, timeout: float = None + ): + self.components = components or [] + self.timeout = timeout + + def to_discord_view(self) -> discord.ui.View: + """转换为Discord View对象""" + view = discord.ui.View(timeout=self.timeout) + + for component in self.components: + if isinstance(component, DiscordButton): + button_style = getattr( + discord.ButtonStyle, component.style, discord.ButtonStyle.primary + ) + + if component.url: + # URL按钮 + button = discord.ui.Button( + label=component.label, + style=discord.ButtonStyle.link, + url=component.url, + emoji=component.emoji, + disabled=component.disabled, + ) + else: + # 普通按钮 + button = discord.ui.Button( + label=component.label, + style=button_style, + custom_id=component.custom_id, + emoji=component.emoji, + disabled=component.disabled, + ) + + view.add_item(button) + + return view diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 9e4844970..3fab55e26 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -1,7 +1,5 @@ import asyncio import discord -from typing import List - from astrbot.api.platform import ( Platform, AstrBotMessage, @@ -10,10 +8,11 @@ from astrbot.api.platform import ( MessageType, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent +from astrbot.api.message_components import Plain, Image, File from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.api.platform import register_platform_adapter from astrbot import logger +from .client import DiscordBotClient try: from .discord_platform_event import DiscordPlatformEvent @@ -22,241 +21,6 @@ except ImportError: from discord_platform_event import DiscordPlatformEvent -# Discord专用组件 -class DiscordEmbed(BaseMessageComponent): - """Discord Embed消息组件""" - - type: str = "discord_embed" - - def __init__( - self, - title: str = None, - description: str = None, - color: int = None, - url: str = None, - thumbnail: str = None, - image: str = None, - footer: str = None, - fields: List[dict] = None, - ): - self.title = title - self.description = description - self.color = color - self.url = url - self.thumbnail = thumbnail - self.image = image - self.footer = footer - self.fields = fields or [] - - def to_discord_embed(self) -> discord.Embed: - """转换为Discord Embed对象""" - embed = discord.Embed() - - if self.title: - embed.title = self.title - if self.description: - embed.description = self.description - if self.color: - embed.color = self.color - if self.url: - embed.url = self.url - if self.thumbnail: - embed.set_thumbnail(url=self.thumbnail) - if self.image: - embed.set_image(url=self.image) - if self.footer: - embed.set_footer(text=self.footer) - - for field in self.fields: - embed.add_field( - name=field.get("name", ""), - value=field.get("value", ""), - inline=field.get("inline", False), - ) - - return embed - - -class DiscordButton(BaseMessageComponent): - """Discord按钮组件""" - - type: str = "discord_button" - - def __init__( - self, - label: str, - custom_id: str = None, - style: str = "primary", - emoji: str = None, - url: str = None, - disabled: bool = False, - ): - self.label = label - self.custom_id = custom_id - self.style = style - self.emoji = emoji - self.url = url - self.disabled = disabled - - -class DiscordView(BaseMessageComponent): - """Discord视图组件,包含按钮和选择菜单""" - - type: str = "discord_view" - - def __init__( - self, components: List[BaseMessageComponent] = None, timeout: float = None - ): - self.components = components or [] - self.timeout = timeout - - def to_discord_view(self) -> discord.ui.View: - """转换为Discord View对象""" - view = discord.ui.View(timeout=self.timeout) - - for component in self.components: - if isinstance(component, DiscordButton): - button_style = getattr( - discord.ButtonStyle, component.style, discord.ButtonStyle.primary - ) - - if component.url: - # URL按钮 - button = discord.ui.Button( - label=component.label, - style=discord.ButtonStyle.link, - url=component.url, - emoji=component.emoji, - disabled=component.disabled, - ) - else: - # 普通按钮 - button = discord.ui.Button( - label=component.label, - style=button_style, - custom_id=component.custom_id, - emoji=component.emoji, - disabled=component.disabled, - ) - - view.add_item(button) - - return view - - -# Discord Bot客户端 -class DiscordBotClient(discord.Bot): - """Discord客户端封装""" - - def __init__(self, token: str, proxy: str = None): - self.token = token - self.proxy = proxy - - # 设置Intent权限,为了最大兼容性使用all() - intents = discord.Intents.all() - - # 初始化Bot - super().__init__(intents=intents, proxy=proxy) - - # 回调函数 - self.on_message_received = None - - async def on_ready(self): - """当机器人成功连接并准备就绪时触发""" - logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") - logger.info("[Discord] 客户端已准备就绪。") - - def _create_message_data(self, message: discord.Message) -> dict: - """从 discord.Message 创建数据字典""" - is_mentioned = self.user in message.mentions - return { - "message": message, - "bot_id": str(self.user.id), - "content": message.content, - "username": message.author.display_name, - "userid": str(message.author.id), - "message_id": str(message.id), - "channel_id": str(message.channel.id), - "guild_id": str(message.guild.id) if message.guild else None, - "type": "message", - "is_mentioned": is_mentioned, - "clean_content": message.clean_content, - } - - def _create_interaction_data(self, interaction: discord.Interaction) -> dict: - """从 discord.Interaction 创建数据字典""" - return { - "interaction": interaction, - "bot_id": str(self.user.id), - "content": self._extract_interaction_content(interaction), - "username": interaction.user.display_name, - "userid": str(interaction.user.id), - "message_id": str(interaction.id), - "channel_id": str(interaction.channel_id) - if interaction.channel_id - else None, - "guild_id": str(interaction.guild_id) if interaction.guild_id else None, - "type": "interaction", - } - - async def on_message(self, message: discord.Message): - """当接收到消息时触发""" - if message.author.bot: - return - - logger.debug( - f"[Discord] 收到原始消息 from {message.author.name}: {message.content}" - ) - - if self.on_message_received: - message_data = self._create_message_data(message) - await self.on_message_received(message_data) - - async def on_interaction(self, interaction: discord.Interaction): - """当接收到交互(按钮点击等)时触发""" - logger.debug( - f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}" - ) - - if self.on_message_received: - interaction_data = self._create_interaction_data(interaction) - await self.on_message_received(interaction_data) - - def _extract_interaction_content(self, interaction: discord.Interaction) -> str: - """从交互中提取内容""" - interaction_type = interaction.type - interaction_data = getattr(interaction, "data", {}) - - if not interaction_data: - return "" - - if interaction_type == discord.InteractionType.application_command: - command_name = interaction_data.get("name", "") - options = interaction_data.get("options", []) - if options: - params = " ".join( - [f"{opt['name']}:{opt.get('value', '')}" for opt in options] - ) - return f"/{command_name} {params}" - return f"/{command_name}" - - elif interaction_type == discord.InteractionType.component: - custom_id = interaction_data.get("custom_id", "") - component_type = interaction_data.get("component_type", "") - return f"component:{custom_id}:{component_type}" - - return str(interaction_data) - - async def start_polling(self): - """开始轮询消息,这是个阻塞方法""" - await self.start(self.token) - - async def close(self): - """关闭客户端""" - if not self.is_closed(): - await super().close() - - # 注册平台适配器 @register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)") class DiscordPlatformAdapter(Platform): @@ -299,6 +63,8 @@ class DiscordPlatformAdapter(Platform): # 初始化回调函数 async def on_received(message_data): logger.debug(f"[Discord] 收到消息: {message_data}") + if self.client_self_id is None: + self.client_self_id = message_data.get("bot_id") abm = await self.convert_message(data=message_data) await self.handle_msg(abm) diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 3cce9e1f3..653b84dae 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -67,10 +67,9 @@ class DiscordPlatformEvent(AstrMessageEvent): """获取当前事件对应的频道对象""" try: channel_id = int(self.session_id) - channel = self.client.get_channel( + return self.client.get_channel( channel_id ) or await self.client.fetch_channel(channel_id) - return channel except (ValueError, discord.errors.NotFound, discord.errors.Forbidden): logger.error(f"[Discord] 无法获取频道 {self.session_id}") return None @@ -81,9 +80,9 @@ class DiscordPlatformEvent(AstrMessageEvent): ) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]: """将 MessageChain 解析为 Discord 发送所需的内容""" try: - from .discord_platform_adapter import DiscordEmbed, DiscordView + from .components import DiscordEmbed, DiscordView except ImportError: - from discord_platform_adapter import DiscordEmbed, DiscordView + from components import DiscordEmbed, DiscordView plain_text_parts = [] files = [] @@ -164,9 +163,9 @@ class DiscordPlatformEvent(AstrMessageEvent): # 合并文本内容 content = "\n".join(plain_text_parts) - if len(content) > 20000: - logger.warning("[Discord] 消息内容超过20000字符,将被截断。") - content = content[:20000] + if len(content) > 2000: + logger.warning("[Discord] 消息内容超过2000字符,将被截断。") + content = content[:2000] return content, files, view, embeds