Merge branch 'master' into branch-1

This commit is contained in:
Soulter
2025-06-22 10:28:44 +08:00
committed by GitHub
15 changed files with 3622 additions and 1427 deletions
+205
View File
@@ -225,8 +225,48 @@ CONFIG_METADATA_2 = {
"telegram_command_auto_refresh": True,
"telegram_command_register_interval": 300,
},
"discord":{
"id": "discord",
"type": "discord",
"enable": False,
"discord_token": "",
"discord_proxy": "",
},
"slack": {
"id": "slack",
"type": "slack",
"enable": False,
"bot_token": "",
"app_token": "",
"signing_secret": "",
"slack_connection_mode": "socket", # webhook, socket
"slack_webhook_host": "0.0.0.0",
"slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback",
},
},
"items": {
"slack_connection_mode": {
"description": "Slack Connection Mode",
"type": "string",
"options": ["webhook", "socket"],
"hint": "The connection mode for Slack. `webhook` uses a webhook server, `socket` uses Slack's Socket Mode.",
},
"slack_webhook_host": {
"description": "Slack Webhook Host",
"type": "string",
"hint": "Only valid when Slack connection mode is `webhook`.",
},
"slack_webhook_port": {
"description": "Slack Webhook Port",
"type": "int",
"hint": "Only valid when Slack connection mode is `webhook`.",
},
"slack_webhook_path": {
"description": "Slack Webhook Path",
"type": "string",
"hint": "Only valid when Slack connection mode is `webhook`.",
},
"active_send_mode": {
"description": "是否换用主动发送接口",
"type": "bool",
@@ -324,6 +364,16 @@ CONFIG_METADATA_2 = {
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
"obvious_hint": True,
},
"discord_token":{
"description": "Discord Bot Token",
"type": "string",
"hint": "在此处填入你的Discord Bot Token"
},
"discord_proxy":{
"description": "Discord 代理地址",
"type": "string",
"hint": "可选的代理地址:http://ip:port"
},
},
},
"platform_settings": {
@@ -800,6 +850,37 @@ CONFIG_METADATA_2 = {
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
"timeout": 20,
},
"GSV TTS(本地加载)": {
"id": "gsv_tts",
"enable": False,
"type": "gsv_tts_selfhost",
"provider_type": "text_to_speech",
"api_base": "http://127.0.0.1:9880",
"gpt_weights_path": "",
"sovits_weights_path": "",
"timeout": 60,
"gsv_default_parms": {
"gsv_ref_audio_path": "",
"gsv_prompt_text": "",
"gsv_prompt_lang": "zh",
"gsv_aux_ref_audio_paths": "",
"gsv_text_lang": "zh",
"gsv_top_k": 5,
"gsv_top_p": 1.0,
"gsv_temperature": 1.0,
"gsv_text_split_method": "cut3",
"gsv_batch_size": 1,
"gsv_batch_threshold": 0.75,
"gsv_split_bucket": True,
"gsv_speed_factor": 1,
"gsv_fragment_interval": 0.3,
"gsv_streaming_mode": False,
"gsv_seed": -1,
"gsv_parallel_infer": True,
"gsv_repetition_penalty": 1.35,
"gsv_media_type": "wav",
},
},
"GSVI TTS(API)": {
"id": "gsvi_tts",
"type": "gsvi_tts_api",
@@ -901,6 +982,130 @@ CONFIG_METADATA_2 = {
},
},
"items": {
"gpt_weights_path": {
"description": "GPT模型文件路径",
"type": "string",
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
"obvious_hint": True,
},
"sovits_weights_path": {
"description": "SoVITS模型文件路径",
"type": "string",
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
"obvious_hint": True,
},
"gsv_default_parms": {
"description": "GPT_SoVITS默认参数",
"hint": "参考音频文件路径、参考音频文本必填,其他参数根据个人爱好自行填写",
"type": "object",
"items": {
"gsv_ref_audio_path": {
"description": "参考音频文件路径",
"type": "string",
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
"obvious_hint": True,
},
"gsv_prompt_text": {
"description": "参考音频文本",
"type": "string",
"hint": "必填!请填写参考音频讲述的文本",
"obvious_hint": True,
},
"gsv_prompt_lang": {
"description": "参考音频文本语言",
"type": "string",
"hint": "请填写参考音频讲述的文本的语言,默认为中文",
},
"gsv_aux_ref_audio_paths": {
"description": "辅助参考音频文件路径",
"type": "string",
"hint": "辅助参考音频文件,可不填",
},
"gsv_text_lang": {
"description": "文本语言",
"type": "string",
"hint": "默认为中文",
},
"gsv_top_k": {
"description": "生成语音的多样性",
"type": "int",
"hint": "",
},
"gsv_top_p": {
"description": "核采样的阈值",
"type": "float",
"hint": "",
},
"gsv_temperature": {
"description": "生成语音的随机性",
"type": "float",
"hint": "",
},
"gsv_text_split_method": {
"description": "切分文本的方法",
"type": "string",
"hint": "可选值: `cut0`:不切分 `cut1`:四句一切 `cut2`50字一切 `cut3`:按中文句号切 `cut4`:按英文句号切 `cut5`:按标点符号切",
"options": [
"cut0",
"cut1",
"cut2",
"cut3",
"cut4",
"cut5",
],
},
"gsv_batch_size": {
"description": "批处理大小",
"type": "int",
"hint": "",
},
"gsv_batch_threshold": {
"description": "批处理阈值",
"type": "float",
"hint": "",
},
"gsv_split_bucket": {
"description": "将文本分割成桶以便并行处理",
"type": "bool",
"hint": "",
},
"gsv_speed_factor": {
"description": "语音播放速度",
"type": "float",
"hint": "1为原始语速",
},
"gsv_fragment_interval": {
"description": "语音片段之间的间隔时间",
"type": "float",
"hint": "",
},
"gsv_streaming_mode": {
"description": "启用流模式",
"type": "bool",
"hint": "",
},
"gsv_seed": {
"description": "随机种子",
"type": "int",
"hint": "用于结果的可重复性",
},
"gsv_parallel_infer": {
"description": "并行执行推理",
"type": "bool",
"hint": "",
},
"gsv_repetition_penalty": {
"description": "重复惩罚因子",
"type": "float",
"hint": "",
},
"gsv_media_type": {
"description": "输出媒体的类型",
"type": "string",
"hint": "建议用wav",
},
},
},
"embedding_dimensions": {
"description": "嵌入维度",
"type": "int",
+9 -1
View File
@@ -77,7 +77,15 @@ class PlatformManager:
case "wecom":
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa
)
case "discord":
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
@@ -0,0 +1,116 @@
import discord
from astrbot import logger
# Discord Bot客户端
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
def __init__(self, token: str, proxy: str = None):
self.token = token
self.proxy = proxy
# 设置Intent权限,遵循权限最小化原则
intents = discord.Intents.default()
intents.message_content = True # 订阅消息内容事件 (Privileged)
intents.members = True # 订阅成员事件 (Privileged)
# 初始化Bot
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
is_mentioned = self.user in message.mentions
return {
"message": message,
"bot_id": str(self.user.id),
"content": message.content,
"username": message.author.display_name,
"userid": str(message.author.id),
"message_id": str(message.id),
"channel_id": str(message.channel.id),
"guild_id": str(message.guild.id) if message.guild else None,
"type": "message",
"is_mentioned": is_mentioned,
"clean_content": message.clean_content,
}
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
return {
"interaction": interaction,
"bot_id": str(self.user.id),
"content": self._extract_interaction_content(interaction),
"username": interaction.user.display_name,
"userid": str(interaction.user.id),
"message_id": str(interaction.id),
"channel_id": str(interaction.channel_id)
if interaction.channel_id
else None,
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"type": "interaction",
}
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
return
logger.debug(
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
)
if self.on_message_received:
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
async def on_interaction(self, interaction: discord.Interaction):
"""当接收到交互(按钮点击等)时触发"""
logger.debug(
f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}"
)
if self.on_message_received:
interaction_data = self._create_interaction_data(interaction)
await self.on_message_received(interaction_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type
interaction_data = getattr(interaction, "data", {})
if not interaction_data:
return ""
if interaction_type == discord.InteractionType.application_command:
command_name = interaction_data.get("name", "")
if options := interaction_data.get("options", []):
params = " ".join(
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
)
return f"/{command_name} {params}"
return f"/{command_name}"
elif interaction_type == discord.InteractionType.component:
custom_id = interaction_data.get("custom_id", "")
component_type = interaction_data.get("component_type", "")
return f"component:{custom_id}:{component_type}"
return str(interaction_data)
async def start_polling(self):
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
async def close(self):
"""关闭客户端"""
if not self.is_closed():
await super().close()
@@ -0,0 +1,125 @@
import discord
from typing import List
from astrbot.api.message_components import BaseMessageComponent
# Discord专用组件
class DiscordEmbed(BaseMessageComponent):
"""Discord Embed消息组件"""
type: str = "discord_embed"
def __init__(
self,
title: str = None,
description: str = None,
color: int = None,
url: str = None,
thumbnail: str = None,
image: str = None,
footer: str = None,
fields: List[dict] = None,
):
self.title = title
self.description = description
self.color = color
self.url = url
self.thumbnail = thumbnail
self.image = image
self.footer = footer
self.fields = fields or []
def to_discord_embed(self) -> discord.Embed:
"""转换为Discord Embed对象"""
embed = discord.Embed()
if self.title:
embed.title = self.title
if self.description:
embed.description = self.description
if self.color:
embed.color = self.color
if self.url:
embed.url = self.url
if self.thumbnail:
embed.set_thumbnail(url=self.thumbnail)
if self.image:
embed.set_image(url=self.image)
if self.footer:
embed.set_footer(text=self.footer)
for field in self.fields:
embed.add_field(
name=field.get("name", ""),
value=field.get("value", ""),
inline=field.get("inline", False),
)
return embed
class DiscordButton(BaseMessageComponent):
"""Discord按钮组件"""
type: str = "discord_button"
def __init__(
self,
label: str,
custom_id: str = None,
style: str = "primary",
emoji: str = None,
url: str = None,
disabled: bool = False,
):
self.label = label
self.custom_id = custom_id
self.style = style
self.emoji = emoji
self.url = url
self.disabled = disabled
class DiscordView(BaseMessageComponent):
"""Discord视图组件,包含按钮和选择菜单"""
type: str = "discord_view"
def __init__(
self, components: List[BaseMessageComponent] = None, timeout: float = None
):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)
for component in self.components:
if isinstance(component, DiscordButton):
button_style = getattr(
discord.ButtonStyle, component.style, discord.ButtonStyle.primary
)
if component.url:
# URL按钮
button = discord.ui.Button(
label=component.label,
style=discord.ButtonStyle.link,
url=component.url,
emoji=component.emoji,
disabled=component.disabled,
)
else:
# 普通按钮
button = discord.ui.Button(
label=component.label,
style=button_style,
custom_id=component.custom_id,
emoji=component.emoji,
disabled=component.disabled,
)
view.add_item(button)
return view
@@ -0,0 +1,220 @@
import asyncio
import discord
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
PlatformMetadata,
MessageType,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, File
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
from astrbot import logger
from .client import DiscordBotClient
from .discord_platform_event import DiscordPlatformEvent
# 注册平台适配器
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
class DiscordPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.client_self_id = None
self.registered_handlers = []
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
"""通过会话发送消息"""
# 创建临时事件对象来发送消息
temp_event = DiscordPlatformEvent(
message_str="",
message_obj=None,
platform_meta=self.meta(),
session_id=session.session_id,
client=self.client,
)
await temp_event.send(message_chain)
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
"""返回平台元数据"""
return PlatformMetadata(
"discord",
"Discord 适配器",
id=self.config.get("id"),
default_config_tmpl=self.config,
)
async def run(self):
"""主要运行逻辑"""
# 初始化回调函数
async def on_received(message_data):
logger.debug(f"[Discord] 收到消息: {message_data}")
if self.client_self_id is None:
self.client_self_id = message_data.get("bot_id")
abm = await self.convert_message(data=message_data)
await self.handle_msg(abm)
# 初始化 Discord 客户端
token = str(self.config.get("discord_token"))
if not token:
logger.error("[Discord] Bot Token 未配置。请在配置文件中正确设置 token。")
return
proxy = self.config.get("discord_proxy") or None
self.client = DiscordBotClient(token, proxy)
self.client.on_message_received = on_received
try:
await self.client.start_polling()
except discord.errors.LoginFailure:
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
except discord.errors.ConnectionClosed:
logger.warning("[Discord] 与 Discord 的连接已关闭。")
except Exception as e:
logger.error(f"[Discord] 适配器运行时发生意外错误: {e}", exc_info=True)
def _determine_message_type(
self, channel, guild_id=None
) -> tuple[MessageType, str]:
"""判断消息类型和群组ID"""
if guild_id is None and (
isinstance(channel, discord.DMChannel)
or getattr(channel, "guild", None) is None
):
return MessageType.FRIEND_MESSAGE, ""
gid = guild_id or getattr(channel, "guild", None).id
return MessageType.GROUP_MESSAGE, str(gid)
def _convert_interaction_to_abm(self, data: dict) -> AstrBotMessage:
"""将交互事件转换为 AstrBotMessage"""
interaction: discord.Interaction = data["interaction"]
abm = AstrBotMessage()
abm.type, abm.group_id = self._determine_message_type(
interaction.channel, interaction.guild_id
)
# 对于交互事件,message_str 通常没有意义,且可能导致被闲聊等通用插件错误响应。
# 将其清空,以确保只有专门的指令处理器会响应。
abm.message_str = ""
abm.sender = MessageMember(
user_id=str(interaction.user.id), nickname=interaction.user.display_name
)
abm.message = [Plain(text=data["content"])]
abm.raw_message = interaction
abm.self_id = self.client_self_id
abm.session_id = (
str(interaction.channel_id)
if interaction.channel_id
else str(interaction.user.id)
)
abm.message_id = str(interaction.id)
return abm
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"]
is_mentioned = data.get("is_mentioned", False)
content = message.content
# 如果机器人被@,移除@部分
if (
is_mentioned
and self.client
and self.client.user
and self.client.user in message.mentions
):
# 构建机器人的@字符串,格式为 <@USER_ID> 或 <@!USER_ID>
mention_str = f"<@{self.client.user.id}>"
mention_str_nickname = (
f"<@!{self.client.user.id}>" # 有些客户端会使用带!的格式
)
if content.startswith(mention_str):
content = content[len(mention_str) :].lstrip()
elif content.startswith(mention_str_nickname):
content = content[len(mention_str_nickname) :].lstrip()
abm = AstrBotMessage()
abm.type, abm.group_id = self._determine_message_type(message.channel)
abm.message_str = content
abm.sender = MessageMember(
user_id=str(message.author.id), nickname=message.author.display_name
)
message_chain = []
if abm.message_str:
message_chain.append(Plain(text=abm.message_str))
if message.attachments:
for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith(
"image/"
):
message_chain.append(
Image(file=attachment.url, filename=attachment.filename)
)
else:
message_chain.append(
File(name=attachment.filename, url=attachment.url)
)
abm.message = message_chain
abm.raw_message = message
abm.self_id = self.client_self_id
abm.session_id = str(message.channel.id)
abm.message_id = str(message.id)
return abm
async def convert_message(self, data: dict) -> AstrBotMessage:
"""将平台消息转换成 AstrBotMessage"""
if data.get("type") in ["interaction", "slash_command"]:
return self._convert_interaction_to_abm(data)
else:
return self._convert_message_to_abm(data)
async def handle_msg(self, message: AstrBotMessage):
"""处理消息"""
message_event = DiscordPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
)
# 如果是被@的消息,设置为唤醒状态
if (
self.client
and self.client.user
and hasattr(message.raw_message, "mentions")
and self.client.user in message.raw_message.mentions
):
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
async def terminate(self):
"""终止适配器"""
logger.info("[Discord] 正在终止适配器...")
if self.client and hasattr(self.client, "close"):
await self.client.close()
logger.info("[Discord] 适配器已终止。")
def register_handler(self, handler_info):
"""注册处理器信息"""
self.registered_handlers.append(handler_info)
@@ -0,0 +1,285 @@
import asyncio
import discord
import base64
from io import BytesIO
from pathlib import Path
from typing import Optional
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent
from astrbot import logger
from .client import DiscordBotClient
from .components import DiscordEmbed, DiscordView
# 自定义Discord视图组件(兼容旧版本)
class DiscordViewComponent(BaseMessageComponent):
type: str = "discord_view"
def __init__(self, view: discord.ui.View):
self.view = view
class DiscordPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: DiscordBotClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
async def send(self, message: MessageChain):
"""发送消息到Discord平台"""
try:
channel = await self._get_channel()
if not channel:
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
return
# 解析消息链
content, files, view, embeds = await self._parse_to_discord(message)
# Discord 不允许发送完全空的消息
if not content and not files and not view and not embeds:
logger.debug("[Discord] 尝试发送空消息,已忽略。")
return
# 发送消息
await channel.send(
content=content or None,
files=files or None,
view=view or None,
embeds=embeds or None,
)
except discord.errors.HTTPException as e:
logger.error(f"[Discord] 发送消息失败: {e.status} {e.code} - {e.text}")
except Exception as e:
logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True)
await super().send(message)
async def _get_channel(self) -> Optional[discord.abc.Messageable]:
"""获取当前事件对应的频道对象"""
try:
channel_id = int(self.session_id)
return self.client.get_channel(
channel_id
) or await self.client.fetch_channel(channel_id)
except (ValueError, discord.errors.NotFound, discord.errors.Forbidden):
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
return None
async def _parse_to_discord(
self,
message: MessageChain,
) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
plain_text_parts = []
files = []
view = None
embeds = []
for i in message.chain: # 遍历消息链
if isinstance(i, Plain): # 如果是文字类型的
plain_text_parts.append(i.text)
elif isinstance(i, Image):
logger.debug(f"[Discord] 开始处理 Image 组件: {i}")
try:
filename = getattr(i, "filename", None)
file_content = getattr(i, "file", None)
if not file_content:
logger.warning(f"[Discord] Image 组件没有 file 属性: {i}")
continue
discord_file = None
# 1. URL
if file_content.startswith("http"):
logger.debug(f"[Discord] 处理 URL 图片: {file_content}")
embed = discord.Embed().set_image(url=file_content)
embeds.append(embed)
continue
# 2. File URI
elif file_content.startswith("file:///"):
logger.debug(f"[Discord] 处理 File URI: {file_content}")
path = Path(file_content[8:])
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
discord_file = discord.File(
BytesIO(file_bytes), filename=filename or path.name
)
else:
logger.warning(f"[Discord] 图片文件不存在: {path}")
# 3. Base64 URI
elif file_content.startswith("base64://"):
logger.debug("[Discord] 处理 Base64 URI")
b64_data = file_content.split("base64://", 1)[1]
missing_padding = len(b64_data) % 4
if missing_padding:
b64_data += "=" * (4 - missing_padding)
img_bytes = base64.b64decode(b64_data)
discord_file = discord.File(
BytesIO(img_bytes), filename=filename or "image.png"
)
# 4. 裸 Base64 或本地路径
else:
try:
logger.debug("[Discord] 尝试作为裸 Base64 处理")
b64_data = file_content
missing_padding = len(b64_data) % 4
if missing_padding:
b64_data += "=" * (4 - missing_padding)
img_bytes = base64.b64decode(b64_data)
discord_file = discord.File(
BytesIO(img_bytes), filename=filename or "image.png"
)
except (ValueError, TypeError, base64.binascii.Error):
logger.debug(
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}"
)
path = Path(file_content)
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
discord_file = discord.File(
BytesIO(file_bytes), filename=filename or path.name
)
else:
logger.warning(f"[Discord] 图片文件不存在: {path}")
if discord_file:
files.append(discord_file)
except Exception:
# 使用 getattr 来安全地访问 i.file,以防 i 本身就是问题
file_info = getattr(i, "file", "未知")
logger.error(
f"[Discord] 处理图片时发生未知严重错误: {file_info}",
exc_info=True,
)
elif isinstance(i, File):
try:
file_path_str = await i.get_file()
if file_path_str:
path = Path(file_path_str)
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
files.append(
discord.File(BytesIO(file_bytes), filename=i.name)
)
else:
logger.warning(
f"[Discord] 获取文件失败,路径不存在: {file_path_str}"
)
else:
logger.warning(f"[Discord] 获取文件失败: {i.name}")
except Exception as e:
logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}")
elif isinstance(i, DiscordEmbed):
# Discord Embed消息
embeds.append(i.to_discord_embed())
elif isinstance(i, DiscordView):
# Discord视图组件(按钮、选择菜单等)
view = i.to_discord_view()
elif isinstance(i, DiscordViewComponent):
# 如果消息链中包含Discord视图组件(兼容旧版本)
if isinstance(i.view, discord.ui.View):
view = i.view
else:
logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}")
# 合并文本内容
content = "\n".join(plain_text_parts)
if len(content) > 2000:
logger.warning("[Discord] 消息内容超过2000字符,将被截断。")
content = content[:2000]
return content, files, view, embeds
async def reply(self, message: MessageChain):
"""回复消息(如果原消息存在)"""
try:
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "reply"
):
# 解析消息链
content, files, view, embeds = await self._parse_to_discord(message)
# 使用Discord的回复功能
await self.message_obj.raw_message.reply(
content=content or None,
files=files or None,
view=view or None,
embeds=embeds or None,
)
else:
# 如果无法回复,使用普通发送
await self.send(message)
except Exception as e:
logger.error(f"[Discord] 回复消息失败: {e}")
# 回退到普通发送
await self.send(message)
async def react(self, emoji: str):
"""对原消息添加反应"""
try:
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "add_reaction"
):
await self.message_obj.raw_message.add_reaction(emoji)
except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}")
def is_slash_command(self) -> bool:
"""判断是否为斜杠命令"""
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type
== discord.InteractionType.application_command
)
def is_button_interaction(self) -> bool:
"""判断是否为按钮交互"""
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component
)
def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id"""
if self.is_button_interaction():
try:
return self.message_obj.raw_message.data.get("custom_id", "")
except Exception:
pass
return ""
def is_mentioned(self) -> bool:
"""判断机器人是否被@"""
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "mentions"
):
return any(
mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions
)
return False
def get_mention_clean_content(self) -> str:
"""获取去除@后的清洁内容"""
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "clean_content"
):
return self.message_obj.raw_message.clean_content
return self.message_str
@@ -0,0 +1,162 @@
import json
import hmac
import hashlib
import asyncio
import logging
from typing import Callable, Optional
from quart import Quart, request, Response
from slack_sdk.web.async_client import AsyncWebClient
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from astrbot.api import logger
class SlackWebhookClient:
"""Slack Webhook 模式客户端,使用 Quart 作为 Web 服务器"""
def __init__(
self,
web_client: AsyncWebClient,
signing_secret: str,
host: str = "0.0.0.0",
port: int = 3000,
path: str = "/slack/events",
event_handler: Optional[Callable] = None,
):
self.web_client = web_client
self.signing_secret = signing_secret
self.host = host
self.port = port
self.path = path
self.event_handler = event_handler
self.app = Quart(__name__)
self._setup_routes()
# 禁用 Quart 的默认日志输出
logging.getLogger("quart.app").setLevel(logging.WARNING)
logging.getLogger("quart.serving").setLevel(logging.WARNING)
self.shutdown_event = asyncio.Event()
def _setup_routes(self):
"""设置路由"""
@self.app.route(self.path, methods=["POST"])
async def slack_events():
"""处理 Slack 事件"""
try:
# 获取请求体和头部
body = await request.get_data()
event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature
timestamp = request.headers.get("X-Slack-Request-Timestamp")
signature = request.headers.get("X-Slack-Signature")
if not timestamp or not signature:
return Response("Missing headers", status=400)
# Calculate the HMAC signature
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
my_signature = (
"v0="
+ hmac.new(
self.signing_secret.encode("utf-8"),
sig_basestring.encode("utf-8"),
hashlib.sha256,
).hexdigest()
)
# Verify the signature
if not hmac.compare_digest(my_signature, signature):
logger.warning("Slack request signature verification failed")
return Response("Invalid signature", status=400)
logger.info(f"Received Slack event: {event_data}")
# 处理 URL 验证事件
if event_data.get("type") == "url_verification":
return {"challenge": event_data.get("challenge")}
# 处理事件
if self.event_handler and event_data.get("type") == "event_callback":
await self.event_handler(event_data)
return Response("", status=200)
except Exception as e:
logger.error(f"处理 Slack 事件时出错: {e}")
return Response("Internal Server Error", status=500)
@self.app.route("/health", methods=["GET"])
async def health_check():
"""健康检查端点"""
return {"status": "ok", "service": "slack-webhook"}
async def start(self):
"""启动 Webhook 服务器"""
logger.info(
f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}..."
)
await self.app.run_task(
host=self.host,
port=self.port,
debug=False,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger(self):
await self.shutdown_event.wait()
async def stop(self):
"""停止 Webhook 服务器"""
self.shutdown_event.set()
logger.info("Slack Webhook 服务器已停止")
class SlackSocketClient:
"""Slack Socket 模式客户端"""
def __init__(
self,
web_client: AsyncWebClient,
app_token: str,
event_handler: Optional[Callable] = None,
):
self.web_client = web_client
self.app_token = app_token
self.event_handler = event_handler
self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
"""处理 Socket Mode 事件"""
try:
# 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response)
# 处理事件
if self.event_handler:
await self.event_handler(req)
except Exception as e:
logger.error(f"处理 Socket Mode 事件时出错: {e}")
async def start(self):
"""启动 Socket Mode 连接"""
self.socket_client = SocketModeClient(
app_token=self.app_token,
logger=logger,
web_client=self.web_client,
)
# 注册事件处理器
self.socket_client.socket_mode_request_listeners.append(self._handle_events)
logger.info("Slack Socket Mode 客户端启动中...")
await self.socket_client.connect()
async def stop(self):
"""停止 Socket Mode 连接"""
if self.socket_client:
await self.socket_client.disconnect()
await self.socket_client.close()
logger.info("Slack Socket Mode 客户端已停止")
@@ -0,0 +1,396 @@
import time
import asyncio
import uuid
import aiohttp
import re
import base64
from typing import Awaitable, Any
from slack_sdk.web.async_client import AsyncWebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from .slack_event import SlackMessageEvent
from .client import SlackWebhookClient, SlackSocketClient
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
@register_platform_adapter(
"slack", "适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。"
)
class SlackAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
self.bot_token = platform_config.get("bot_token")
self.app_token = platform_config.get("app_token")
self.signing_secret = platform_config.get("signing_secret")
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
self.webhook_path = platform_config.get(
"slack_webhook_path", "/astrbot-slack-webhook/callback"
)
if not self.bot_token:
raise ValueError("Slack bot_token 是必需的")
if self.connection_mode == "socket" and not self.app_token:
raise ValueError("Socket Mode 需要 app_token")
if self.connection_mode == "webhook" and not self.signing_secret:
raise ValueError("Webhook Mode 需要 signing_secret")
self.metadata = PlatformMetadata(
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
)
# 初始化 Slack Web Client
self.web_client = AsyncWebClient(token=self.bot_token, logger=logger)
self.socket_client = None
self.webhook_client = None
self.bot_self_id = None
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
blocks, text = SlackMessageEvent._parse_slack_blocks(
message_chain=message_chain, web_client=self.web_client
)
try:
if session.message_type == MessageType.GROUP_MESSAGE:
# 发送到频道
channel_id = (
session.session_id.split("_")[-1]
if "_" in session.session_id
else session.session_id
)
await self.web_client.chat_postMessage(
channel=channel_id,
text=text,
blocks=blocks if blocks else None,
)
else:
# 发送私信
await self.web_client.chat_postMessage(
channel=session.session_id,
text=text,
blocks=blocks if blocks else None,
)
except Exception as e:
logger.error(f"Slack 发送消息失败: {e}")
await super().send_by_session(session, message_chain)
async def convert_message(self, event: dict) -> AstrBotMessage:
logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage()
abm.self_id = self.bot_self_id
# 获取用户信息
user_id = event.get("user", "")
try:
user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"]
user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception:
user_name = user_id
abm.sender = MessageMember(user_id=user_id, nickname=user_name)
# 判断消息类型
channel_id = event.get("channel", "")
try:
channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"]
if is_im:
abm.type = MessageType.FRIEND_MESSAGE
else:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = channel_id
except Exception:
# 默认作为群组消息处理
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = channel_id
# 设置会话ID
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{user_id}_{channel_id}"
else:
abm.session_id = (
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
)
abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
abm.timestamp = int(float(event.get("ts", time.time())))
# 处理消息内容
message_text = event.get("text", "")
abm.message_str = message_text
abm.message = []
# 优先使用 blocks 字段解析消息
if "blocks" in event and event["blocks"]:
abm.message = self._parse_blocks(event["blocks"])
# 更新 message_str
abm.message_str = ""
for component in abm.message:
if isinstance(component, Plain):
abm.message_str += component.text
elif message_text:
# 处理传统的文本消息
if "<@" in message_text:
mentions = re.findall(r"<@([^>]+)>", message_text)
for mention in mentions:
try:
mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"]
user_name = user_data.get("real_name") or user_data.get(
"name", mention
)
abm.message.append(At(qq=mention, name=user_name))
except Exception:
abm.message.append(At(qq=mention, name=""))
# 清理消息文本中的@标记
if clean_text := re.sub(r"<@[^>]+>", "", message_text).strip():
abm.message.append(Plain(text=clean_text))
else:
abm.message.append(Plain(text=message_text))
# 处理文件附件
if "files" in event:
for file_info in event["files"]:
file_name = file_info.get("name", "unknown")
file_url = file_info.get("url_private", "")
if file_info.get("mimetype", "").startswith("image/"):
file_url = await self.get_file_base64(file_url)
abm.message.append(Image.fromBase64(base64=file_url))
else:
# TODO: 下载鉴权
abm.message.append(
File(name=file_name, file=file_url, url=file_url)
)
abm.raw_message = event
return abm
def _parse_blocks(self, blocks: list) -> list:
"""解析 Slack blocks 格式的消息内容"""
message_components = []
for block in blocks:
block_type = block.get("type", "")
if block_type == "rich_text":
# 处理富文本块
elements = block.get("elements", [])
for element in elements:
if element.get("type") == "rich_text_section":
# 处理富文本段落
section_elements = element.get("elements", [])
text_content = ""
for section_element in section_elements:
element_type = section_element.get("type", "")
if element_type == "text":
# 普通文本
text_content += section_element.get("text", "")
elif element_type == "user":
# @用户提及
user_id = section_element.get("user_id", "")
if user_id:
# 将之前的文本内容先添加到组件中
if text_content.strip():
message_components.append(
Plain(text=text_content)
)
text_content = ""
# 添加@提及组件
message_components.append(At(qq=user_id, name=""))
elif element_type == "channel":
# #频道提及
channel_id = section_element.get("channel_id", "")
text_content += f"#{channel_id}"
elif element_type == "link":
# 链接
url = section_element.get("url", "")
link_text = section_element.get("text", url)
text_content += f"[{link_text}]({url})"
elif element_type == "emoji":
# 表情符号
emoji_name = section_element.get("name", "")
text_content += f":{emoji_name}:"
if text_content.strip():
message_components.append(Plain(text=text_content))
elif element.get("type") == "rich_text_list":
# 处理列表
list_items = element.get("elements", [])
list_text = ""
for item in list_items:
if item.get("type") == "rich_text_section":
item_elements = item.get("elements", [])
item_text = ""
for item_element in item_elements:
if item_element.get("type") == "text":
item_text += item_element.get("text", "")
list_text += f"{item_text}\n"
if list_text.strip():
message_components.append(Plain(text=list_text.strip()))
elif block_type == "section":
# 处理段落块
if "text" in block:
text_obj = block["text"]
if text_obj.get("type") == "mrkdwn":
text_content = text_obj.get("text", "")
message_components.append(Plain(text=text_content))
return message_components
async def _handle_socket_event(self, req: SocketModeRequest):
"""处理 Socket Mode 事件"""
if req.type == "events_api":
# 事件 API
event = req.payload.get("event", {})
# 忽略机器人自己的消息和消息编辑
if event.get("subtype") in [
"bot_message",
"message_changed",
"message_deleted",
]:
return
if event.get("bot_id"):
return
if event.get("type") in ["message", "app_mention"]:
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
async def get_bot_user_id(self):
auth_info = await self.web_client.auth_test()
return auth_info.get("user_id")
async def get_file_base64(self, url: str) -> str:
"""下载 Slack 文件并返回 Base64 编码的内容"""
headers = {"Authorization": f"Bearer {self.bot_token}"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as resp:
if resp.status == 200:
content = await resp.read()
base64_content = base64.b64encode(content).decode("utf-8")
return base64_content
else:
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:
self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
if self.connection_mode == "socket":
if not self.app_token:
raise ValueError("Socket Mode 需要 app_token")
# 创建 Socket 客户端
self.socket_client = SlackSocketClient(
self.web_client, self.app_token, self._handle_socket_event
)
logger.info("Slack 适配器 (Socket Mode) 启动中...")
await self.socket_client.start()
elif self.connection_mode == "webhook":
if not self.signing_secret:
raise ValueError("Webhook Mode 需要 signing_secret")
# 创建 Webhook 客户端
self.webhook_client = SlackWebhookClient(
self.web_client,
self.signing_secret,
self.webhook_host,
self.webhook_port,
self.webhook_path,
self._handle_webhook_event,
)
logger.info(
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}..."
)
await self.webhook_client.start()
else:
raise ValueError(
f"不支持的连接模式: {self.connection_mode},请使用 'socket''webhook'"
)
async def _handle_webhook_event(self, event_data: dict):
"""处理 Webhook 事件"""
event = event_data.get("event", {})
# 忽略机器人自己的消息和消息编辑
if event.get("subtype") in [
"bot_message",
"message_changed",
"message_deleted",
]:
return
if event.get("bot_id"):
return
if event.get("type") in ["message", "app_mention"]:
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
async def terminate(self):
if self.socket_client:
await self.socket_client.stop()
if self.webhook_client:
await self.webhook_client.stop()
logger.info("Slack 适配器已被优雅地关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = SlackMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
web_client=self.web_client,
)
self.commit_event(message_event)
def get_client(self):
return self.web_client
@@ -0,0 +1,237 @@
import asyncio
import re
from typing import AsyncGenerator
from slack_sdk.web.async_client import AsyncWebClient
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
Plain,
File,
BaseMessageComponent,
)
from astrbot.api.platform import Group, MessageMember
from astrbot.api import logger
class SlackMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str,
message_obj,
platform_meta,
session_id,
web_client: AsyncWebClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.web_client = web_client
@staticmethod
async def _from_segment_to_slack_block(
segment: BaseMessageComponent, web_client: AsyncWebClient
) -> dict:
"""将消息段转换为 Slack 块格式"""
if isinstance(segment, Plain):
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
elif isinstance(segment, Image):
# upload file
url = segment.url or segment.file
if url.startswith("http"):
return {
"type": "image",
"image_url": url,
"alt_text": "图片",
}
path = await segment.convert_to_file_path()
response = await web_client.files_upload_v2(
file=path,
filename="image.jpg",
)
if not response["ok"]:
logger.error(f"Slack file upload failed: {response['error']}")
return {
"type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"},
}
image_url = response["files"][0]["url_private"]
logger.debug(f"Slack file upload response: {response}")
return {
"type": "image",
"slack_file": {
"url": image_url,
},
"alt_text": "图片",
}
elif isinstance(segment, File):
# upload file
url = segment.url or segment.file
response = await web_client.files_upload_v2(
file=url,
filename=segment.name or "file",
)
if not response["ok"]:
logger.error(f"Slack file upload failed: {response['error']}")
return {
"type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
else:
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
@staticmethod
async def _parse_slack_blocks(
message_chain: MessageChain, web_client: AsyncWebClient
):
"""解析成 Slack 块格式"""
blocks = []
text_content = ""
for segment in message_chain.chain:
if isinstance(segment, Plain):
text_content += segment.text
else:
# 如果有文本内容,先添加文本块
if text_content.strip():
blocks.append(
{
"type": "section",
"text": {"type": "mrkdwn", "text": text_content},
}
)
text_content = ""
# 添加其他类型的块
block = await SlackMessageEvent._from_segment_to_slack_block(
segment, web_client
)
blocks.append(block)
# 如果最后还有文本内容
if text_content.strip():
blocks.append(
{"type": "section", "text": {"type": "mrkdwn", "text": text_content}}
)
return blocks, "" if blocks else text_content
async def send(self, message: MessageChain):
blocks, text = await SlackMessageEvent._parse_slack_blocks(
message, self.web_client
)
try:
if self.get_group_id():
# 发送到频道
await self.web_client.chat_postMessage(
channel=self.get_group_id(),
text=text,
blocks=blocks or None,
)
else:
# 发送私信
await self.web_client.chat_postMessage(
channel=self.get_sender_id(),
text=text,
blocks=blocks or None,
)
except Exception:
# 如果块发送失败,尝试只发送文本
fallback_text = ""
for segment in message.chain:
if isinstance(segment, Plain):
fallback_text += segment.text
elif isinstance(segment, File):
fallback_text += f" [文件: {segment.name}] "
elif isinstance(segment, Image):
fallback_text += " [图片] "
if self.get_group_id():
await self.web_client.chat_postMessage(
channel=self.get_group_id(), text=fallback_text
)
else:
await self.web_client.chat_postMessage(
channel=self.get_sender_id(), text=fallback_text
)
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)
async def get_group(self, group_id=None, **kwargs):
if group_id:
channel_id = group_id
elif self.get_group_id():
channel_id = self.get_group_id()
else:
return None
try:
# 获取频道信息
channel_info = await self.web_client.conversations_info(channel=channel_id)
# 获取频道成员
members_response = await self.web_client.conversations_members(
channel=channel_id
)
members = []
for member_id in members_response["members"]:
try:
user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"]
members.append(
MessageMember(
user_id=member_id,
nickname=user_data.get("real_name")
or user_data.get("name", member_id),
)
)
except Exception:
# 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"]
return Group(
group_id=channel_id,
group_name=channel_data.get("name", ""),
group_avatar="",
group_admins=[], # Slack 的管理员信息需要特殊权限获取
group_owner=channel_data.get("creator", ""),
members=members,
)
except Exception:
return None
+4
View File
@@ -225,6 +225,10 @@ class ProviderManager:
from .sources.edge_tts_source import (
ProviderEdgeTTS as ProviderEdgeTTS,
)
case "gsv_tts_selfhost":
from .sources.gsv_selfhosted_source import (
ProviderGSVTTS as ProviderGSVTTS,
)
case "gsvi_tts_api":
from .sources.gsvi_tts_source import (
ProviderGSVITTS as ProviderGSVITTS,
@@ -0,0 +1,148 @@
import asyncio
import os
import uuid
import aiohttp
from ..provider import TTSProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@register_provider_adapter(
provider_type_name="gsv_tts_selfhost",
desc="GPT-SoVITS TTS(本地加载)",
provider_type=ProviderType.TEXT_TO_SPEECH,
)
class ProviderGSVTTS(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
"/"
)
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
# TTS 请求的默认参数,移除前缀gsv_
self.default_params: dict = {
key.removeprefix("gsv_"): str(value).lower()
for key, value in provider_config.get("gsv_default_parms", {}).items()
}
self.timeout = provider_config.get("timeout", 60)
self._session: aiohttp.ClientSession | None = None
async def initialize(self):
"""异步初始化:在 ProviderManager 中被调用"""
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
)
try:
await self._set_model_weights()
logger.info("[GSV TTS] 初始化完成")
except Exception as e:
logger.error(f"[GSV TTS] 初始化失败:{e}")
raise
def get_session(self) -> aiohttp.ClientSession:
if not self._session or self._session.closed:
raise RuntimeError(
"[GSV TTS] Provider HTTP session is not ready or closed."
)
return self._session
async def _make_request(
self, endpoint: str, params=None, retries: int = 3
) -> bytes | None:
"""发起请求"""
for attempt in range(retries):
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
try:
async with self.get_session().get(endpoint, params=params) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
)
return await response.read()
except Exception as e:
if attempt < retries - 1:
logger.warning(
f"[GSV TTS] 请求 {endpoint}{attempt + 1} 次失败:{e},重试中..."
)
await asyncio.sleep(1)
else:
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
raise
async def _set_model_weights(self):
"""设置模型路径"""
try:
if self.gpt_weights_path:
await self._make_request(
f"{self.api_base}/set_gpt_weights",
{"weights_path": self.gpt_weights_path},
)
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
else:
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
if self.sovits_weights_path:
await self._make_request(
f"{self.api_base}/set_sovits_weights",
{"weights_path": self.sovits_weights_path},
)
logger.info(
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
)
else:
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
except aiohttp.ClientError as e:
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
except Exception as e:
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
async def get_audio(self, text: str) -> str:
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
if not text.strip():
raise ValueError("[GSV TTS] TTS 文本不能为空")
endpoint = f"{self.api_base}/tts"
params = self.build_synthesis_params(text)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
result = await self._make_request(endpoint, params)
if isinstance(result, bytes):
with open(path, "wb") as f:
f.write(result)
return path
else:
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
def build_synthesis_params(self, text: str) -> dict:
"""
构建语音合成所需的参数字典。
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
"""
params = self.default_params.copy()
params["text"] = text
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
return params
async def terminate(self):
"""终止释放资源:在 ProviderManager 中被调用"""
if self._session and not self._session.closed:
await self._session.close()
logger.info("[GSV TTS] Session 已关闭")
+80 -26
View File
@@ -306,6 +306,24 @@
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
<!-- Key为空的确认对话框 -->
<v-dialog v-model="showKeyConfirm" max-width="450" persistent>
<v-card>
<v-card-title class="text-h6 bg-error d-flex align-center">
<v-icon start class="me-2">mdi-alert-circle-outline</v-icon>
确认保存
</v-card-title>
<v-card-text class="py-4 text-body-1 text-medium-emphasis">
您没有填写 API Key确定要保存吗这可能会导致该服务提供商无法正常工作
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="handleKeyConfirm(false)">取消</v-btn>
<v-btn color="error" variant="flat" @click="handleKeyConfirm(true)">确定</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
</template>
@@ -341,6 +359,10 @@ export default {
sessionSeparationEnabled: false,
sessionSettingLoading: false,
// Key确认对话框
showKeyConfirm: false,
keyConfirmResolve: null,
newSelectedProviderName: '',
newSelectedProviderConfig: {},
updatingMode: false,
@@ -388,6 +410,16 @@ export default {
}
},
watch: {
showKeyConfirm(newValue) {
// 当对话框关闭时,如果 Promise 还在等待,则拒绝它以防止内存泄漏
if (!newValue && this.keyConfirmResolve) {
this.keyConfirmResolve(false);
this.keyConfirmResolve = null;
}
}
},
computed: {
// 翻译消息的计算属性
messages() {
@@ -591,32 +623,40 @@ export default {
this.updatingMode = true;
},
newProvider() {
async newProvider() {
// 检查 key 是否为空
if (
'key' in this.newSelectedProviderConfig &&
(!this.newSelectedProviderConfig.key || this.newSelectedProviderConfig.key.length === 0)
) {
const confirmed = await this.confirmEmptyKey();
if (!confirmed) {
return; // 如果用户取消,则中止保存
}
}
this.loading = true;
if (this.updatingMode) {
axios.post('/api/config/provider/update', {
id: this.newSelectedProviderName,
config: this.newSelectedProviderConfig
}).then((res) => {
this.loading = false;
this.showProviderCfg = false;
this.getConfig();
this.showSuccess(res.data.message || this.messages.success.update);
}).catch((err) => {
this.loading = false;
this.showError(err.response?.data?.message || err.message);
});
this.updatingMode = false;
} else {
axios.post('/api/config/provider/new', this.newSelectedProviderConfig).then((res) => {
this.loading = false;
this.showProviderCfg = false;
this.getConfig();
this.showSuccess(res.data.message || this.messages.success.add);
}).catch((err) => {
this.loading = false;
this.showError(err.response?.data?.message || err.message);
});
const wasUpdating = this.updatingMode;
try {
if (wasUpdating) {
const res = await axios.post('/api/config/provider/update', {
id: this.newSelectedProviderName,
config: this.newSelectedProviderConfig
});
this.showSuccess(res.data.message || "更新成功!");
} else {
const res = await axios.post('/api/config/provider/new', this.newSelectedProviderConfig);
this.showSuccess(res.data.message || "添加成功!");
}
this.showProviderCfg = false;
this.getConfig();
} catch (err) {
this.showError(err.response?.data?.message || err.message);
} finally {
this.loading = false;
if (wasUpdating) {
this.updatingMode = false;
}
}
},
@@ -698,7 +738,21 @@ export default {
this.loadingStatus = false;
this.showError(err.response?.data?.message || err.message);
});
}
},
confirmEmptyKey() {
this.showKeyConfirm = true;
return new Promise((resolve) => {
this.keyConfirmResolve = resolve;
});
},
handleKeyConfirm(confirmed) {
if (this.keyConfirmResolve) {
this.keyConfirmResolve(confirmed);
}
this.showKeyConfirm = false;
},
}
}
</script>
+2
View File
@@ -33,6 +33,7 @@ dependencies = [
"pillow>=11.2.1",
"pip>=25.1.1",
"psutil>=5.8.0",
"py-cord[speed]>=2.6.1",
"pydantic~=2.10.3",
"pydub>=0.25.1",
"pyjwt>=2.10.1",
@@ -41,6 +42,7 @@ dependencies = [
"quart>=0.20.0",
"readability-lxml>=0.8.4.1",
"silk-python>=0.2.6",
"slack-sdk>=3.35.0",
"telegramify-markdown>=0.5.1",
"watchfiles>=1.0.5",
"websockets>=15.0.1",
+3 -1
View File
@@ -37,4 +37,6 @@ watchfiles
websockets
faiss-cpu
aiosqlite
nh3
nh3
py-cord[speed]>=2.6.1
slack-sdk
Generated
+1630 -1399
View File
File diff suppressed because it is too large Load Diff