feat: 支持 TTS, STT 提供商的显示和快捷切换

This commit is contained in:
Soulter
2025-02-09 22:08:51 +08:00
parent ba45a2d270
commit 8da029add9
2 changed files with 104 additions and 19 deletions
+21 -1
View File
@@ -2,7 +2,7 @@ from asyncio import Queue
from typing import List, TypedDict, Union
from astrbot.core import sp
from astrbot.core.provider.provider import Provider
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.func_tool_manager import FuncCall
@@ -127,6 +127,14 @@ class Context:
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
return self.provider_manager.provider_insts
def get_all_tts_providers(self) -> List[TTSProvider]:
'''获取所有用于 TTS 任务的 Provider。'''
return self.provider_manager.tts_provider_insts
def get_all_stt_providers(self) -> List[STTProvider]:
'''获取所有用于 STT 任务的 Provider。'''
return self.provider_manager.stt_provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
@@ -135,6 +143,18 @@ class Context:
'''
return self.provider_manager.curr_provider_inst
def get_using_tts_provider(self) -> TTSProvider:
'''
获取当前使用的用于 TTS 任务的 Provider。
'''
return self.provider_manager.curr_tts_provider_inst
def get_using_stt_provider(self) -> STTProvider:
'''
获取当前使用的用于 STT 任务的 Provider。
'''
return self.provider_manager.curr_stt_provider_inst
def get_config(self) -> AstrBotConfig:
'''获取 AstrBot 的配置。'''
return self._config
+83 -18
View File
@@ -60,6 +60,7 @@ AstrBot 指令:
[System]
/plugin: 查看插件、插件帮助
/t2i: 开关文本转图片
/tts: 开关文本转语音
/sid: 获取会话 ID
/op <admin_id>: 授权管理员(op)
/deop <admin_id>: 取消管理员(op)
@@ -84,10 +85,7 @@ AstrBot 指令:
/websearch: 网页搜索
[其他]
/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入
/unset <变量名>: 删除会话的变量。
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
/set 变量名: 为会话定义变量(Dify 工作流输入)
{notice}"""
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@@ -126,7 +124,7 @@ AstrBot 指令:
tm = self.context.get_llm_tool_manager()
for tool in tm.func_list:
self.context.deactivate_llm_tool(tool.name)
event.set_result(MessageEventResult().message(f"停用所有工具成功。"))
event.set_result(MessageEventResult().message("停用所有工具成功。"))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
@@ -201,6 +199,18 @@ AstrBot 指令:
config.save_config()
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
@filter.command("tts")
async def tts(self, event: AstrMessageEvent):
config = self.context.get_config()
if config['provider_tts_settings']['enable']:
config['provider_tts_settings']['enable'] = False
config.save_config()
event.set_result(MessageEventResult().message("已关闭文本转语音。"))
return
config['provider_tts_settings']['enable'] = True
config.save_config()
event.set_result(MessageEventResult().message("已开启文本转语音。"))
@filter.command("sid")
async def sid(self, event: AstrMessageEvent):
sid = event.unified_msg_origin
@@ -246,34 +256,89 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
@filter.command("provider")
async def provider(self, event: AstrMessageEvent, idx: int = None):
async def provider(self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None):
'''查看或者切换 LLM Provider'''
if not self.context.get_using_provider():
event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
if idx is None:
ret = "## 当前载入的 LLM 提供商\n"
if idx is None:
ret = "## 载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
id_ = llm.meta().id
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
if self.context.get_using_provider().meta().id == id_:
ret += " (当前使用)"
ret += "\n"
tts_providers = self.context.get_all_tts_providers()
if tts_providers:
ret += "\n## 载入的 TTS 提供商\n"
for idx, tts in enumerate(tts_providers):
id_ = tts.meta().id
ret += f"{idx + 1}. {id_}"
tts_using = self.context.get_using_tts_provider()
if tts_using and tts_using.meta().id == id_:
ret += " (当前使用)"
ret += "\n"
stt_providers = self.context.get_all_stt_providers()
if stt_providers:
ret += "\n## 载入的 STT 提供商\n"
for idx, stt in enumerate(stt_providers):
id_ = stt.meta().id
ret += f"{idx + 1}. {id_}"
stt_using = self.context.get_using_stt_provider()
if stt_using and stt_using.meta().id == id_:
ret += " (当前使用)"
ret += "\n"
ret += "\n使用 /provider <序号> 切换提供商。"
ret += "\n使用 /provider <序号> 切换 LLM 提供商。"
if tts_providers:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
if stt_providers:
ret += "\n使用 /provider stt <切换> STT 提供商。"
event.set_result(MessageEventResult().message(ret))
else:
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
if idx == "tts":
if idx2 is None:
event.set_result(MessageEventResult().message("请输入序号。"))
return
else:
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
provider = self.context.get_all_tts_providers()[idx2 - 1]
id_ = provider.meta().id
self.context.provider_manager.curr_tts_provider_inst = provider
sp.put("curr_provider_tts", id_)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif idx == "stt":
if idx2 is None:
event.set_result(MessageEventResult().message("请输入序号。"))
return
else:
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
provider = self.context.get_all_stt_providers()[idx2 - 1]
id_ = provider.meta().id
self.context.provider_manager.curr_stt_provider_inst = provider
sp.put("curr_provider_stt", id_)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif isinstance(idx, int):
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
self.context.provider_manager.curr_provider_inst = provider
sp.put("curr_provider", id_)
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
self.context.provider_manager.curr_provider_inst = provider
sp.put("curr_provider", id_)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
else:
event.set_result(MessageEventResult().message("无效的参数。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("reset")
@@ -582,7 +647,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。使用 /unset 移除。")
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
@@ -592,7 +657,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
session_var = session_vars.get(session_id, {})
if key not in session_var:
yield event.plain_result("没有那个变量名。")
yield event.plain_result("没有那个变量名。格式 /unset 变量名。")
else:
del session_var[key]
sp.put("session_variables", session_vars)