feat: 拆分Discord 适配器的部分代码,并处理一些小的问题。

- 基于最小权限原则,修改了 Bot 申请的权限范围
- 拆分了代码,使得文件结构更加清晰
This commit is contained in:
lxfight
2025-06-20 21:43:23 +08:00
parent 50c22bbadb
commit b9fab74edc
4 changed files with 251 additions and 245 deletions
@@ -0,0 +1,116 @@
import discord
from astrbot import logger
# Discord Bot客户端
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
def __init__(self, token: str, proxy: str = None):
self.token = token
self.proxy = proxy
# 设置Intent权限,遵循权限最小化原则
intents = discord.Intents.default()
intents.message_content = True # 订阅消息内容事件 (Privileged)
intents.members = True # 订阅成员事件 (Privileged)
# 初始化Bot
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
is_mentioned = self.user in message.mentions
return {
"message": message,
"bot_id": str(self.user.id),
"content": message.content,
"username": message.author.display_name,
"userid": str(message.author.id),
"message_id": str(message.id),
"channel_id": str(message.channel.id),
"guild_id": str(message.guild.id) if message.guild else None,
"type": "message",
"is_mentioned": is_mentioned,
"clean_content": message.clean_content,
}
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
return {
"interaction": interaction,
"bot_id": str(self.user.id),
"content": self._extract_interaction_content(interaction),
"username": interaction.user.display_name,
"userid": str(interaction.user.id),
"message_id": str(interaction.id),
"channel_id": str(interaction.channel_id)
if interaction.channel_id
else None,
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"type": "interaction",
}
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
return
logger.debug(
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
)
if self.on_message_received:
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
async def on_interaction(self, interaction: discord.Interaction):
"""当接收到交互(按钮点击等)时触发"""
logger.debug(
f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}"
)
if self.on_message_received:
interaction_data = self._create_interaction_data(interaction)
await self.on_message_received(interaction_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type
interaction_data = getattr(interaction, "data", {})
if not interaction_data:
return ""
if interaction_type == discord.InteractionType.application_command:
command_name = interaction_data.get("name", "")
if options := interaction_data.get("options", []):
params = " ".join(
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
)
return f"/{command_name} {params}"
return f"/{command_name}"
elif interaction_type == discord.InteractionType.component:
custom_id = interaction_data.get("custom_id", "")
component_type = interaction_data.get("component_type", "")
return f"component:{custom_id}:{component_type}"
return str(interaction_data)
async def start_polling(self):
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
async def close(self):
"""关闭客户端"""
if not self.is_closed():
await super().close()
@@ -0,0 +1,125 @@
import discord
from typing import List
from astrbot.api.message_components import BaseMessageComponent
# Discord专用组件
class DiscordEmbed(BaseMessageComponent):
"""Discord Embed消息组件"""
type: str = "discord_embed"
def __init__(
self,
title: str = None,
description: str = None,
color: int = None,
url: str = None,
thumbnail: str = None,
image: str = None,
footer: str = None,
fields: List[dict] = None,
):
self.title = title
self.description = description
self.color = color
self.url = url
self.thumbnail = thumbnail
self.image = image
self.footer = footer
self.fields = fields or []
def to_discord_embed(self) -> discord.Embed:
"""转换为Discord Embed对象"""
embed = discord.Embed()
if self.title:
embed.title = self.title
if self.description:
embed.description = self.description
if self.color:
embed.color = self.color
if self.url:
embed.url = self.url
if self.thumbnail:
embed.set_thumbnail(url=self.thumbnail)
if self.image:
embed.set_image(url=self.image)
if self.footer:
embed.set_footer(text=self.footer)
for field in self.fields:
embed.add_field(
name=field.get("name", ""),
value=field.get("value", ""),
inline=field.get("inline", False),
)
return embed
class DiscordButton(BaseMessageComponent):
"""Discord按钮组件"""
type: str = "discord_button"
def __init__(
self,
label: str,
custom_id: str = None,
style: str = "primary",
emoji: str = None,
url: str = None,
disabled: bool = False,
):
self.label = label
self.custom_id = custom_id
self.style = style
self.emoji = emoji
self.url = url
self.disabled = disabled
class DiscordView(BaseMessageComponent):
"""Discord视图组件,包含按钮和选择菜单"""
type: str = "discord_view"
def __init__(
self, components: List[BaseMessageComponent] = None, timeout: float = None
):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)
for component in self.components:
if isinstance(component, DiscordButton):
button_style = getattr(
discord.ButtonStyle, component.style, discord.ButtonStyle.primary
)
if component.url:
# URL按钮
button = discord.ui.Button(
label=component.label,
style=discord.ButtonStyle.link,
url=component.url,
emoji=component.emoji,
disabled=component.disabled,
)
else:
# 普通按钮
button = discord.ui.Button(
label=component.label,
style=button_style,
custom_id=component.custom_id,
emoji=component.emoji,
disabled=component.disabled,
)
view.add_item(button)
return view
@@ -1,7 +1,5 @@
import asyncio
import discord
from typing import List
from astrbot.api.platform import (
Platform,
AstrBotMessage,
@@ -10,10 +8,11 @@ from astrbot.api.platform import (
MessageType,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent
from astrbot.api.message_components import Plain, Image, File
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
from astrbot import logger
from .client import DiscordBotClient
try:
from .discord_platform_event import DiscordPlatformEvent
@@ -22,241 +21,6 @@ except ImportError:
from discord_platform_event import DiscordPlatformEvent
# Discord专用组件
class DiscordEmbed(BaseMessageComponent):
"""Discord Embed消息组件"""
type: str = "discord_embed"
def __init__(
self,
title: str = None,
description: str = None,
color: int = None,
url: str = None,
thumbnail: str = None,
image: str = None,
footer: str = None,
fields: List[dict] = None,
):
self.title = title
self.description = description
self.color = color
self.url = url
self.thumbnail = thumbnail
self.image = image
self.footer = footer
self.fields = fields or []
def to_discord_embed(self) -> discord.Embed:
"""转换为Discord Embed对象"""
embed = discord.Embed()
if self.title:
embed.title = self.title
if self.description:
embed.description = self.description
if self.color:
embed.color = self.color
if self.url:
embed.url = self.url
if self.thumbnail:
embed.set_thumbnail(url=self.thumbnail)
if self.image:
embed.set_image(url=self.image)
if self.footer:
embed.set_footer(text=self.footer)
for field in self.fields:
embed.add_field(
name=field.get("name", ""),
value=field.get("value", ""),
inline=field.get("inline", False),
)
return embed
class DiscordButton(BaseMessageComponent):
"""Discord按钮组件"""
type: str = "discord_button"
def __init__(
self,
label: str,
custom_id: str = None,
style: str = "primary",
emoji: str = None,
url: str = None,
disabled: bool = False,
):
self.label = label
self.custom_id = custom_id
self.style = style
self.emoji = emoji
self.url = url
self.disabled = disabled
class DiscordView(BaseMessageComponent):
"""Discord视图组件,包含按钮和选择菜单"""
type: str = "discord_view"
def __init__(
self, components: List[BaseMessageComponent] = None, timeout: float = None
):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)
for component in self.components:
if isinstance(component, DiscordButton):
button_style = getattr(
discord.ButtonStyle, component.style, discord.ButtonStyle.primary
)
if component.url:
# URL按钮
button = discord.ui.Button(
label=component.label,
style=discord.ButtonStyle.link,
url=component.url,
emoji=component.emoji,
disabled=component.disabled,
)
else:
# 普通按钮
button = discord.ui.Button(
label=component.label,
style=button_style,
custom_id=component.custom_id,
emoji=component.emoji,
disabled=component.disabled,
)
view.add_item(button)
return view
# Discord Bot客户端
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
def __init__(self, token: str, proxy: str = None):
self.token = token
self.proxy = proxy
# 设置Intent权限,为了最大兼容性使用all()
intents = discord.Intents.all()
# 初始化Bot
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
is_mentioned = self.user in message.mentions
return {
"message": message,
"bot_id": str(self.user.id),
"content": message.content,
"username": message.author.display_name,
"userid": str(message.author.id),
"message_id": str(message.id),
"channel_id": str(message.channel.id),
"guild_id": str(message.guild.id) if message.guild else None,
"type": "message",
"is_mentioned": is_mentioned,
"clean_content": message.clean_content,
}
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
return {
"interaction": interaction,
"bot_id": str(self.user.id),
"content": self._extract_interaction_content(interaction),
"username": interaction.user.display_name,
"userid": str(interaction.user.id),
"message_id": str(interaction.id),
"channel_id": str(interaction.channel_id)
if interaction.channel_id
else None,
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"type": "interaction",
}
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
return
logger.debug(
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
)
if self.on_message_received:
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
async def on_interaction(self, interaction: discord.Interaction):
"""当接收到交互(按钮点击等)时触发"""
logger.debug(
f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}"
)
if self.on_message_received:
interaction_data = self._create_interaction_data(interaction)
await self.on_message_received(interaction_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type
interaction_data = getattr(interaction, "data", {})
if not interaction_data:
return ""
if interaction_type == discord.InteractionType.application_command:
command_name = interaction_data.get("name", "")
options = interaction_data.get("options", [])
if options:
params = " ".join(
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
)
return f"/{command_name} {params}"
return f"/{command_name}"
elif interaction_type == discord.InteractionType.component:
custom_id = interaction_data.get("custom_id", "")
component_type = interaction_data.get("component_type", "")
return f"component:{custom_id}:{component_type}"
return str(interaction_data)
async def start_polling(self):
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
async def close(self):
"""关闭客户端"""
if not self.is_closed():
await super().close()
# 注册平台适配器
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
class DiscordPlatformAdapter(Platform):
@@ -299,6 +63,8 @@ class DiscordPlatformAdapter(Platform):
# 初始化回调函数
async def on_received(message_data):
logger.debug(f"[Discord] 收到消息: {message_data}")
if self.client_self_id is None:
self.client_self_id = message_data.get("bot_id")
abm = await self.convert_message(data=message_data)
await self.handle_msg(abm)
@@ -67,10 +67,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
"""获取当前事件对应的频道对象"""
try:
channel_id = int(self.session_id)
channel = self.client.get_channel(
return self.client.get_channel(
channel_id
) or await self.client.fetch_channel(channel_id)
return channel
except (ValueError, discord.errors.NotFound, discord.errors.Forbidden):
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
return None
@@ -81,9 +80,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
try:
from .discord_platform_adapter import DiscordEmbed, DiscordView
from .components import DiscordEmbed, DiscordView
except ImportError:
from discord_platform_adapter import DiscordEmbed, DiscordView
from components import DiscordEmbed, DiscordView
plain_text_parts = []
files = []
@@ -164,9 +163,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
# 合并文本内容
content = "\n".join(plain_text_parts)
if len(content) > 20000:
logger.warning("[Discord] 消息内容超过20000字符,将被截断。")
content = content[:20000]
if len(content) > 2000:
logger.warning("[Discord] 消息内容超过2000字符,将被截断。")
content = content[:2000]
return content, files, view, embeds