From 8da029add9e26a68037e4bc64fbfa013c9b7cb19 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 9 Feb 2025 22:08:51 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20TTS,=20STT=20?= =?UTF-8?q?=E6=8F=90=E4=BE=9B=E5=95=86=E7=9A=84=E6=98=BE=E7=A4=BA=E5=92=8C?= =?UTF-8?q?=E5=BF=AB=E6=8D=B7=E5=88=87=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/star/context.py | 22 +++++++- packages/astrbot/main.py | 101 ++++++++++++++++++++++++++++------- 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 472ed0e96..068cd1f75 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -2,7 +2,7 @@ from asyncio import Queue from typing import List, TypedDict, Union from astrbot.core import sp -from astrbot.core.provider.provider import Provider +from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.provider.func_tool_manager import FuncCall @@ -127,6 +127,14 @@ class Context: '''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。''' return self.provider_manager.provider_insts + def get_all_tts_providers(self) -> List[TTSProvider]: + '''获取所有用于 TTS 任务的 Provider。''' + return self.provider_manager.tts_provider_insts + + def get_all_stt_providers(self) -> List[STTProvider]: + '''获取所有用于 STT 任务的 Provider。''' + return self.provider_manager.stt_provider_insts + def get_using_provider(self) -> Provider: ''' 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 @@ -135,6 +143,18 @@ class Context: ''' return self.provider_manager.curr_provider_inst + def get_using_tts_provider(self) -> TTSProvider: + ''' + 获取当前使用的用于 TTS 任务的 Provider。 + ''' + return self.provider_manager.curr_tts_provider_inst + + def get_using_stt_provider(self) -> STTProvider: + ''' + 获取当前使用的用于 STT 任务的 Provider。 + ''' + return self.provider_manager.curr_stt_provider_inst + def get_config(self) -> AstrBotConfig: '''获取 AstrBot 的配置。''' return self._config diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 6732942f7..a380ce1e6 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -60,6 +60,7 @@ AstrBot 指令: [System] /plugin: 查看插件、插件帮助 /t2i: 开关文本转图片 +/tts: 开关文本转语音 /sid: 获取会话 ID /op : 授权管理员(op) /deop : 取消管理员(op) @@ -84,10 +85,7 @@ AstrBot 指令: /websearch: 网页搜索 [其他] -/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。 -/unset <变量名>: 删除会话的变量。 - -提示:如要查看插件指令,请输入 /plugin 查看具体信息。 +/set 变量名 值: 为会话定义变量(Dify 工作流输入) {notice}""" event.set_result(MessageEventResult().message(msg).use_t2i(False)) @@ -126,7 +124,7 @@ AstrBot 指令: tm = self.context.get_llm_tool_manager() for tool in tm.func_list: self.context.deactivate_llm_tool(tool.name) - event.set_result(MessageEventResult().message(f"停用所有工具成功。")) + event.set_result(MessageEventResult().message("停用所有工具成功。")) @filter.command("plugin") async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None): @@ -201,6 +199,18 @@ AstrBot 指令: config.save_config() event.set_result(MessageEventResult().message("已开启文本转图片模式。")) + @filter.command("tts") + async def tts(self, event: AstrMessageEvent): + config = self.context.get_config() + if config['provider_tts_settings']['enable']: + config['provider_tts_settings']['enable'] = False + config.save_config() + event.set_result(MessageEventResult().message("已关闭文本转语音。")) + return + config['provider_tts_settings']['enable'] = True + config.save_config() + event.set_result(MessageEventResult().message("已开启文本转语音。")) + @filter.command("sid") async def sid(self, event: AstrMessageEvent): sid = event.unified_msg_origin @@ -246,34 +256,89 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) @filter.command("provider") - async def provider(self, event: AstrMessageEvent, idx: int = None): + async def provider(self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None): '''查看或者切换 LLM Provider''' if not self.context.get_using_provider(): event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) return - if idx is None: - ret = "## 当前载入的 LLM 提供商\n" + if idx is None: + ret = "## 载入的 LLM 提供商\n" for idx, llm in enumerate(self.context.get_all_providers()): id_ = llm.meta().id ret += f"{idx + 1}. {id_} ({llm.meta().model})" if self.context.get_using_provider().meta().id == id_: ret += " (当前使用)" ret += "\n" + + tts_providers = self.context.get_all_tts_providers() + if tts_providers: + ret += "\n## 载入的 TTS 提供商\n" + for idx, tts in enumerate(tts_providers): + id_ = tts.meta().id + ret += f"{idx + 1}. {id_}" + tts_using = self.context.get_using_tts_provider() + if tts_using and tts_using.meta().id == id_: + ret += " (当前使用)" + ret += "\n" + + stt_providers = self.context.get_all_stt_providers() + if stt_providers: + ret += "\n## 载入的 STT 提供商\n" + for idx, stt in enumerate(stt_providers): + id_ = stt.meta().id + ret += f"{idx + 1}. {id_}" + stt_using = self.context.get_using_stt_provider() + if stt_using and stt_using.meta().id == id_: + ret += " (当前使用)" + ret += "\n" - ret += "\n使用 /provider <序号> 切换提供商。" + ret += "\n使用 /provider <序号> 切换 LLM 提供商。" + + if tts_providers: + ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" + if stt_providers: + ret += "\n使用 /provider stt <切换> STT 提供商。" + event.set_result(MessageEventResult().message(ret)) else: - if idx > len(self.context.get_all_providers()) or idx < 1: - event.set_result(MessageEventResult().message("无效的序号。")) + if idx == "tts": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return + else: + if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_tts_providers()[idx2 - 1] + id_ = provider.meta().id + self.context.provider_manager.curr_tts_provider_inst = provider + sp.put("curr_provider_tts", id_) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + elif idx == "stt": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return + else: + if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_stt_providers()[idx2 - 1] + id_ = provider.meta().id + self.context.provider_manager.curr_stt_provider_inst = provider + sp.put("curr_provider_stt", id_) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + elif isinstance(idx, int): + if idx > len(self.context.get_all_providers()) or idx < 1: + event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_providers()[idx - 1] - id_ = provider.meta().id - self.context.provider_manager.curr_provider_inst = provider - sp.put("curr_provider", id_) + provider = self.context.get_all_providers()[idx - 1] + id_ = provider.meta().id + self.context.provider_manager.curr_provider_inst = provider + sp.put("curr_provider", id_) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + else: + event.set_result(MessageEventResult().message("无效的参数。")) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("reset") @@ -582,7 +647,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo sp.put("session_variables", session_vars) - yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。") + yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。使用 /unset 移除。") @filter.command("unset") async def unset_variable(self, event: AstrMessageEvent, key: str): @@ -592,7 +657,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo session_var = session_vars.get(session_id, {}) if key not in session_var: - yield event.plain_result("没有那个变量名。") + yield event.plain_result("没有那个变量名。格式 /unset 变量名。") else: del session_var[key] sp.put("session_variables", session_vars)