feat: 添加 Discord 斜杠指令注册功能及相关配置项

feat: 添加 Activity 设置项
fix: 修复 At Reply 未处理的问题
This commit is contained in:
HakimYu
2025-06-22 16:29:02 +08:00
parent f9c3e4cdb0
commit ac4f3d8907
5 changed files with 290 additions and 100 deletions
+17 -1
View File
@@ -225,12 +225,15 @@ CONFIG_METADATA_2 = {
"telegram_command_auto_refresh": True,
"telegram_command_register_interval": 300,
},
"discord":{
"discord": {
"id": "discord",
"type": "discord",
"enable": False,
"discord_token": "",
"discord_proxy": "",
"discord_command_register": True,
"discord_guild_id_for_debug": "",
"discord_activity_name": "",
},
"slack": {
"id": "slack",
@@ -374,6 +377,19 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "可选的代理地址:http://ip:port"
},
"discord_command_register": {
"description": "是否自动将插件指令注册为 Discord 斜杠指令",
"type": "bool",
},
"discord_activity_name": {
"description": "Discord 活动名称",
"type": "string",
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
},
"discord_guild_id_for_debug": {
"description": "【开发用】指定一个服务器(Guild)ID。在此服务器注册的指令会立刻生效,便于调试。留空则注册为全局指令。",
"type": "string",
},
},
},
"platform_settings": {
@@ -1,5 +1,11 @@
import discord
from astrbot import logger
import sys
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# Discord Bot客户端
@@ -20,12 +26,23 @@ class DiscordBotClient(discord.Bot):
# 回调函数
self.on_message_received = None
self.on_ready_once_callback = None
self._ready_once_fired = False
@override
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
if self.on_ready_once_callback and not self._ready_once_fired:
self._ready_once_fired = True
try:
await self.on_ready_once_callback()
except Exception as e:
logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
is_mentioned = self.user in message.mentions
@@ -59,6 +76,7 @@ class DiscordBotClient(discord.Bot):
"type": "interaction",
}
@override
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
@@ -72,15 +90,6 @@ class DiscordBotClient(discord.Bot):
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
async def on_interaction(self, interaction: discord.Interaction):
"""当接收到交互(按钮点击等)时触发"""
logger.debug(
f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}"
)
if self.on_message_received:
interaction_data = self._create_interaction_data(interaction)
await self.on_message_received(interaction_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
@@ -110,6 +119,7 @@ class DiscordBotClient(discord.Bot):
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
@override
async def close(self):
"""关闭客户端"""
if not self.is_closed():
@@ -79,6 +79,13 @@ class DiscordButton(BaseMessageComponent):
self.url = url
self.disabled = disabled
class DiscordReference(BaseMessageComponent):
"""Discord引用组件"""
type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str):
self.message_id = message_id
self.channel_id = channel_id
class DiscordView(BaseMessageComponent):
"""Discord视图组件,包含按钮和选择菜单"""
@@ -91,6 +98,7 @@ class DiscordView(BaseMessageComponent):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)
@@ -14,6 +14,19 @@ from astrbot.api.platform import register_platform_adapter
from astrbot import logger
from .client import DiscordBotClient
from .discord_platform_event import DiscordPlatformEvent
import sys
from functools import partial
from typing import Any, Dict, List, Tuple, Type
from astrbot.core.star.filter.command import CommandFilter, GreedyStr
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
import re
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# 注册平台适配器
@@ -27,7 +40,13 @@ class DiscordPlatformAdapter(Platform):
self.settings = platform_settings
self.client_self_id = None
self.registered_handlers = []
# 指令注册相关
self.enable_command_register = self.config.get(
"discord_command_register", True)
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
self.activity_name = self.config.get("discord_activity_name", None)
@override
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
@@ -43,6 +62,7 @@ class DiscordPlatformAdapter(Platform):
await temp_event.send(message_chain)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
"""返回平台元数据"""
return PlatformMetadata(
@@ -52,6 +72,7 @@ class DiscordPlatformAdapter(Platform):
default_config_tmpl=self.config,
)
@override
async def run(self):
"""主要运行逻辑"""
@@ -73,6 +94,14 @@ class DiscordPlatformAdapter(Platform):
self.client = DiscordBotClient(token, proxy)
self.client.on_message_received = on_received
async def callback():
if self.enable_command_register:
await self._collect_and_register_commands()
if self.activity_name:
await self.client.change_presence(status=discord.Status.online, activity=discord.CustomActivity(name=self.activity_name))
self.client.on_ready_once_callback = callback
try:
await self.client.start_polling()
except discord.errors.LoginFailure:
@@ -95,32 +124,6 @@ class DiscordPlatformAdapter(Platform):
gid = guild_id or getattr(channel, "guild", None).id
return MessageType.GROUP_MESSAGE, str(gid)
def _convert_interaction_to_abm(self, data: dict) -> AstrBotMessage:
"""将交互事件转换为 AstrBotMessage"""
interaction: discord.Interaction = data["interaction"]
abm = AstrBotMessage()
abm.type, abm.group_id = self._determine_message_type(
interaction.channel, interaction.guild_id
)
# 对于交互事件,message_str 通常没有意义,且可能导致被闲聊等通用插件错误响应。
# 将其清空,以确保只有专门的指令处理器会响应。
abm.message_str = ""
abm.sender = MessageMember(
user_id=str(interaction.user.id), nickname=interaction.user.display_name
)
abm.message = [Plain(text=data["content"])]
abm.raw_message = interaction
abm.self_id = self.client_self_id
abm.session_id = (
str(interaction.channel_id)
if interaction.channel_id
else str(interaction.user.id)
)
abm.message_id = str(interaction.id)
return abm
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"]
@@ -142,9 +145,9 @@ class DiscordPlatformAdapter(Platform):
)
if content.startswith(mention_str):
content = content[len(mention_str) :].lstrip()
content = content[len(mention_str):].lstrip()
elif content.startswith(mention_str_nickname):
content = content[len(mention_str_nickname) :].lstrip()
content = content[len(mention_str_nickname):].lstrip()
abm = AstrBotMessage()
@@ -181,12 +184,10 @@ class DiscordPlatformAdapter(Platform):
async def convert_message(self, data: dict) -> AstrBotMessage:
"""将平台消息转换成 AstrBotMessage"""
if data.get("type") in ["interaction", "slash_command"]:
return self._convert_interaction_to_abm(data)
else:
return self._convert_message_to_abm(data)
# 由于 on_interaction 已被禁用,我们只处理普通消息
return self._convert_message_to_abm(data)
async def handle_msg(self, message: AstrBotMessage):
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None):
"""处理消息"""
message_event = DiscordPlatformEvent(
message_str=message.message_str,
@@ -194,23 +195,43 @@ class DiscordPlatformAdapter(Platform):
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
interaction_followup_webhook=followup_webhook,
)
# 如果是被@的消息,设置为唤醒状态
if (
# 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None
# 检查是否被@
is_mention = (
self.client
and self.client.user
and hasattr(message.raw_message, "mentions")
and self.client.user in message.raw_message.mentions
):
)
# 如果是斜杠指令或被@的消息,设置为唤醒状态
if is_slash_command or is_mention:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
@override
async def terminate(self):
"""终止适配器"""
logger.info("[Discord] 正在终止适配器...")
# 清理指令
if self.enable_command_register and self.client:
logger.info("[Discord] 正在清理已注册的斜杠指令...")
try:
# 传入空的列表来清除所有全局指令
# 如果指定了 guild_id,则只清除该服务器的指令
await self.client.sync_commands(commands=[], guild_ids=[self.guild_id] if self.guild_id else None)
logger.info("[Discord] 指令清理完成。")
except Exception as e:
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
if self.client and hasattr(self.client, "close"):
await self.client.close()
logger.info("[Discord] 适配器已终止。")
@@ -218,3 +239,132 @@ class DiscordPlatformAdapter(Platform):
def register_handler(self, handler_info):
"""注册处理器信息"""
self.registered_handlers.append(handler_info)
async def _collect_and_register_commands(self):
"""收集所有指令并注册到Discord"""
logger.info("[Discord] 开始收集并注册斜杠指令...")
registered_commands = []
for handler_md in star_handlers_registry:
if not star_map[handler_md.handler_module_path].activated:
continue
for event_filter in handler_md.event_filters:
cmd_info = self._extract_command_info(event_filter, handler_md)
if not cmd_info:
continue
cmd_name, description, cmd_filter_instance = cmd_info
# 创建动态回调
callback = self._create_dynamic_callback(cmd_name)
# 创建一个通用的参数选项来接收所有文本输入
options = [
discord.Option(
name="params",
description="指令的所有参数",
type=discord.SlashCommandOptionType.string,
required=False,
)
]
# 创建SlashCommand
slash_command = discord.SlashCommand(
name=cmd_name,
description=description,
func=callback,
options=options,
guild_ids=[self.guild_id] if self.guild_id else None,
)
self.client.add_application_command(slash_command)
registered_commands.append(cmd_name)
if registered_commands:
logger.info(
f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}")
else:
logger.info("[Discord] 没有发现可注册的指令。")
# 使用 Pycord 的方法同步指令
# 注意:这可能需要一些时间,并且有频率限制
await self.client.sync_commands()
logger.info("[Discord] 指令同步完成。")
def _create_dynamic_callback(self, cmd_name: str):
"""为每个指令动态创建一个异步回调函数"""
async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None):
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
logger.debug(f"[Discord] 回调函数参数: {ctx}")
logger.debug(f"[Discord] 回调函数参数: {params}")
message_str_for_filter = cmd_name
if params:
message_str_for_filter += " " + params
logger.debug(
f"[Discord] 斜杠指令 '{cmd_name}' 被触发。 "
f"原始参数: '{params}'. "
f"构建的指令字符串: '{message_str_for_filter}'"
)
# 尝试立即响应,防止超时
followup_webhook = None
try:
await ctx.defer()
followup_webhook = ctx.followup
except Exception as e:
logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}")
# 2. 构建 AstrBotMessage
abm = AstrBotMessage()
abm.type, abm.group_id = self._determine_message_type(
ctx.channel, ctx.guild_id
)
abm.message_str = message_str_for_filter
abm.sender = MessageMember(
user_id=str(ctx.author.id), nickname=ctx.author.display_name
)
abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id
abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id)
# 3. 将消息和 webhook 分别交给 handle_msg 处理
await self.handle_msg(abm, followup_webhook)
return dynamic_callback
@staticmethod
def _extract_command_info(
event_filter: Any, handler_metadata: StarHandlerMetadata
) -> Tuple[str, str, CommandFilter] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
is_group = False
cmd_filter_instance = None
if isinstance(event_filter, CommandFilter):
# 暂不支持子指令注册为斜杠指令
if event_filter.parent_command_names and event_filter.parent_command_names != [""]:
return None
cmd_name = event_filter.command_name
cmd_filter_instance = event_filter
elif isinstance(event_filter, CommandGroupFilter):
# 暂不支持指令组直接注册为斜杠指令,因为它们没有 handle 方法
return None
if not cmd_name:
return None
# Discord 斜杠指令名称规范
if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name):
logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}")
return None
description = handler_metadata.desc or f"指令: {cmd_name}"
if len(description) > 100:
description = description[:97] + "..."
return cmd_name, description, cmd_filter_instance
@@ -3,15 +3,27 @@ import discord
import base64
from io import BytesIO
from pathlib import Path
from typing import Optional
from typing import Optional, List
import sys
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, At
from astrbot.api.message_components import (
Plain,
Image,
File,
BaseMessageComponent,
Reply,
)
from astrbot import logger
from .client import DiscordBotClient
from .components import DiscordEmbed, DiscordView
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# 自定义Discord视图组件(兼容旧版本)
class DiscordViewComponent(BaseMessageComponent):
@@ -29,36 +41,52 @@ class DiscordPlatformEvent(AstrMessageEvent):
platform_meta: PlatformMetadata,
session_id: str,
client: DiscordBotClient,
interaction_followup_webhook: Optional[discord.Webhook] = None,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.interaction_followup_webhook = interaction_followup_webhook
@override
async def send(self, message: MessageChain):
"""发送消息到Discord平台"""
# 解析消息链为 Discord 所需的对象
try:
channel = await self._get_channel()
if not channel:
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
return
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
except Exception as e:
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
return
# 解析消息链
content, files, view, embeds = await self._parse_to_discord(message)
kwargs = {}
if content:
kwargs["content"] = content
if files:
kwargs["files"] = files
if view:
kwargs["view"] = view
if embeds:
kwargs["embeds"] = embeds
if reference_message_id and not self.interaction_followup_webhook:
kwargs["reference"] = self.client.get_message(int(reference_message_id))
if not kwargs:
logger.debug("[Discord] 尝试发送空消息,已忽略。")
return
# Discord 不允许发送完全空的消息
if not content and not files and not view and not embeds:
logger.debug("[Discord] 尝试发送空消息,已忽略。")
return
# 根据上下文执行发送/回复操作
try:
# -- 斜杠指令/交互上下文 --
if self.interaction_followup_webhook:
await self.interaction_followup_webhook.send(**kwargs)
# 发送消息
await channel.send(
content=content or None,
files=files or None,
view=view or None,
embeds=embeds or None,
)
# -- 常规消息上下文 --
else:
channel = await self._get_channel()
if not channel:
return
else:
await channel.send(**kwargs)
except discord.errors.HTTPException as e:
logger.error(f"[Discord] 发送消息失败: {e.status} {e.code} - {e.text}")
except Exception as e:
logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True)
@@ -80,14 +108,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
message: MessageChain,
) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
plain_text_parts = []
content = ""
files = []
view = None
embeds = []
reference_message_id = None
for i in message.chain: # 遍历消息链
if isinstance(i, Plain): # 如果是文字类型的
plain_text_parts.append(i.text)
content += i.text
elif isinstance(i, Reply):
reference_message_id = i.id
elif isinstance(i, At):
content += f"<@{i.qq}>"
elif isinstance(i, Image):
logger.debug(f"[Discord] 开始处理 Image 组件: {i}")
try:
@@ -174,7 +206,8 @@ class DiscordPlatformEvent(AstrMessageEvent):
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
files.append(
discord.File(BytesIO(file_bytes), filename=i.name)
discord.File(BytesIO(file_bytes),
filename=i.name)
)
else:
logger.warning(
@@ -197,37 +230,10 @@ class DiscordPlatformEvent(AstrMessageEvent):
else:
logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}")
# 合并文本内容
content = "\n".join(plain_text_parts)
if len(content) > 2000:
logger.warning("[Discord] 消息内容超过2000字符,将被截断。")
content = content[:2000]
return content, files, view, embeds
async def reply(self, message: MessageChain):
"""回复消息(如果原消息存在)"""
try:
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "reply"
):
# 解析消息链
content, files, view, embeds = await self._parse_to_discord(message)
# 使用Discord的回复功能
await self.message_obj.raw_message.reply(
content=content or None,
files=files or None,
view=view or None,
embeds=embeds or None,
)
else:
# 如果无法回复,使用普通发送
await self.send(message)
except Exception as e:
logger.error(f"[Discord] 回复消息失败: {e}")
# 回退到普通发送
await self.send(message)
return content, files, view, embeds, reference_message_id
async def react(self, emoji: str):
"""对原消息添加反应"""