Merge branch 'AstrBotDevs:master' into Astrbot_session_manage

This commit is contained in:
Gao Jinzhe
2025-06-27 17:10:08 +08:00
committed by GitHub
7 changed files with 91 additions and 52 deletions
-4
View File
@@ -389,10 +389,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
},
"discord_guild_id_for_debug": {
"description": "【开发用】指定一个服务器(Guild)ID。在此服务器注册的指令会立刻生效,便于调试。留空则注册为全局指令。",
"type": "string",
},
},
},
"platform_settings": {
+8 -8
View File
@@ -128,9 +128,7 @@ class RespondStage(Stage):
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send()
await event.send_streaming(result.async_stream, use_fallback)
await event._post_send()
return
elif len(result.chain) > 0:
# 检查路径映射
@@ -141,8 +139,6 @@ class RespondStage(Stage):
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
await event._pre_send()
# 检查消息链是否为空
try:
if await self._is_empty_message_chain(result.chain):
@@ -158,9 +154,14 @@ class RespondStage(Stage):
c for c in result.chain if not isinstance(c, Comp.Record)
]
if self.enable_seg and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
)
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
decorated_comps = []
if self.reply_with_mention:
@@ -208,7 +209,6 @@ class RespondStage(Stage):
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
await event._post_send()
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
@@ -141,7 +141,11 @@ class ResultDecorateStage(Stage):
break
# 分段回复
if self.enable_segmented_reply:
if self.enable_segmented_reply and event.get_platform_name() not in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
if (
self.only_llm_result and result.is_llm_result()
) or not self.only_llm_result:
@@ -136,7 +136,6 @@ class WakingCheckStage(Stage):
f"插件 {star_map[handler.handler_module_path].name}: {e}"
)
)
await event._post_send()
event.stop_event()
passed = False
break
@@ -151,7 +150,6 @@ class WakingCheckStage(Stage):
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
)
)
await event._post_send()
logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
)
+2 -2
View File
@@ -235,10 +235,10 @@ class AstrMessageEvent(abc.ABC):
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法"""
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
async def _post_send(self):
"""调度器会在执行 send() 后调用该方法"""
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
def set_result(self, result: Union[MessageEventResult, str]):
"""设置消息事件的结果。
@@ -46,6 +46,8 @@ class DiscordPlatformAdapter(Platform):
self.enable_command_register = self.config.get("discord_command_register", True)
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
self.activity_name = self.config.get("discord_activity_name", None)
self.shutdown_event = asyncio.Event()
self._polling_task = None
@override
async def send_by_session(
@@ -137,7 +139,8 @@ class DiscordPlatformAdapter(Platform):
self.client.on_ready_once_callback = callback
try:
await self.client.start_polling()
self._polling_task = asyncio.create_task(self.client.start_polling())
await self.shutdown_event.wait()
except discord.errors.LoginFailure:
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
except discord.errors.ConnectionClosed:
@@ -162,42 +165,47 @@ class DiscordPlatformAdapter(Platform):
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"]
is_mentioned = data.get("is_mentioned", False)
content = message.content
# 如果机器人被@,移除@部分
if (
is_mentioned
and self.client
and self.client.user
and self.client.user in message.mentions
):
# 构建机器人的@字符串,格式为 <@USER_ID> 或 <@!USER_ID>
# 剥离 User Mention (<@id>, <@!id>)
if self.client and self.client.user:
mention_str = f"<@{self.client.user.id}>"
mention_str_nickname = (
f"<@!{self.client.user.id}>" # 有些客户端会使用带!的格式
)
mention_str_nickname = f"<@!{self.client.user.id}>"
if content.startswith(mention_str):
content = content[len(mention_str) :].lstrip()
elif content.startswith(mention_str_nickname):
content = content[len(mention_str_nickname) :].lstrip()
abm = AstrBotMessage()
# 剥离 Role Mentionbot 拥有的任一角色被提及,<@&role_id>
if (
hasattr(message, "role_mentions")
and hasattr(message, "guild")
and message.guild
):
bot_member = (
message.guild.get_member(self.client.user.id)
if self.client and self.client.user
else None
)
if bot_member and hasattr(bot_member, "roles"):
for role in bot_member.roles:
role_mention_str = f"<@&{role.id}>"
if content.startswith(role_mention_str):
content = content[len(role_mention_str) :].lstrip()
break # 只剥离第一个匹配的角色 mention
abm = AstrBotMessage()
abm.type = self._get_message_type(message.channel)
abm.group_id = self._get_channel_id(message.channel)
abm.message_str = content
abm.sender = MessageMember(
user_id=str(message.author.id), nickname=message.author.display_name
)
message_chain = []
if abm.message_str:
message_chain.append(Plain(text=abm.message_str))
if message.attachments:
for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith(
@@ -210,7 +218,6 @@ class DiscordPlatformAdapter(Platform):
message_chain.append(
File(name=attachment.filename, url=attachment.url)
)
abm.message = message_chain
abm.raw_message = message
abm.self_id = self.client_self_id
@@ -237,13 +244,35 @@ class DiscordPlatformAdapter(Platform):
# 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None
# 检查是否被@
is_mention = (
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False
# User Mention
if (
self.client
and self.client.user
and hasattr(message.raw_message, "mentions")
and self.client.user in message.raw_message.mentions
)
):
if self.client.user in message.raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"):
bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
try:
bot_member = message.raw_message.guild.get_member(
self.client.user.id
)
except Exception:
bot_member = None
if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions)
if (
bot_roles
and mentioned_roles
and bot_roles.intersection(mentioned_roles)
):
is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态
if is_slash_command or is_mention:
@@ -255,23 +284,37 @@ class DiscordPlatformAdapter(Platform):
@override
async def terminate(self):
"""终止适配器"""
logger.info("[Discord] 正在终止适配器...")
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
self.shutdown_event.set()
# 优先 cancel polling_task
if self._polling_task:
self._polling_task.cancel()
try:
await asyncio.wait_for(self._polling_task, timeout=10)
except asyncio.CancelledError:
logger.info("[Discord] polling_task 已取消。")
except Exception as e:
logger.warning(f"[Discord] polling_task 取消异常: {e}")
logger.info("[Discord] 正在清理已注册的斜杠指令... (step 2)")
# 清理指令
if self.enable_command_register and self.client:
logger.info("[Discord] 正在清理已注册的斜杠指令...")
try:
# 传入空的列表来清除所有全局指令
# 如果指定了 guild_id,则只清除该服务器的指令
await self.client.sync_commands(
commands=[], guild_ids=[self.guild_id] if self.guild_id else None
await asyncio.wait_for(
self.client.sync_commands(
commands=[],
guild_ids=[self.guild_id] if self.guild_id else None,
),
timeout=10,
)
logger.info("[Discord] 指令清理完成。")
except Exception as e:
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
logger.info("[Discord] 正在关闭 Discord 客户端... (step 3)")
if self.client and hasattr(self.client, "close"):
await self.client.close()
try:
await asyncio.wait_for(self.client.close(), timeout=10)
except Exception as e:
logger.warning(f"[Discord] 客户端关闭异常: {e}")
logger.info("[Discord] 适配器已终止。")
def register_handler(self, handler_info):
@@ -28,10 +28,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
self.send_buffer = message
await self._post_send()
async def send_streaming(self, generator, use_fallback: bool = False):
"""流式输出仅支持消息列表私聊"""