Merge branch 'master' into branch-1
This commit is contained in:
@@ -225,8 +225,48 @@ CONFIG_METADATA_2 = {
|
||||
"telegram_command_auto_refresh": True,
|
||||
"telegram_command_register_interval": 300,
|
||||
},
|
||||
"discord":{
|
||||
"id": "discord",
|
||||
"type": "discord",
|
||||
"enable": False,
|
||||
"discord_token": "",
|
||||
"discord_proxy": "",
|
||||
},
|
||||
"slack": {
|
||||
"id": "slack",
|
||||
"type": "slack",
|
||||
"enable": False,
|
||||
"bot_token": "",
|
||||
"app_token": "",
|
||||
"signing_secret": "",
|
||||
"slack_connection_mode": "socket", # webhook, socket
|
||||
"slack_webhook_host": "0.0.0.0",
|
||||
"slack_webhook_port": 6197,
|
||||
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"slack_connection_mode": {
|
||||
"description": "Slack Connection Mode",
|
||||
"type": "string",
|
||||
"options": ["webhook", "socket"],
|
||||
"hint": "The connection mode for Slack. `webhook` uses a webhook server, `socket` uses Slack's Socket Mode.",
|
||||
},
|
||||
"slack_webhook_host": {
|
||||
"description": "Slack Webhook Host",
|
||||
"type": "string",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
},
|
||||
"slack_webhook_port": {
|
||||
"description": "Slack Webhook Port",
|
||||
"type": "int",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
},
|
||||
"slack_webhook_path": {
|
||||
"description": "Slack Webhook Path",
|
||||
"type": "string",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
},
|
||||
"active_send_mode": {
|
||||
"description": "是否换用主动发送接口",
|
||||
"type": "bool",
|
||||
@@ -324,6 +364,16 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"discord_token":{
|
||||
"description": "Discord Bot Token",
|
||||
"type": "string",
|
||||
"hint": "在此处填入你的Discord Bot Token"
|
||||
},
|
||||
"discord_proxy":{
|
||||
"description": "Discord 代理地址",
|
||||
"type": "string",
|
||||
"hint": "可选的代理地址:http://ip:port"
|
||||
},
|
||||
},
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -800,6 +850,37 @@ CONFIG_METADATA_2 = {
|
||||
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
|
||||
"timeout": 20,
|
||||
},
|
||||
"GSV TTS(本地加载)": {
|
||||
"id": "gsv_tts",
|
||||
"enable": False,
|
||||
"type": "gsv_tts_selfhost",
|
||||
"provider_type": "text_to_speech",
|
||||
"api_base": "http://127.0.0.1:9880",
|
||||
"gpt_weights_path": "",
|
||||
"sovits_weights_path": "",
|
||||
"timeout": 60,
|
||||
"gsv_default_parms": {
|
||||
"gsv_ref_audio_path": "",
|
||||
"gsv_prompt_text": "",
|
||||
"gsv_prompt_lang": "zh",
|
||||
"gsv_aux_ref_audio_paths": "",
|
||||
"gsv_text_lang": "zh",
|
||||
"gsv_top_k": 5,
|
||||
"gsv_top_p": 1.0,
|
||||
"gsv_temperature": 1.0,
|
||||
"gsv_text_split_method": "cut3",
|
||||
"gsv_batch_size": 1,
|
||||
"gsv_batch_threshold": 0.75,
|
||||
"gsv_split_bucket": True,
|
||||
"gsv_speed_factor": 1,
|
||||
"gsv_fragment_interval": 0.3,
|
||||
"gsv_streaming_mode": False,
|
||||
"gsv_seed": -1,
|
||||
"gsv_parallel_infer": True,
|
||||
"gsv_repetition_penalty": 1.35,
|
||||
"gsv_media_type": "wav",
|
||||
},
|
||||
},
|
||||
"GSVI TTS(API)": {
|
||||
"id": "gsvi_tts",
|
||||
"type": "gsvi_tts_api",
|
||||
@@ -901,6 +982,130 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"gpt_weights_path": {
|
||||
"description": "GPT模型文件路径",
|
||||
"type": "string",
|
||||
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"sovits_weights_path": {
|
||||
"description": "SoVITS模型文件路径",
|
||||
"type": "string",
|
||||
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_default_parms": {
|
||||
"description": "GPT_SoVITS默认参数",
|
||||
"hint": "参考音频文件路径、参考音频文本必填,其他参数根据个人爱好自行填写",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"gsv_ref_audio_path": {
|
||||
"description": "参考音频文件路径",
|
||||
"type": "string",
|
||||
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_prompt_text": {
|
||||
"description": "参考音频文本",
|
||||
"type": "string",
|
||||
"hint": "必填!请填写参考音频讲述的文本",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_prompt_lang": {
|
||||
"description": "参考音频文本语言",
|
||||
"type": "string",
|
||||
"hint": "请填写参考音频讲述的文本的语言,默认为中文",
|
||||
},
|
||||
"gsv_aux_ref_audio_paths": {
|
||||
"description": "辅助参考音频文件路径",
|
||||
"type": "string",
|
||||
"hint": "辅助参考音频文件,可不填",
|
||||
},
|
||||
"gsv_text_lang": {
|
||||
"description": "文本语言",
|
||||
"type": "string",
|
||||
"hint": "默认为中文",
|
||||
},
|
||||
"gsv_top_k": {
|
||||
"description": "生成语音的多样性",
|
||||
"type": "int",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_top_p": {
|
||||
"description": "核采样的阈值",
|
||||
"type": "float",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_temperature": {
|
||||
"description": "生成语音的随机性",
|
||||
"type": "float",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_text_split_method": {
|
||||
"description": "切分文本的方法",
|
||||
"type": "string",
|
||||
"hint": "可选值: `cut0`:不切分 `cut1`:四句一切 `cut2`:50字一切 `cut3`:按中文句号切 `cut4`:按英文句号切 `cut5`:按标点符号切",
|
||||
"options": [
|
||||
"cut0",
|
||||
"cut1",
|
||||
"cut2",
|
||||
"cut3",
|
||||
"cut4",
|
||||
"cut5",
|
||||
],
|
||||
},
|
||||
"gsv_batch_size": {
|
||||
"description": "批处理大小",
|
||||
"type": "int",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_batch_threshold": {
|
||||
"description": "批处理阈值",
|
||||
"type": "float",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_split_bucket": {
|
||||
"description": "将文本分割成桶以便并行处理",
|
||||
"type": "bool",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_speed_factor": {
|
||||
"description": "语音播放速度",
|
||||
"type": "float",
|
||||
"hint": "1为原始语速",
|
||||
},
|
||||
"gsv_fragment_interval": {
|
||||
"description": "语音片段之间的间隔时间",
|
||||
"type": "float",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_streaming_mode": {
|
||||
"description": "启用流模式",
|
||||
"type": "bool",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_seed": {
|
||||
"description": "随机种子",
|
||||
"type": "int",
|
||||
"hint": "用于结果的可重复性",
|
||||
},
|
||||
"gsv_parallel_infer": {
|
||||
"description": "并行执行推理",
|
||||
"type": "bool",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_repetition_penalty": {
|
||||
"description": "重复惩罚因子",
|
||||
"type": "float",
|
||||
"hint": "",
|
||||
},
|
||||
"gsv_media_type": {
|
||||
"description": "输出媒体的类型",
|
||||
"type": "string",
|
||||
"hint": "建议用wav",
|
||||
},
|
||||
},
|
||||
},
|
||||
"embedding_dimensions": {
|
||||
"description": "嵌入维度",
|
||||
"type": "int",
|
||||
|
||||
@@ -77,7 +77,15 @@ class PlatformManager:
|
||||
case "wecom":
|
||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||
WeixinOfficialAccountPlatformAdapter, # noqa
|
||||
)
|
||||
case "discord":
|
||||
from .sources.discord.discord_platform_adapter import (
|
||||
DiscordPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "slack":
|
||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import discord
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
# Discord Bot客户端
|
||||
class DiscordBotClient(discord.Bot):
|
||||
"""Discord客户端封装"""
|
||||
|
||||
def __init__(self, token: str, proxy: str = None):
|
||||
self.token = token
|
||||
self.proxy = proxy
|
||||
|
||||
# 设置Intent权限,遵循权限最小化原则
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True # 订阅消息内容事件 (Privileged)
|
||||
intents.members = True # 订阅成员事件 (Privileged)
|
||||
|
||||
# 初始化Bot
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
|
||||
# 回调函数
|
||||
self.on_message_received = None
|
||||
|
||||
async def on_ready(self):
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||
logger.info("[Discord] 客户端已准备就绪。")
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
is_mentioned = self.user in message.mentions
|
||||
return {
|
||||
"message": message,
|
||||
"bot_id": str(self.user.id),
|
||||
"content": message.content,
|
||||
"username": message.author.display_name,
|
||||
"userid": str(message.author.id),
|
||||
"message_id": str(message.id),
|
||||
"channel_id": str(message.channel.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
"type": "message",
|
||||
"is_mentioned": is_mentioned,
|
||||
"clean_content": message.clean_content,
|
||||
}
|
||||
|
||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||
"""从 discord.Interaction 创建数据字典"""
|
||||
return {
|
||||
"interaction": interaction,
|
||||
"bot_id": str(self.user.id),
|
||||
"content": self._extract_interaction_content(interaction),
|
||||
"username": interaction.user.display_name,
|
||||
"userid": str(interaction.user.id),
|
||||
"message_id": str(interaction.id),
|
||||
"channel_id": str(interaction.channel_id)
|
||||
if interaction.channel_id
|
||||
else None,
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
|
||||
)
|
||||
|
||||
if self.on_message_received:
|
||||
message_data = self._create_message_data(message)
|
||||
await self.on_message_received(message_data)
|
||||
|
||||
async def on_interaction(self, interaction: discord.Interaction):
|
||||
"""当接收到交互(按钮点击等)时触发"""
|
||||
logger.debug(
|
||||
f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}"
|
||||
)
|
||||
|
||||
if self.on_message_received:
|
||||
interaction_data = self._create_interaction_data(interaction)
|
||||
await self.on_message_received(interaction_data)
|
||||
|
||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||
"""从交互中提取内容"""
|
||||
interaction_type = interaction.type
|
||||
interaction_data = getattr(interaction, "data", {})
|
||||
|
||||
if not interaction_data:
|
||||
return ""
|
||||
|
||||
if interaction_type == discord.InteractionType.application_command:
|
||||
command_name = interaction_data.get("name", "")
|
||||
if options := interaction_data.get("options", []):
|
||||
params = " ".join(
|
||||
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
|
||||
)
|
||||
return f"/{command_name} {params}"
|
||||
return f"/{command_name}"
|
||||
|
||||
elif interaction_type == discord.InteractionType.component:
|
||||
custom_id = interaction_data.get("custom_id", "")
|
||||
component_type = interaction_data.get("component_type", "")
|
||||
return f"component:{custom_id}:{component_type}"
|
||||
|
||||
return str(interaction_data)
|
||||
|
||||
async def start_polling(self):
|
||||
"""开始轮询消息,这是个阻塞方法"""
|
||||
await self.start(self.token)
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
if not self.is_closed():
|
||||
await super().close()
|
||||
@@ -0,0 +1,125 @@
|
||||
import discord
|
||||
from typing import List
|
||||
from astrbot.api.message_components import BaseMessageComponent
|
||||
|
||||
|
||||
# Discord专用组件
|
||||
class DiscordEmbed(BaseMessageComponent):
|
||||
"""Discord Embed消息组件"""
|
||||
|
||||
type: str = "discord_embed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
color: int = None,
|
||||
url: str = None,
|
||||
thumbnail: str = None,
|
||||
image: str = None,
|
||||
footer: str = None,
|
||||
fields: List[dict] = None,
|
||||
):
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.color = color
|
||||
self.url = url
|
||||
self.thumbnail = thumbnail
|
||||
self.image = image
|
||||
self.footer = footer
|
||||
self.fields = fields or []
|
||||
|
||||
def to_discord_embed(self) -> discord.Embed:
|
||||
"""转换为Discord Embed对象"""
|
||||
embed = discord.Embed()
|
||||
|
||||
if self.title:
|
||||
embed.title = self.title
|
||||
if self.description:
|
||||
embed.description = self.description
|
||||
if self.color:
|
||||
embed.color = self.color
|
||||
if self.url:
|
||||
embed.url = self.url
|
||||
if self.thumbnail:
|
||||
embed.set_thumbnail(url=self.thumbnail)
|
||||
if self.image:
|
||||
embed.set_image(url=self.image)
|
||||
if self.footer:
|
||||
embed.set_footer(text=self.footer)
|
||||
|
||||
for field in self.fields:
|
||||
embed.add_field(
|
||||
name=field.get("name", ""),
|
||||
value=field.get("value", ""),
|
||||
inline=field.get("inline", False),
|
||||
)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
class DiscordButton(BaseMessageComponent):
|
||||
"""Discord按钮组件"""
|
||||
|
||||
type: str = "discord_button"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
custom_id: str = None,
|
||||
style: str = "primary",
|
||||
emoji: str = None,
|
||||
url: str = None,
|
||||
disabled: bool = False,
|
||||
):
|
||||
self.label = label
|
||||
self.custom_id = custom_id
|
||||
self.style = style
|
||||
self.emoji = emoji
|
||||
self.url = url
|
||||
self.disabled = disabled
|
||||
|
||||
|
||||
class DiscordView(BaseMessageComponent):
|
||||
"""Discord视图组件,包含按钮和选择菜单"""
|
||||
|
||||
type: str = "discord_view"
|
||||
|
||||
def __init__(
|
||||
self, components: List[BaseMessageComponent] = None, timeout: float = None
|
||||
):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
def to_discord_view(self) -> discord.ui.View:
|
||||
"""转换为Discord View对象"""
|
||||
view = discord.ui.View(timeout=self.timeout)
|
||||
|
||||
for component in self.components:
|
||||
if isinstance(component, DiscordButton):
|
||||
button_style = getattr(
|
||||
discord.ButtonStyle, component.style, discord.ButtonStyle.primary
|
||||
)
|
||||
|
||||
if component.url:
|
||||
# URL按钮
|
||||
button = discord.ui.Button(
|
||||
label=component.label,
|
||||
style=discord.ButtonStyle.link,
|
||||
url=component.url,
|
||||
emoji=component.emoji,
|
||||
disabled=component.disabled,
|
||||
)
|
||||
else:
|
||||
# 普通按钮
|
||||
button = discord.ui.Button(
|
||||
label=component.label,
|
||||
style=button_style,
|
||||
custom_id=component.custom_id,
|
||||
emoji=component.emoji,
|
||||
disabled=component.disabled,
|
||||
)
|
||||
|
||||
view.add_item(button)
|
||||
|
||||
return view
|
||||
@@ -0,0 +1,220 @@
|
||||
import asyncio
|
||||
import discord
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, File
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from .client import DiscordBotClient
|
||||
from .discord_platform_event import DiscordPlatformEvent
|
||||
|
||||
|
||||
# 注册平台适配器
|
||||
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
|
||||
class DiscordPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = None
|
||||
self.registered_handlers = []
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
# 创建临时事件对象来发送消息
|
||||
temp_event = DiscordPlatformEvent(
|
||||
message_str="",
|
||||
message_obj=None,
|
||||
platform_meta=self.meta(),
|
||||
session_id=session.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
await temp_event.send(message_chain)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""返回平台元数据"""
|
||||
return PlatformMetadata(
|
||||
"discord",
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
default_config_tmpl=self.config,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""主要运行逻辑"""
|
||||
|
||||
# 初始化回调函数
|
||||
async def on_received(message_data):
|
||||
logger.debug(f"[Discord] 收到消息: {message_data}")
|
||||
if self.client_self_id is None:
|
||||
self.client_self_id = message_data.get("bot_id")
|
||||
abm = await self.convert_message(data=message_data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
# 初始化 Discord 客户端
|
||||
token = str(self.config.get("discord_token"))
|
||||
if not token:
|
||||
logger.error("[Discord] Bot Token 未配置。请在配置文件中正确设置 token。")
|
||||
return
|
||||
|
||||
proxy = self.config.get("discord_proxy") or None
|
||||
self.client = DiscordBotClient(token, proxy)
|
||||
self.client.on_message_received = on_received
|
||||
|
||||
try:
|
||||
await self.client.start_polling()
|
||||
except discord.errors.LoginFailure:
|
||||
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
|
||||
except discord.errors.ConnectionClosed:
|
||||
logger.warning("[Discord] 与 Discord 的连接已关闭。")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 适配器运行时发生意外错误: {e}", exc_info=True)
|
||||
|
||||
def _determine_message_type(
|
||||
self, channel, guild_id=None
|
||||
) -> tuple[MessageType, str]:
|
||||
"""判断消息类型和群组ID"""
|
||||
if guild_id is None and (
|
||||
isinstance(channel, discord.DMChannel)
|
||||
or getattr(channel, "guild", None) is None
|
||||
):
|
||||
return MessageType.FRIEND_MESSAGE, ""
|
||||
|
||||
gid = guild_id or getattr(channel, "guild", None).id
|
||||
return MessageType.GROUP_MESSAGE, str(gid)
|
||||
|
||||
def _convert_interaction_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将交互事件转换为 AstrBotMessage"""
|
||||
interaction: discord.Interaction = data["interaction"]
|
||||
abm = AstrBotMessage()
|
||||
|
||||
abm.type, abm.group_id = self._determine_message_type(
|
||||
interaction.channel, interaction.guild_id
|
||||
)
|
||||
|
||||
# 对于交互事件,message_str 通常没有意义,且可能导致被闲聊等通用插件错误响应。
|
||||
# 将其清空,以确保只有专门的指令处理器会响应。
|
||||
abm.message_str = ""
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(interaction.user.id), nickname=interaction.user.display_name
|
||||
)
|
||||
abm.message = [Plain(text=data["content"])]
|
||||
abm.raw_message = interaction
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = (
|
||||
str(interaction.channel_id)
|
||||
if interaction.channel_id
|
||||
else str(interaction.user.id)
|
||||
)
|
||||
abm.message_id = str(interaction.id)
|
||||
return abm
|
||||
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message: discord.Message = data["message"]
|
||||
is_mentioned = data.get("is_mentioned", False)
|
||||
|
||||
content = message.content
|
||||
|
||||
# 如果机器人被@,移除@部分
|
||||
if (
|
||||
is_mentioned
|
||||
and self.client
|
||||
and self.client.user
|
||||
and self.client.user in message.mentions
|
||||
):
|
||||
# 构建机器人的@字符串,格式为 <@USER_ID> 或 <@!USER_ID>
|
||||
mention_str = f"<@{self.client.user.id}>"
|
||||
mention_str_nickname = (
|
||||
f"<@!{self.client.user.id}>" # 有些客户端会使用带!的格式
|
||||
)
|
||||
|
||||
if content.startswith(mention_str):
|
||||
content = content[len(mention_str) :].lstrip()
|
||||
elif content.startswith(mention_str_nickname):
|
||||
content = content[len(mention_str_nickname) :].lstrip()
|
||||
|
||||
abm = AstrBotMessage()
|
||||
|
||||
abm.type, abm.group_id = self._determine_message_type(message.channel)
|
||||
|
||||
abm.message_str = content
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(message.author.id), nickname=message.author.display_name
|
||||
)
|
||||
|
||||
message_chain = []
|
||||
if abm.message_str:
|
||||
message_chain.append(Plain(text=abm.message_str))
|
||||
|
||||
if message.attachments:
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith(
|
||||
"image/"
|
||||
):
|
||||
message_chain.append(
|
||||
Image(file=attachment.url, filename=attachment.filename)
|
||||
)
|
||||
else:
|
||||
message_chain.append(
|
||||
File(name=attachment.filename, url=attachment.url)
|
||||
)
|
||||
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = str(message.channel.id)
|
||||
abm.message_id = str(message.id)
|
||||
return abm
|
||||
|
||||
async def convert_message(self, data: dict) -> AstrBotMessage:
|
||||
"""将平台消息转换成 AstrBotMessage"""
|
||||
if data.get("type") in ["interaction", "slash_command"]:
|
||||
return self._convert_interaction_to_abm(data)
|
||||
else:
|
||||
return self._convert_message_to_abm(data)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
"""处理消息"""
|
||||
message_event = DiscordPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
# 如果是被@的消息,设置为唤醒状态
|
||||
if (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
and self.client.user in message.raw_message.mentions
|
||||
):
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("[Discord] 正在终止适配器...")
|
||||
if self.client and hasattr(self.client, "close"):
|
||||
await self.client.close()
|
||||
logger.info("[Discord] 适配器已终止。")
|
||||
|
||||
def register_handler(self, handler_info):
|
||||
"""注册处理器信息"""
|
||||
self.registered_handlers.append(handler_info)
|
||||
@@ -0,0 +1,285 @@
|
||||
import asyncio
|
||||
import discord
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent
|
||||
from astrbot import logger
|
||||
from .client import DiscordBotClient
|
||||
from .components import DiscordEmbed, DiscordView
|
||||
|
||||
|
||||
# 自定义Discord视图组件(兼容旧版本)
|
||||
class DiscordViewComponent(BaseMessageComponent):
|
||||
type: str = "discord_view"
|
||||
|
||||
def __init__(self, view: discord.ui.View):
|
||||
self.view = view
|
||||
|
||||
|
||||
class DiscordPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: DiscordBotClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息到Discord平台"""
|
||||
try:
|
||||
channel = await self._get_channel()
|
||||
if not channel:
|
||||
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
|
||||
return
|
||||
|
||||
# 解析消息链
|
||||
content, files, view, embeds = await self._parse_to_discord(message)
|
||||
|
||||
# Discord 不允许发送完全空的消息
|
||||
if not content and not files and not view and not embeds:
|
||||
logger.debug("[Discord] 尝试发送空消息,已忽略。")
|
||||
return
|
||||
|
||||
# 发送消息
|
||||
await channel.send(
|
||||
content=content or None,
|
||||
files=files or None,
|
||||
view=view or None,
|
||||
embeds=embeds or None,
|
||||
)
|
||||
|
||||
except discord.errors.HTTPException as e:
|
||||
logger.error(f"[Discord] 发送消息失败: {e.status} {e.code} - {e.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def _get_channel(self) -> Optional[discord.abc.Messageable]:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
channel_id = int(self.session_id)
|
||||
return self.client.get_channel(
|
||||
channel_id
|
||||
) or await self.client.fetch_channel(channel_id)
|
||||
except (ValueError, discord.errors.NotFound, discord.errors.Forbidden):
|
||||
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
|
||||
return None
|
||||
|
||||
async def _parse_to_discord(
|
||||
self,
|
||||
message: MessageChain,
|
||||
) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]:
|
||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||
plain_text_parts = []
|
||||
files = []
|
||||
view = None
|
||||
embeds = []
|
||||
|
||||
for i in message.chain: # 遍历消息链
|
||||
if isinstance(i, Plain): # 如果是文字类型的
|
||||
plain_text_parts.append(i.text)
|
||||
elif isinstance(i, Image):
|
||||
logger.debug(f"[Discord] 开始处理 Image 组件: {i}")
|
||||
try:
|
||||
filename = getattr(i, "filename", None)
|
||||
file_content = getattr(i, "file", None)
|
||||
|
||||
if not file_content:
|
||||
logger.warning(f"[Discord] Image 组件没有 file 属性: {i}")
|
||||
continue
|
||||
|
||||
discord_file = None
|
||||
|
||||
# 1. URL
|
||||
if file_content.startswith("http"):
|
||||
logger.debug(f"[Discord] 处理 URL 图片: {file_content}")
|
||||
embed = discord.Embed().set_image(url=file_content)
|
||||
embeds.append(embed)
|
||||
continue
|
||||
|
||||
# 2. File URI
|
||||
elif file_content.startswith("file:///"):
|
||||
logger.debug(f"[Discord] 处理 File URI: {file_content}")
|
||||
path = Path(file_content[8:])
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
discord_file = discord.File(
|
||||
BytesIO(file_bytes), filename=filename or path.name
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 图片文件不存在: {path}")
|
||||
|
||||
# 3. Base64 URI
|
||||
elif file_content.startswith("base64://"):
|
||||
logger.debug("[Discord] 处理 Base64 URI")
|
||||
b64_data = file_content.split("base64://", 1)[1]
|
||||
missing_padding = len(b64_data) % 4
|
||||
if missing_padding:
|
||||
b64_data += "=" * (4 - missing_padding)
|
||||
img_bytes = base64.b64decode(b64_data)
|
||||
discord_file = discord.File(
|
||||
BytesIO(img_bytes), filename=filename or "image.png"
|
||||
)
|
||||
|
||||
# 4. 裸 Base64 或本地路径
|
||||
else:
|
||||
try:
|
||||
logger.debug("[Discord] 尝试作为裸 Base64 处理")
|
||||
b64_data = file_content
|
||||
missing_padding = len(b64_data) % 4
|
||||
if missing_padding:
|
||||
b64_data += "=" * (4 - missing_padding)
|
||||
img_bytes = base64.b64decode(b64_data)
|
||||
discord_file = discord.File(
|
||||
BytesIO(img_bytes), filename=filename or "image.png"
|
||||
)
|
||||
except (ValueError, TypeError, base64.binascii.Error):
|
||||
logger.debug(
|
||||
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}"
|
||||
)
|
||||
path = Path(file_content)
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
discord_file = discord.File(
|
||||
BytesIO(file_bytes), filename=filename or path.name
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 图片文件不存在: {path}")
|
||||
|
||||
if discord_file:
|
||||
files.append(discord_file)
|
||||
|
||||
except Exception:
|
||||
# 使用 getattr 来安全地访问 i.file,以防 i 本身就是问题
|
||||
file_info = getattr(i, "file", "未知")
|
||||
logger.error(
|
||||
f"[Discord] 处理图片时发生未知严重错误: {file_info}",
|
||||
exc_info=True,
|
||||
)
|
||||
elif isinstance(i, File):
|
||||
try:
|
||||
file_path_str = await i.get_file()
|
||||
if file_path_str:
|
||||
path = Path(file_path_str)
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
files.append(
|
||||
discord.File(BytesIO(file_bytes), filename=i.name)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[Discord] 获取文件失败,路径不存在: {file_path_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 获取文件失败: {i.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}")
|
||||
elif isinstance(i, DiscordEmbed):
|
||||
# Discord Embed消息
|
||||
embeds.append(i.to_discord_embed())
|
||||
elif isinstance(i, DiscordView):
|
||||
# Discord视图组件(按钮、选择菜单等)
|
||||
view = i.to_discord_view()
|
||||
elif isinstance(i, DiscordViewComponent):
|
||||
# 如果消息链中包含Discord视图组件(兼容旧版本)
|
||||
if isinstance(i.view, discord.ui.View):
|
||||
view = i.view
|
||||
else:
|
||||
logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}")
|
||||
|
||||
# 合并文本内容
|
||||
content = "\n".join(plain_text_parts)
|
||||
if len(content) > 2000:
|
||||
logger.warning("[Discord] 消息内容超过2000字符,将被截断。")
|
||||
content = content[:2000]
|
||||
|
||||
return content, files, view, embeds
|
||||
|
||||
async def reply(self, message: MessageChain):
|
||||
"""回复消息(如果原消息存在)"""
|
||||
try:
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "reply"
|
||||
):
|
||||
# 解析消息链
|
||||
content, files, view, embeds = await self._parse_to_discord(message)
|
||||
|
||||
# 使用Discord的回复功能
|
||||
await self.message_obj.raw_message.reply(
|
||||
content=content or None,
|
||||
files=files or None,
|
||||
view=view or None,
|
||||
embeds=embeds or None,
|
||||
)
|
||||
else:
|
||||
# 如果无法回复,使用普通发送
|
||||
await self.send(message)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 回复消息失败: {e}")
|
||||
# 回退到普通发送
|
||||
await self.send(message)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
"""对原消息添加反应"""
|
||||
try:
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "add_reaction"
|
||||
):
|
||||
await self.message_obj.raw_message.add_reaction(emoji)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||
|
||||
def is_slash_command(self) -> bool:
|
||||
"""判断是否为斜杠命令"""
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type
|
||||
== discord.InteractionType.application_command
|
||||
)
|
||||
|
||||
def is_button_interaction(self) -> bool:
|
||||
"""判断是否为按钮交互"""
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
||||
)
|
||||
|
||||
def get_interaction_custom_id(self) -> str:
|
||||
"""获取交互组件的custom_id"""
|
||||
if self.is_button_interaction():
|
||||
try:
|
||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
def is_mentioned(self) -> bool:
|
||||
"""判断机器人是否被@"""
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "mentions"
|
||||
):
|
||||
return any(
|
||||
mention.id == int(self.message_obj.self_id)
|
||||
for mention in self.message_obj.raw_message.mentions
|
||||
)
|
||||
return False
|
||||
|
||||
def get_mention_clean_content(self) -> str:
|
||||
"""获取去除@后的清洁内容"""
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "clean_content"
|
||||
):
|
||||
return self.message_obj.raw_message.clean_content
|
||||
return self.message_str
|
||||
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import hmac
|
||||
import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
from quart import Quart, request, Response
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class SlackWebhookClient:
|
||||
"""Slack Webhook 模式客户端,使用 Quart 作为 Web 服务器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_client: AsyncWebClient,
|
||||
signing_secret: str,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 3000,
|
||||
path: str = "/slack/events",
|
||||
event_handler: Optional[Callable] = None,
|
||||
):
|
||||
self.web_client = web_client
|
||||
self.signing_secret = signing_secret
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.event_handler = event_handler
|
||||
|
||||
self.app = Quart(__name__)
|
||||
self._setup_routes()
|
||||
|
||||
# 禁用 Quart 的默认日志输出
|
||||
logging.getLogger("quart.app").setLevel(logging.WARNING)
|
||||
logging.getLogger("quart.serving").setLevel(logging.WARNING)
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置路由"""
|
||||
|
||||
@self.app.route(self.path, methods=["POST"])
|
||||
async def slack_events():
|
||||
"""处理 Slack 事件"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = await request.get_data()
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
timestamp = request.headers.get("X-Slack-Request-Timestamp")
|
||||
signature = request.headers.get("X-Slack-Signature")
|
||||
if not timestamp or not signature:
|
||||
return Response("Missing headers", status=400)
|
||||
# Calculate the HMAC signature
|
||||
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
my_signature = (
|
||||
"v0="
|
||||
+ hmac.new(
|
||||
self.signing_secret.encode("utf-8"),
|
||||
sig_basestring.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
)
|
||||
# Verify the signature
|
||||
if not hmac.compare_digest(my_signature, signature):
|
||||
logger.warning("Slack request signature verification failed")
|
||||
return Response("Invalid signature", status=400)
|
||||
logger.info(f"Received Slack event: {event_data}")
|
||||
|
||||
# 处理 URL 验证事件
|
||||
if event_data.get("type") == "url_verification":
|
||||
return {"challenge": event_data.get("challenge")}
|
||||
# 处理事件
|
||||
if self.event_handler and event_data.get("type") == "event_callback":
|
||||
await self.event_handler(event_data)
|
||||
|
||||
return Response("", status=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Slack 事件时出错: {e}")
|
||||
return Response("Internal Server Error", status=500)
|
||||
|
||||
@self.app.route("/health", methods=["GET"])
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {"status": "ok", "service": "slack-webhook"}
|
||||
|
||||
async def start(self):
|
||||
"""启动 Webhook 服务器"""
|
||||
logger.info(
|
||||
f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}..."
|
||||
)
|
||||
|
||||
await self.app.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
debug=False,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def stop(self):
|
||||
"""停止 Webhook 服务器"""
|
||||
self.shutdown_event.set()
|
||||
logger.info("Slack Webhook 服务器已停止")
|
||||
|
||||
|
||||
class SlackSocketClient:
|
||||
"""Slack Socket 模式客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_client: AsyncWebClient,
|
||||
app_token: str,
|
||||
event_handler: Optional[Callable] = None,
|
||||
):
|
||||
self.web_client = web_client
|
||||
self.app_token = app_token
|
||||
self.event_handler = event_handler
|
||||
self.socket_client = None
|
||||
|
||||
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
|
||||
"""处理 Socket Mode 事件"""
|
||||
try:
|
||||
# 确认收到事件
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
await self.socket_client.send_socket_mode_response(response)
|
||||
|
||||
# 处理事件
|
||||
if self.event_handler:
|
||||
await self.event_handler(req)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Socket Mode 事件时出错: {e}")
|
||||
|
||||
async def start(self):
|
||||
"""启动 Socket Mode 连接"""
|
||||
self.socket_client = SocketModeClient(
|
||||
app_token=self.app_token,
|
||||
logger=logger,
|
||||
web_client=self.web_client,
|
||||
)
|
||||
|
||||
# 注册事件处理器
|
||||
self.socket_client.socket_mode_request_listeners.append(self._handle_events)
|
||||
|
||||
logger.info("Slack Socket Mode 客户端启动中...")
|
||||
await self.socket_client.connect()
|
||||
|
||||
async def stop(self):
|
||||
"""停止 Socket Mode 连接"""
|
||||
if self.socket_client:
|
||||
await self.socket_client.disconnect()
|
||||
await self.socket_client.close()
|
||||
logger.info("Slack Socket Mode 客户端已停止")
|
||||
@@ -0,0 +1,396 @@
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
import aiohttp
|
||||
import re
|
||||
import base64
|
||||
from typing import Awaitable, Any
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from .slack_event import SlackMessageEvent
|
||||
from .client import SlackWebhookClient, SlackSocketClient
|
||||
from astrbot.api.message_components import * # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"slack", "适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。"
|
||||
)
|
||||
class SlackAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.bot_token = platform_config.get("bot_token")
|
||||
self.app_token = platform_config.get("app_token")
|
||||
self.signing_secret = platform_config.get("signing_secret")
|
||||
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
|
||||
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
|
||||
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
|
||||
self.webhook_path = platform_config.get(
|
||||
"slack_webhook_path", "/astrbot-slack-webhook/callback"
|
||||
)
|
||||
|
||||
if not self.bot_token:
|
||||
raise ValueError("Slack bot_token 是必需的")
|
||||
|
||||
if self.connection_mode == "socket" and not self.app_token:
|
||||
raise ValueError("Socket Mode 需要 app_token")
|
||||
|
||||
if self.connection_mode == "webhook" and not self.signing_secret:
|
||||
raise ValueError("Webhook Mode 需要 signing_secret")
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="slack",
|
||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
# 初始化 Slack Web Client
|
||||
self.web_client = AsyncWebClient(token=self.bot_token, logger=logger)
|
||||
self.socket_client = None
|
||||
self.webhook_client = None
|
||||
|
||||
self.bot_self_id = None
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
blocks, text = SlackMessageEvent._parse_slack_blocks(
|
||||
message_chain=message_chain, web_client=self.web_client
|
||||
)
|
||||
|
||||
try:
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
# 发送到频道
|
||||
channel_id = (
|
||||
session.session_id.split("_")[-1]
|
||||
if "_" in session.session_id
|
||||
else session.session_id
|
||||
)
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=channel_id,
|
||||
text=text,
|
||||
blocks=blocks if blocks else None,
|
||||
)
|
||||
else:
|
||||
# 发送私信
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=session.session_id,
|
||||
text=text,
|
||||
blocks=blocks if blocks else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Slack 发送消息失败: {e}")
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, event: dict) -> AstrBotMessage:
|
||||
logger.debug(f"[slack] RawMessage {event}")
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self.bot_self_id
|
||||
|
||||
# 获取用户信息
|
||||
user_id = event.get("user", "")
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=user_id)
|
||||
user_data = user_info["user"]
|
||||
user_name = user_data.get("real_name") or user_data.get("name", user_id)
|
||||
except Exception:
|
||||
user_name = user_id
|
||||
|
||||
abm.sender = MessageMember(user_id=user_id, nickname=user_name)
|
||||
|
||||
# 判断消息类型
|
||||
channel_id = event.get("channel", "")
|
||||
try:
|
||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||
is_im = channel_info["channel"]["is_im"]
|
||||
|
||||
if is_im:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = channel_id
|
||||
except Exception:
|
||||
# 默认作为群组消息处理
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = channel_id
|
||||
|
||||
# 设置会话ID
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{user_id}_{channel_id}"
|
||||
else:
|
||||
abm.session_id = (
|
||||
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
|
||||
)
|
||||
|
||||
abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
|
||||
abm.timestamp = int(float(event.get("ts", time.time())))
|
||||
|
||||
# 处理消息内容
|
||||
message_text = event.get("text", "")
|
||||
abm.message_str = message_text
|
||||
abm.message = []
|
||||
|
||||
# 优先使用 blocks 字段解析消息
|
||||
if "blocks" in event and event["blocks"]:
|
||||
abm.message = self._parse_blocks(event["blocks"])
|
||||
# 更新 message_str
|
||||
abm.message_str = ""
|
||||
for component in abm.message:
|
||||
if isinstance(component, Plain):
|
||||
abm.message_str += component.text
|
||||
elif message_text:
|
||||
# 处理传统的文本消息
|
||||
if "<@" in message_text:
|
||||
mentions = re.findall(r"<@([^>]+)>", message_text)
|
||||
for mention in mentions:
|
||||
try:
|
||||
mentioned_user = await self.web_client.users_info(user=mention)
|
||||
user_data = mentioned_user["user"]
|
||||
user_name = user_data.get("real_name") or user_data.get(
|
||||
"name", mention
|
||||
)
|
||||
abm.message.append(At(qq=mention, name=user_name))
|
||||
except Exception:
|
||||
abm.message.append(At(qq=mention, name=""))
|
||||
|
||||
# 清理消息文本中的@标记
|
||||
if clean_text := re.sub(r"<@[^>]+>", "", message_text).strip():
|
||||
abm.message.append(Plain(text=clean_text))
|
||||
else:
|
||||
abm.message.append(Plain(text=message_text))
|
||||
|
||||
# 处理文件附件
|
||||
if "files" in event:
|
||||
for file_info in event["files"]:
|
||||
file_name = file_info.get("name", "unknown")
|
||||
file_url = file_info.get("url_private", "")
|
||||
if file_info.get("mimetype", "").startswith("image/"):
|
||||
file_url = await self.get_file_base64(file_url)
|
||||
abm.message.append(Image.fromBase64(base64=file_url))
|
||||
else:
|
||||
# TODO: 下载鉴权
|
||||
abm.message.append(
|
||||
File(name=file_name, file=file_url, url=file_url)
|
||||
)
|
||||
|
||||
abm.raw_message = event
|
||||
return abm
|
||||
|
||||
def _parse_blocks(self, blocks: list) -> list:
|
||||
"""解析 Slack blocks 格式的消息内容"""
|
||||
message_components = []
|
||||
|
||||
for block in blocks:
|
||||
block_type = block.get("type", "")
|
||||
|
||||
if block_type == "rich_text":
|
||||
# 处理富文本块
|
||||
elements = block.get("elements", [])
|
||||
for element in elements:
|
||||
if element.get("type") == "rich_text_section":
|
||||
# 处理富文本段落
|
||||
section_elements = element.get("elements", [])
|
||||
text_content = ""
|
||||
|
||||
for section_element in section_elements:
|
||||
element_type = section_element.get("type", "")
|
||||
|
||||
if element_type == "text":
|
||||
# 普通文本
|
||||
text_content += section_element.get("text", "")
|
||||
elif element_type == "user":
|
||||
# @用户提及
|
||||
user_id = section_element.get("user_id", "")
|
||||
if user_id:
|
||||
# 将之前的文本内容先添加到组件中
|
||||
if text_content.strip():
|
||||
message_components.append(
|
||||
Plain(text=text_content)
|
||||
)
|
||||
text_content = ""
|
||||
# 添加@提及组件
|
||||
message_components.append(At(qq=user_id, name=""))
|
||||
elif element_type == "channel":
|
||||
# #频道提及
|
||||
channel_id = section_element.get("channel_id", "")
|
||||
text_content += f"#{channel_id}"
|
||||
elif element_type == "link":
|
||||
# 链接
|
||||
url = section_element.get("url", "")
|
||||
link_text = section_element.get("text", url)
|
||||
text_content += f"[{link_text}]({url})"
|
||||
elif element_type == "emoji":
|
||||
# 表情符号
|
||||
emoji_name = section_element.get("name", "")
|
||||
text_content += f":{emoji_name}:"
|
||||
|
||||
if text_content.strip():
|
||||
message_components.append(Plain(text=text_content))
|
||||
|
||||
elif element.get("type") == "rich_text_list":
|
||||
# 处理列表
|
||||
list_items = element.get("elements", [])
|
||||
list_text = ""
|
||||
for item in list_items:
|
||||
if item.get("type") == "rich_text_section":
|
||||
item_elements = item.get("elements", [])
|
||||
item_text = ""
|
||||
for item_element in item_elements:
|
||||
if item_element.get("type") == "text":
|
||||
item_text += item_element.get("text", "")
|
||||
list_text += f"• {item_text}\n"
|
||||
|
||||
if list_text.strip():
|
||||
message_components.append(Plain(text=list_text.strip()))
|
||||
|
||||
elif block_type == "section":
|
||||
# 处理段落块
|
||||
if "text" in block:
|
||||
text_obj = block["text"]
|
||||
if text_obj.get("type") == "mrkdwn":
|
||||
text_content = text_obj.get("text", "")
|
||||
message_components.append(Plain(text=text_content))
|
||||
|
||||
return message_components
|
||||
|
||||
async def _handle_socket_event(self, req: SocketModeRequest):
|
||||
"""处理 Socket Mode 事件"""
|
||||
if req.type == "events_api":
|
||||
# 事件 API
|
||||
event = req.payload.get("event", {})
|
||||
|
||||
# 忽略机器人自己的消息和消息编辑
|
||||
if event.get("subtype") in [
|
||||
"bot_message",
|
||||
"message_changed",
|
||||
"message_deleted",
|
||||
]:
|
||||
return
|
||||
|
||||
if event.get("bot_id"):
|
||||
return
|
||||
|
||||
if event.get("type") in ["message", "app_mention"]:
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def get_bot_user_id(self):
|
||||
auth_info = await self.web_client.auth_test()
|
||||
return auth_info.get("user_id")
|
||||
|
||||
async def get_file_base64(self, url: str) -> str:
|
||||
"""下载 Slack 文件并返回 Base64 编码的内容"""
|
||||
headers = {"Authorization": f"Bearer {self.bot_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
content = await resp.read()
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
return base64_content
|
||||
else:
|
||||
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
self.bot_self_id = await self.get_bot_user_id()
|
||||
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
|
||||
|
||||
if self.connection_mode == "socket":
|
||||
if not self.app_token:
|
||||
raise ValueError("Socket Mode 需要 app_token")
|
||||
|
||||
# 创建 Socket 客户端
|
||||
self.socket_client = SlackSocketClient(
|
||||
self.web_client, self.app_token, self._handle_socket_event
|
||||
)
|
||||
|
||||
logger.info("Slack 适配器 (Socket Mode) 启动中...")
|
||||
await self.socket_client.start()
|
||||
|
||||
elif self.connection_mode == "webhook":
|
||||
if not self.signing_secret:
|
||||
raise ValueError("Webhook Mode 需要 signing_secret")
|
||||
|
||||
# 创建 Webhook 客户端
|
||||
self.webhook_client = SlackWebhookClient(
|
||||
self.web_client,
|
||||
self.signing_secret,
|
||||
self.webhook_host,
|
||||
self.webhook_port,
|
||||
self.webhook_path,
|
||||
self._handle_webhook_event,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}..."
|
||||
)
|
||||
await self.webhook_client.start()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'"
|
||||
)
|
||||
|
||||
async def _handle_webhook_event(self, event_data: dict):
|
||||
"""处理 Webhook 事件"""
|
||||
event = event_data.get("event", {})
|
||||
|
||||
# 忽略机器人自己的消息和消息编辑
|
||||
if event.get("subtype") in [
|
||||
"bot_message",
|
||||
"message_changed",
|
||||
"message_deleted",
|
||||
]:
|
||||
return
|
||||
|
||||
if event.get("bot_id"):
|
||||
return
|
||||
|
||||
if event.get("type") in ["message", "app_mention"]:
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def terminate(self):
|
||||
if self.socket_client:
|
||||
await self.socket_client.stop()
|
||||
if self.webhook_client:
|
||||
await self.webhook_client.stop()
|
||||
logger.info("Slack 适配器已被优雅地关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = SlackMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
web_client=self.web_client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self):
|
||||
return self.web_client
|
||||
@@ -0,0 +1,237 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Plain,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class SlackMessageEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str,
|
||||
message_obj,
|
||||
platform_meta,
|
||||
session_id,
|
||||
web_client: AsyncWebClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.web_client = web_client
|
||||
|
||||
@staticmethod
|
||||
async def _from_segment_to_slack_block(
|
||||
segment: BaseMessageComponent, web_client: AsyncWebClient
|
||||
) -> dict:
|
||||
"""将消息段转换为 Slack 块格式"""
|
||||
if isinstance(segment, Plain):
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
|
||||
elif isinstance(segment, Image):
|
||||
# upload file
|
||||
url = segment.url or segment.file
|
||||
if url.startswith("http"):
|
||||
return {
|
||||
"type": "image",
|
||||
"image_url": url,
|
||||
"alt_text": "图片",
|
||||
}
|
||||
path = await segment.convert_to_file_path()
|
||||
response = await web_client.files_upload_v2(
|
||||
file=path,
|
||||
filename="image.jpg",
|
||||
)
|
||||
if not response["ok"]:
|
||||
logger.error(f"Slack file upload failed: {response['error']}")
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "图片上传失败"},
|
||||
}
|
||||
image_url = response["files"][0]["url_private"]
|
||||
logger.debug(f"Slack file upload response: {response}")
|
||||
return {
|
||||
"type": "image",
|
||||
"slack_file": {
|
||||
"url": image_url,
|
||||
},
|
||||
"alt_text": "图片",
|
||||
}
|
||||
elif isinstance(segment, File):
|
||||
# upload file
|
||||
url = segment.url or segment.file
|
||||
response = await web_client.files_upload_v2(
|
||||
file=url,
|
||||
filename=segment.name or "file",
|
||||
)
|
||||
if not response["ok"]:
|
||||
logger.error(f"Slack file upload failed: {response['error']}")
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
||||
else:
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||
|
||||
@staticmethod
|
||||
async def _parse_slack_blocks(
|
||||
message_chain: MessageChain, web_client: AsyncWebClient
|
||||
):
|
||||
"""解析成 Slack 块格式"""
|
||||
blocks = []
|
||||
text_content = ""
|
||||
|
||||
for segment in message_chain.chain:
|
||||
if isinstance(segment, Plain):
|
||||
text_content += segment.text
|
||||
else:
|
||||
# 如果有文本内容,先添加文本块
|
||||
if text_content.strip():
|
||||
blocks.append(
|
||||
{
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": text_content},
|
||||
}
|
||||
)
|
||||
text_content = ""
|
||||
|
||||
# 添加其他类型的块
|
||||
block = await SlackMessageEvent._from_segment_to_slack_block(
|
||||
segment, web_client
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
# 如果最后还有文本内容
|
||||
if text_content.strip():
|
||||
blocks.append(
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": text_content}}
|
||||
)
|
||||
|
||||
return blocks, "" if blocks else text_content
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
blocks, text = await SlackMessageEvent._parse_slack_blocks(
|
||||
message, self.web_client
|
||||
)
|
||||
|
||||
try:
|
||||
if self.get_group_id():
|
||||
# 发送到频道
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_group_id(),
|
||||
text=text,
|
||||
blocks=blocks or None,
|
||||
)
|
||||
else:
|
||||
# 发送私信
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_sender_id(),
|
||||
text=text,
|
||||
blocks=blocks or None,
|
||||
)
|
||||
except Exception:
|
||||
# 如果块发送失败,尝试只发送文本
|
||||
fallback_text = ""
|
||||
for segment in message.chain:
|
||||
if isinstance(segment, Plain):
|
||||
fallback_text += segment.text
|
||||
elif isinstance(segment, File):
|
||||
fallback_text += f" [文件: {segment.name}] "
|
||||
elif isinstance(segment, Image):
|
||||
fallback_text += " [图片] "
|
||||
|
||||
if self.get_group_id():
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_group_id(), text=fallback_text
|
||||
)
|
||||
else:
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_sender_id(), text=fallback_text
|
||||
)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
if group_id:
|
||||
channel_id = group_id
|
||||
elif self.get_group_id():
|
||||
channel_id = self.get_group_id()
|
||||
else:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 获取频道信息
|
||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||
|
||||
# 获取频道成员
|
||||
members_response = await self.web_client.conversations_members(
|
||||
channel=channel_id
|
||||
)
|
||||
|
||||
members = []
|
||||
for member_id in members_response["members"]:
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=member_id)
|
||||
user_data = user_info["user"]
|
||||
members.append(
|
||||
MessageMember(
|
||||
user_id=member_id,
|
||||
nickname=user_data.get("real_name")
|
||||
or user_data.get("name", member_id),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 如果获取用户信息失败,使用默认信息
|
||||
members.append(MessageMember(user_id=member_id, nickname=member_id))
|
||||
|
||||
channel_data = channel_info["channel"]
|
||||
return Group(
|
||||
group_id=channel_id,
|
||||
group_name=channel_data.get("name", ""),
|
||||
group_avatar="",
|
||||
group_admins=[], # Slack 的管理员信息需要特殊权限获取
|
||||
group_owner=channel_data.get("creator", ""),
|
||||
members=members,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -225,6 +225,10 @@ class ProviderManager:
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
provider_type_name="gsv_tts_selfhost",
|
||||
desc="GPT-SoVITS TTS(本地加载)",
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
)
|
||||
class ProviderGSVTTS(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
|
||||
"/"
|
||||
)
|
||||
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
|
||||
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
|
||||
|
||||
# TTS 请求的默认参数,移除前缀gsv_
|
||||
self.default_params: dict = {
|
||||
key.removeprefix("gsv_"): str(value).lower()
|
||||
for key, value in provider_config.get("gsv_default_parms", {}).items()
|
||||
}
|
||||
self.timeout = provider_config.get("timeout", 60)
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化:在 ProviderManager 中被调用"""
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
)
|
||||
try:
|
||||
await self._set_model_weights()
|
||||
logger.info("[GSV TTS] 初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[GSV TTS] 初始化失败:{e}")
|
||||
raise
|
||||
|
||||
def get_session(self) -> aiohttp.ClientSession:
|
||||
if not self._session or self._session.closed:
|
||||
raise RuntimeError(
|
||||
"[GSV TTS] Provider HTTP session is not ready or closed."
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def _make_request(
|
||||
self, endpoint: str, params=None, retries: int = 3
|
||||
) -> bytes | None:
|
||||
"""发起请求"""
|
||||
for attempt in range(retries):
|
||||
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
|
||||
try:
|
||||
async with self.get_session().get(endpoint, params=params) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
|
||||
)
|
||||
return await response.read()
|
||||
except Exception as e:
|
||||
if attempt < retries - 1:
|
||||
logger.warning(
|
||||
f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
|
||||
raise
|
||||
|
||||
async def _set_model_weights(self):
|
||||
"""设置模型路径"""
|
||||
try:
|
||||
if self.gpt_weights_path:
|
||||
await self._make_request(
|
||||
f"{self.api_base}/set_gpt_weights",
|
||||
{"weights_path": self.gpt_weights_path},
|
||||
)
|
||||
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
|
||||
else:
|
||||
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
|
||||
|
||||
if self.sovits_weights_path:
|
||||
await self._make_request(
|
||||
f"{self.api_base}/set_sovits_weights",
|
||||
{"weights_path": self.sovits_weights_path},
|
||||
)
|
||||
logger.info(
|
||||
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
|
||||
)
|
||||
else:
|
||||
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
|
||||
if not text.strip():
|
||||
raise ValueError("[GSV TTS] TTS 文本不能为空")
|
||||
|
||||
endpoint = f"{self.api_base}/tts"
|
||||
|
||||
params = self.build_synthesis_params(text)
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
|
||||
|
||||
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
|
||||
|
||||
result = await self._make_request(endpoint, params)
|
||||
if isinstance(result, bytes):
|
||||
with open(path, "wb") as f:
|
||||
f.write(result)
|
||||
return path
|
||||
else:
|
||||
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
|
||||
|
||||
def build_synthesis_params(self, text: str) -> dict:
|
||||
"""
|
||||
构建语音合成所需的参数字典。
|
||||
|
||||
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
|
||||
"""
|
||||
params = self.default_params.copy()
|
||||
params["text"] = text
|
||||
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
|
||||
return params
|
||||
|
||||
async def terminate(self):
|
||||
"""终止释放资源:在 ProviderManager 中被调用"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
logger.info("[GSV TTS] Session 已关闭")
|
||||
@@ -306,6 +306,24 @@
|
||||
</v-snackbar>
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
|
||||
<!-- Key为空的确认对话框 -->
|
||||
<v-dialog v-model="showKeyConfirm" max-width="450" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="text-h6 bg-error d-flex align-center">
|
||||
<v-icon start class="me-2">mdi-alert-circle-outline</v-icon>
|
||||
确认保存
|
||||
</v-card-title>
|
||||
<v-card-text class="py-4 text-body-1 text-medium-emphasis">
|
||||
您没有填写 API Key,确定要保存吗?这可能会导致该服务提供商无法正常工作。
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="handleKeyConfirm(false)">取消</v-btn>
|
||||
<v-btn color="error" variant="flat" @click="handleKeyConfirm(true)">确定</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -341,6 +359,10 @@ export default {
|
||||
sessionSeparationEnabled: false,
|
||||
sessionSettingLoading: false,
|
||||
|
||||
// Key确认对话框
|
||||
showKeyConfirm: false,
|
||||
keyConfirmResolve: null,
|
||||
|
||||
newSelectedProviderName: '',
|
||||
newSelectedProviderConfig: {},
|
||||
updatingMode: false,
|
||||
@@ -388,6 +410,16 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
watch: {
|
||||
showKeyConfirm(newValue) {
|
||||
// 当对话框关闭时,如果 Promise 还在等待,则拒绝它以防止内存泄漏
|
||||
if (!newValue && this.keyConfirmResolve) {
|
||||
this.keyConfirmResolve(false);
|
||||
this.keyConfirmResolve = null;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
computed: {
|
||||
// 翻译消息的计算属性
|
||||
messages() {
|
||||
@@ -591,32 +623,40 @@ export default {
|
||||
this.updatingMode = true;
|
||||
},
|
||||
|
||||
newProvider() {
|
||||
async newProvider() {
|
||||
// 检查 key 是否为空
|
||||
if (
|
||||
'key' in this.newSelectedProviderConfig &&
|
||||
(!this.newSelectedProviderConfig.key || this.newSelectedProviderConfig.key.length === 0)
|
||||
) {
|
||||
const confirmed = await this.confirmEmptyKey();
|
||||
if (!confirmed) {
|
||||
return; // 如果用户取消,则中止保存
|
||||
}
|
||||
}
|
||||
|
||||
this.loading = true;
|
||||
if (this.updatingMode) {
|
||||
axios.post('/api/config/provider/update', {
|
||||
id: this.newSelectedProviderName,
|
||||
config: this.newSelectedProviderConfig
|
||||
}).then((res) => {
|
||||
this.loading = false;
|
||||
this.showProviderCfg = false;
|
||||
this.getConfig();
|
||||
this.showSuccess(res.data.message || this.messages.success.update);
|
||||
}).catch((err) => {
|
||||
this.loading = false;
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
});
|
||||
this.updatingMode = false;
|
||||
} else {
|
||||
axios.post('/api/config/provider/new', this.newSelectedProviderConfig).then((res) => {
|
||||
this.loading = false;
|
||||
this.showProviderCfg = false;
|
||||
this.getConfig();
|
||||
this.showSuccess(res.data.message || this.messages.success.add);
|
||||
}).catch((err) => {
|
||||
this.loading = false;
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
});
|
||||
const wasUpdating = this.updatingMode;
|
||||
try {
|
||||
if (wasUpdating) {
|
||||
const res = await axios.post('/api/config/provider/update', {
|
||||
id: this.newSelectedProviderName,
|
||||
config: this.newSelectedProviderConfig
|
||||
});
|
||||
this.showSuccess(res.data.message || "更新成功!");
|
||||
} else {
|
||||
const res = await axios.post('/api/config/provider/new', this.newSelectedProviderConfig);
|
||||
this.showSuccess(res.data.message || "添加成功!");
|
||||
}
|
||||
this.showProviderCfg = false;
|
||||
this.getConfig();
|
||||
} catch (err) {
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
} finally {
|
||||
this.loading = false;
|
||||
if (wasUpdating) {
|
||||
this.updatingMode = false;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -698,7 +738,21 @@ export default {
|
||||
this.loadingStatus = false;
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
confirmEmptyKey() {
|
||||
this.showKeyConfirm = true;
|
||||
return new Promise((resolve) => {
|
||||
this.keyConfirmResolve = resolve;
|
||||
});
|
||||
},
|
||||
|
||||
handleKeyConfirm(confirmed) {
|
||||
if (this.keyConfirmResolve) {
|
||||
this.keyConfirmResolve(confirmed);
|
||||
}
|
||||
this.showKeyConfirm = false;
|
||||
},
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -33,6 +33,7 @@ dependencies = [
|
||||
"pillow>=11.2.1",
|
||||
"pip>=25.1.1",
|
||||
"psutil>=5.8.0",
|
||||
"py-cord[speed]>=2.6.1",
|
||||
"pydantic~=2.10.3",
|
||||
"pydub>=0.25.1",
|
||||
"pyjwt>=2.10.1",
|
||||
@@ -41,6 +42,7 @@ dependencies = [
|
||||
"quart>=0.20.0",
|
||||
"readability-lxml>=0.8.4.1",
|
||||
"silk-python>=0.2.6",
|
||||
"slack-sdk>=3.35.0",
|
||||
"telegramify-markdown>=0.5.1",
|
||||
"watchfiles>=1.0.5",
|
||||
"websockets>=15.0.1",
|
||||
|
||||
+3
-1
@@ -37,4 +37,6 @@ watchfiles
|
||||
websockets
|
||||
faiss-cpu
|
||||
aiosqlite
|
||||
nh3
|
||||
nh3
|
||||
py-cord[speed]>=2.6.1
|
||||
slack-sdk
|
||||
Reference in New Issue
Block a user