feat: 适配 OpenAI TTS API,并支持 Napcat,Gewechat,Lagrange 的语音输出
This commit is contained in:
@@ -40,6 +40,10 @@ DEFAULT_CONFIG = {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
||||
@@ -371,6 +375,14 @@ CONFIG_METADATA_2 = {
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"openai_tts(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "tts-1",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"whisper_hint": {
|
||||
@@ -570,6 +582,23 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"description": "文本转语音(TTS)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用文本转语音(TTS)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个TTS提供商",
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"misc_config_group": {
|
||||
|
||||
@@ -7,7 +7,6 @@ from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
|
||||
@@ -123,7 +123,8 @@ class Record(BaseMessageComponent):
|
||||
proxy: T.Optional[bool] = True
|
||||
timeout: T.Optional[int] = 0
|
||||
# 额外
|
||||
path: T.Optional[str] # 用这个
|
||||
path: T.Optional[str]
|
||||
duration: T.Optional[int] = 0 # 毫秒
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
for k in _.keys():
|
||||
|
||||
@@ -139,7 +139,7 @@ class MessageEventResult(MessageChain):
|
||||
'''
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
|
||||
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
|
||||
'''设置事件处理的结果类型。
|
||||
|
||||
Args:
|
||||
@@ -148,5 +148,10 @@ class MessageEventResult(MessageChain):
|
||||
self.result_content_type = typ
|
||||
return self
|
||||
|
||||
def is_llm_result(self) -> bool:
|
||||
'''是否为 LLM 结果。
|
||||
'''
|
||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
|
||||
CommandResult = MessageEventResult
|
||||
@@ -1,11 +1,12 @@
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@@ -16,6 +17,7 @@ class ResultDecorateStage:
|
||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
|
||||
self.t2i = ctx.astrbot_config['t2i']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -32,9 +34,28 @@ class ResultDecorateStage:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
result.chain.insert(0, Plain(self.reply_prefix))
|
||||
|
||||
# TTS
|
||||
if self.use_tts and result.is_llm_result():
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain_str += " " + comp.text
|
||||
else:
|
||||
break
|
||||
if plain_str:
|
||||
try:
|
||||
audio_path = await tts_provider.get_audio(plain_str)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
result.chain = [Record(file=audio_path, url=audio_path)]
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
|
||||
# 文本转图片
|
||||
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import asyncio
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
|
||||
@@ -20,16 +20,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d['type'] = 'text'
|
||||
if isinstance(segment, Image):
|
||||
if isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
if segment.file and segment.file.startswith("file:///"):
|
||||
image_base64 = file_to_base64(segment.file[8:])
|
||||
bs64_data = file_to_base64(segment.file[8:])
|
||||
image_file_path = segment.file[8:]
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(segment.file)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
else:
|
||||
bs64_data = file_to_base64(segment.file)
|
||||
d['data'] = {
|
||||
'file': image_base64,
|
||||
'file': bs64_data,
|
||||
}
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
@@ -163,7 +163,7 @@ class SimpleGewechatClient():
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}.jpg"
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
@@ -185,11 +185,7 @@ class SimpleGewechatClient():
|
||||
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
|
||||
|
||||
async def start_polling(self):
|
||||
|
||||
# 设置回调
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
|
||||
|
||||
await self.server.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
@@ -327,4 +323,23 @@ class SimpleGewechatClient():
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送语音结果: {json_blob}")
|
||||
@@ -1,13 +1,24 @@
|
||||
import random
|
||||
import asyncio
|
||||
import wave
|
||||
import uuid
|
||||
import os
|
||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url
|
||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from .client import SimpleGewechatClient
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, 'rb') as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -39,18 +50,53 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
img_url = comp.file
|
||||
img_path = ""
|
||||
if img_url.startswith("file:///"):
|
||||
with open(comp.file[8:], "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
img_path = img_url[8:]
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
img_path = await download_image_by_url(comp.file)
|
||||
else:
|
||||
img_path = img_url
|
||||
|
||||
if not img_path:
|
||||
logger.error("无法获取到图片路径。")
|
||||
return
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
temp_directory = os.path.abspath('data/temp')
|
||||
img_path = os.path.abspath(img_path)
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path).split(".")[0]
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await self.client.post_image(to_wxid, img_url)
|
||||
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = ""
|
||||
|
||||
if record_url.startswith("file:///"):
|
||||
record_path = record_url[8:]
|
||||
elif record_url.startswith("http"):
|
||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
||||
else:
|
||||
record_path = record_url
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
|
||||
print(f"duration: {duration}, {silk_path}")
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
# temp_directory = os.path.abspath('data/temp')
|
||||
# record_path = os.path.abspath(record_path)
|
||||
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
|
||||
# with open(record_path, "rb") as f:
|
||||
# record_path = f"data/temp/{uuid.uuid4()}.wav"
|
||||
# with open(record_path, "wb") as f2:
|
||||
# f2.write(f.read())
|
||||
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{self.client.file_server_url}/{file_id}"
|
||||
await self.client.post_voice(to_wxid, record_url, duration*1000)
|
||||
await super().send(message)
|
||||
@@ -80,4 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
elif i.file and i.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(i.file)
|
||||
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
|
||||
else:
|
||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
||||
image_file_path = i.file
|
||||
return plain_text, image_base64, image_file_path
|
||||
@@ -32,6 +32,10 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -1,6 +1,6 @@
|
||||
import traceback
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider, STTProvider, Personality
|
||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||
from .entites import ProviderType
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -64,11 +64,15 @@ class ProviderManager():
|
||||
'''加载的 Provider 的实例'''
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
'''加载的 Speech To Text Provider 的实例'''
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
'''加载的 Text To Speech Provider 的实例'''
|
||||
self.llm_tools = llm_tools
|
||||
self.curr_provider_inst: Provider = None
|
||||
'''当前使用的 Provider 实例'''
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
'''当前使用的 Speech To Text Provider 实例'''
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
'''当前使用的 Text To Speech Provider 实例'''
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
@@ -103,6 +107,8 @@ class ProviderManager():
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
||||
continue
|
||||
@@ -119,8 +125,10 @@ class ProviderManager():
|
||||
continue
|
||||
selected_provider_id = sp.get("curr_provider")
|
||||
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
provider_enabled = self.provider_settings.get("enable", False)
|
||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
tts_enabled = self.provider_settings.get("enable", False)
|
||||
|
||||
provider_metadata = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
@@ -138,6 +146,18 @@ class ProviderManager():
|
||||
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
@@ -167,11 +187,18 @@ class ProviderManager():
|
||||
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
if self.provider_stt_settings.get("enable"):
|
||||
if not self.curr_stt_provider_inst:
|
||||
|
||||
if stt_enabled and not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
|
||||
if tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
@@ -24,9 +24,32 @@ class ProviderMeta():
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
def __init__(self, provider_config: dict) -> None:
|
||||
super().__init__()
|
||||
self.model_name = ""
|
||||
self.provider_config = provider_config
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
|
||||
class Provider(abc.ABC):
|
||||
class Provider(AbstractProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
@@ -35,14 +58,11 @@ class Provider(abc.ABC):
|
||||
db_helper: BaseDatabase = None,
|
||||
default_persona: Personality = None
|
||||
) -> None:
|
||||
self.model_name = ""
|
||||
'''当前使用的模型名称'''
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.session_memory = defaultdict(list)
|
||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
||||
|
||||
self.provider_config = provider_config
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality: Personality = default_persona
|
||||
@@ -58,14 +78,6 @@ class Provider(abc.ABC):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
@@ -133,17 +145,11 @@ class Provider(abc.ABC):
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
|
||||
|
||||
class STTProvider():
|
||||
class STTProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@@ -151,19 +157,15 @@ class STTProvider():
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''获取音频的文本'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获取当前使用的模型'''
|
||||
return self.provider_config.get("model", "")
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
@abc.abstractmethod
|
||||
async def get_audio(self, text: str) -> str:
|
||||
'''获取文本的音频,返回音频文件路径'''
|
||||
raise NotImplementedError()
|
||||
@@ -1,4 +1,3 @@
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
|
||||
@@ -110,7 +109,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"completion: {completion.usage}")
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import uuid
|
||||
import os
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
|
||||
class ProviderOpenAITTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = provider_config.get("api_key", "")
|
||||
self.voice = provider_config.get("voice", "alloy")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name,
|
||||
voice=self.voice,
|
||||
response_format='wav',
|
||||
input=text
|
||||
) as response:
|
||||
with open(path, 'wb') as f:
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
return path
|
||||
@@ -2,36 +2,29 @@ import wave
|
||||
from io import BytesIO
|
||||
|
||||
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
|
||||
import pysilk
|
||||
import pilk
|
||||
|
||||
with open(silk_path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
input_data = input_data[1:]
|
||||
input_io = BytesIO(input_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(output_io.read())
|
||||
pcm_path = f"{output_path}.pcm"
|
||||
pilk.decode(silk_path, pcm_path)
|
||||
|
||||
with open(pcm_path, "rb") as pcm:
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(pcm.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def wav_to_tencent_silk(wav_path: str) -> BytesIO:
|
||||
import pysilk
|
||||
|
||||
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
'''返回 duration'''
|
||||
import pilk
|
||||
|
||||
# wav to pcm
|
||||
with wave.open(wav_path, 'rb') as wav:
|
||||
wav_data = wav.readframes(wav.getnframes())
|
||||
wav_data = BytesIO(wav_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.encode(wav_data, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
pcm_path = f"{wav_path}.pcm"
|
||||
with open(pcm_path, "wb") as f:
|
||||
f.write(wav.readframes(wav.getnframes()))
|
||||
|
||||
# 在首字节添加 \x02
|
||||
silk_data = output_io.read()
|
||||
silk_data_with_prefix = b'\x02' + silk_data
|
||||
|
||||
return BytesIO(silk_data_with_prefix)
|
||||
return pilk.encode(pcm_path, output_path, pcm_rate=24000, tencent=True)
|
||||
+1
-1
@@ -17,4 +17,4 @@ pyjwt
|
||||
apscheduler
|
||||
docstring_parser
|
||||
aiodocker
|
||||
silk-python
|
||||
pilk
|
||||
Reference in New Issue
Block a user