From 637acd1a129a0227ca4d996bff0b3f6d1b749af8 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 25 Jan 2025 19:46:00 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=80=82=E9=85=8D=20OpenAI=20TTS=20API?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E6=94=AF=E6=8C=81=20Napcat=EF=BC=8CGewechat?= =?UTF-8?q?=EF=BC=8CLagrange=20=E7=9A=84=E8=AF=AD=E9=9F=B3=E8=BE=93?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 29 +++++++ astrbot/core/core_lifecycle.py | 1 - astrbot/core/message/components.py | 3 +- astrbot/core/message/message_event_result.py | 7 +- .../core/pipeline/result_decorate/stage.py | 25 +++++- .../aiocqhttp/aiocqhttp_message_event.py | 12 +-- .../core/platform/sources/gewechat/client.py | 27 +++++-- .../sources/gewechat/gewechat_event.py | 68 ++++++++++++++--- .../qqofficial/qqofficial_message_event.py | 3 + .../platform/sources/webchat/webchat_event.py | 4 + astrbot/core/provider/manager.py | 33 +++++++- astrbot/core/provider/provider.py | 76 ++++++++++--------- .../core/provider/sources/openai_source.py | 3 +- .../provider/sources/openai_tts_api_source.py | 40 ++++++++++ astrbot/core/utils/tencent_record_helper.py | 45 +++++------ requirements.txt | 2 +- 16 files changed, 282 insertions(+), 96 deletions(-) create mode 100644 astrbot/core/provider/sources/openai_tts_api_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index afbae30cc..8cb965988 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -40,6 +40,10 @@ DEFAULT_CONFIG = { "enable": False, "provider_id": "", }, + "provider_tts_settings": { + "enable": False, + "provider_id": "", + }, "content_safety": { "internal_keywords": {"enable": True, "extra_keywords": []}, "baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""}, @@ -371,6 +375,14 @@ CONFIG_METADATA_2 = { "type": "openai_whisper_selfhost", "model": "tiny", }, + "openai_tts(API)": { + "id": "openai_tts", + "type": "openai_tts_api", + "enable": False, + "api_key": "", + "api_base": "", + "model": "tts-1", + }, }, "items": { "whisper_hint": { @@ -570,6 +582,23 @@ CONFIG_METADATA_2 = { }, }, }, + "provider_tts_settings": { + "description": "文本转语音(TTS)", + "type": "object", + "items": { + "enable": { + "description": "启用文本转语音(TTS)", + "type": "bool", + "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。", + "obvious_hint": True, + }, + "provider_id": { + "description": "提供商 ID,不填则默认第一个TTS提供商", + "type": "string", + "hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。", + }, + }, + }, }, }, "misc_config_group": { diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 417963971..dd49c6957 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -7,7 +7,6 @@ from .event_bus import EventBus from . import astrbot_config from asyncio import Queue from typing import List -from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext from astrbot.core.star import PluginManager from astrbot.core.platform.manager import PlatformManager diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 74d90c535..d43d2d89f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -123,7 +123,8 @@ class Record(BaseMessageComponent): proxy: T.Optional[bool] = True timeout: T.Optional[int] = 0 # 额外 - path: T.Optional[str] # 用这个 + path: T.Optional[str] + duration: T.Optional[int] = 0 # 毫秒 def __init__(self, file: T.Optional[str], **_): for k in _.keys(): diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index b75f0b2a6..9ecb0cd23 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -139,7 +139,7 @@ class MessageEventResult(MessageChain): ''' return self.result_type == EventResultType.STOP - def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult': + def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult': '''设置事件处理的结果类型。 Args: @@ -148,5 +148,10 @@ class MessageEventResult(MessageChain): self.result_content_type = typ return self + def is_llm_result(self) -> bool: + '''是否为 LLM 结果。 + ''' + return self.result_content_type == ResultContentType.LLM_RESULT + CommandResult = MessageEventResult \ No newline at end of file diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 6580318dd..46e7914ee 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -1,11 +1,12 @@ import time +import traceback from typing import Union, AsyncGenerator from ..stage import register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType from astrbot.core import logger -from astrbot.core.message.components import Plain, Image, At, Reply +from astrbot.core.message.components import Plain, Image, At, Reply, Record from astrbot.core import html_renderer from astrbot.core.star.star_handler import star_handlers_registry, EventType @@ -16,6 +17,7 @@ class ResultDecorateStage: self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix'] self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention'] self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote'] + self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable'] self.t2i = ctx.astrbot_config['t2i'] async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: @@ -32,9 +34,28 @@ class ResultDecorateStage: # 回复前缀 if self.reply_prefix: result.chain.insert(0, Plain(self.reply_prefix)) + + # TTS + if self.use_tts and result.is_llm_result(): + tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst + plain_str = "" + for comp in result.chain: + if isinstance(comp, Plain): + plain_str += " " + comp.text + else: + break + if plain_str: + try: + audio_path = await tts_provider.get_audio(plain_str) + logger.info("TTS 结果: " + audio_path) + if audio_path: + result.chain = [Record(file=audio_path, url=audio_path)] + except BaseException: + traceback.print_exc() + logger.error("TTS 失败,使用文本发送。") # 文本转图片 - if (result.use_t2i_ is None and self.t2i) or result.use_t2i_: + elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_: plain_str = "" for comp in result.chain: if isinstance(comp, Plain): diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index b218057df..aa26516c8 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -3,7 +3,7 @@ import random import asyncio from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Plain, Image +from astrbot.api.message_components import Plain, Image, Record from aiocqhttp import CQHttp from astrbot.core.utils.io import file_to_base64, download_image_by_url @@ -20,16 +20,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent): d = segment.toDict() if isinstance(segment, Plain): d['type'] = 'text' - if isinstance(segment, Image): + if isinstance(segment, (Image, Record)): # convert to base64 if segment.file and segment.file.startswith("file:///"): - image_base64 = file_to_base64(segment.file[8:]) + bs64_data = file_to_base64(segment.file[8:]) image_file_path = segment.file[8:] elif segment.file and segment.file.startswith("http"): image_file_path = await download_image_by_url(segment.file) - image_base64 = file_to_base64(image_file_path) + bs64_data = file_to_base64(image_file_path) + else: + bs64_data = file_to_base64(segment.file) d['data'] = { - 'file': image_base64, + 'file': bs64_data, } ret.append(d) return ret diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index cc9eab512..044f64483 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -163,7 +163,7 @@ class SimpleGewechatClient(): return quart.jsonify({"r": "AstrBot ACK"}) async def handle_file(self, file_id): - file_path = f"data/temp/{file_id}.jpg" + file_path = f"data/temp/{file_id}" return await quart.send_file(file_path) async def _set_callback_url(self): @@ -185,11 +185,7 @@ class SimpleGewechatClient(): logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。") async def start_polling(self): - - # 设置回调 threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start() - - await self.server.run_task( host=self.host, port=self.port, @@ -327,4 +323,23 @@ class SimpleGewechatClient(): json=payload ) as resp: json_blob = await resp.json() - logger.debug(f"发送图片结果: {json_blob}") \ No newline at end of file + logger.debug(f"发送图片结果: {json_blob}") + + async def post_voice(self, to_wxid, voice_url: str, voice_duration: int): + payload = { + "appId": self.appid, + "toWxid": to_wxid, + "voiceUrl": voice_url, + "voiceDuration": voice_duration + } + + logger.debug(f"发送语音: {payload}") + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/message/postVoice", + headers=self.headers, + json=payload + ) as resp: + json_blob = await resp.json() + logger.debug(f"发送语音结果: {json_blob}") \ No newline at end of file diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 2de45a218..27048b8e5 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -1,13 +1,24 @@ -import random -import asyncio +import wave +import uuid import os -from astrbot.core.utils.io import save_temp_img, download_image_by_url +from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file +from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata -from astrbot.api.message_components import Plain, Image +from astrbot.api.message_components import Plain, Image, Record from .client import SimpleGewechatClient +def get_wav_duration(file_path): + with wave.open(file_path, 'rb') as wav_file: + file_size = os.path.getsize(file_path) + n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4] + if n_frames == 2147483647: + duration = (file_size - 44) / (n_channels * sampwidth * framerate) + else: + duration = n_frames / float(framerate) + return duration + class GewechatPlatformEvent(AstrMessageEvent): def __init__( self, @@ -39,18 +50,53 @@ class GewechatPlatformEvent(AstrMessageEvent): img_url = comp.file img_path = "" if img_url.startswith("file:///"): - with open(comp.file[8:], "rb") as f: - img_path = save_temp_img(f.read()) + img_path = img_url[8:] elif comp.file and comp.file.startswith("http"): img_path = await download_image_by_url(comp.file) + else: + img_path = img_url - if not img_path: - logger.error("无法获取到图片路径。") - return + # 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径 + temp_directory = os.path.abspath('data/temp') + img_path = os.path.abspath(img_path) + if os.path.commonpath([temp_directory, img_path]) != temp_directory: + with open(img_path, "rb") as f: + img_path = save_temp_img(f.read()) - file_id = os.path.basename(img_path).split(".")[0] + file_id = os.path.basename(img_path) img_url = f"{self.client.file_server_url}/{file_id}" logger.debug(f"gewe callback img url: {img_url}") await self.client.post_image(to_wxid, img_url) - + elif isinstance(comp, Record): + # 默认已经存在 data/temp 中 + record_url = comp.file + record_path = "" + + if record_url.startswith("file:///"): + record_path = record_url[8:] + elif record_url.startswith("http"): + await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") + else: + record_path = record_url + + silk_path = f"data/temp/{uuid.uuid4()}.silk" + duration = await wav_to_tencent_silk(record_path, silk_path) + + print(f"duration: {duration}, {silk_path}") + + # 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径 + # temp_directory = os.path.abspath('data/temp') + # record_path = os.path.abspath(record_path) + # if os.path.commonpath([temp_directory, record_path]) != temp_directory: + # with open(record_path, "rb") as f: + # record_path = f"data/temp/{uuid.uuid4()}.wav" + # with open(record_path, "wb") as f2: + # f2.write(f.read()) + + if duration == 0: + duration = get_wav_duration(record_path) + + file_id = os.path.basename(silk_path) + record_url = f"{self.client.file_server_url}/{file_id}" + await self.client.post_voice(to_wxid, record_url, duration*1000) await super().send(message) \ No newline at end of file diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index bd077a078..50074d0f1 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -80,4 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): elif i.file and i.file.startswith("http"): image_file_path = await download_image_by_url(i.file) image_base64 = file_to_base64(image_file_path).replace("base64://", "") + else: + image_base64 = file_to_base64(i.file).replace("base64://", "") + image_file_path = i.file return plain_text, image_base64, image_file_path \ No newline at end of file diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index b19526000..f447a616c 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -32,6 +32,10 @@ class WebChatMessageEvent(AstrMessageEvent): f.write(f2.read()) elif comp.file and comp.file.startswith("http"): await download_image_by_url(comp.file, path=path) + else: + with open(path, "wb") as f: + with open(comp.file, "rb") as f2: + f.write(f2.read()) web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid)) web_chat_back_queue.put_nowait(None) await super().send(message) \ No newline at end of file diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 2305f34bf..33c94108b 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,6 +1,6 @@ import traceback from astrbot.core.config.astrbot_config import AstrBotConfig -from .provider import Provider, STTProvider, Personality +from .provider import Provider, STTProvider, TTSProvider, Personality from .entites import ProviderType from typing import List from astrbot.core.db import BaseDatabase @@ -64,11 +64,15 @@ class ProviderManager(): '''加载的 Provider 的实例''' self.stt_provider_insts: List[STTProvider] = [] '''加载的 Speech To Text Provider 的实例''' + self.tts_provider_insts: List[TTSProvider] = [] + '''加载的 Text To Speech Provider 的实例''' self.llm_tools = llm_tools self.curr_provider_inst: Provider = None '''当前使用的 Provider 实例''' self.curr_stt_provider_inst: STTProvider = None '''当前使用的 Speech To Text Provider 实例''' + self.curr_tts_provider_inst: TTSProvider = None + '''当前使用的 Text To Speech Provider 实例''' self.loaded_ids = defaultdict(bool) self.db_helper = db_helper @@ -103,6 +107,8 @@ class ProviderManager(): from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401 case "openai_whisper_selfhost": from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401 + case "openai_tts_api": + from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401 except (ImportError, ModuleNotFoundError) as e: logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") continue @@ -119,8 +125,10 @@ class ProviderManager(): continue selected_provider_id = sp.get("curr_provider") selected_stt_provider_id = self.provider_stt_settings.get("provider_id") + selected_tts_provider_id = self.provider_settings.get("provider_id") provider_enabled = self.provider_settings.get("enable", False) stt_enabled = self.provider_stt_settings.get("enable", False) + tts_enabled = self.provider_settings.get("enable", False) provider_metadata = provider_cls_map[provider_config['type']] logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...") @@ -138,6 +146,18 @@ class ProviderManager(): if selected_stt_provider_id == provider_config['id'] and stt_enabled: self.curr_stt_provider_inst = inst logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") + + elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: + # TTS 任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.tts_provider_insts.append(inst) + if selected_tts_provider_id == provider_config['id'] and tts_enabled: + self.curr_tts_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。") elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: # 文本生成任务 @@ -167,11 +187,18 @@ class ProviderManager(): if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled: self.curr_stt_provider_inst = self.stt_provider_insts[0] + if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled: + self.curr_tts_provider_inst = self.tts_provider_insts[0] + if not self.curr_provider_inst: logger.warning("未启用任何用于 文本生成 的提供商适配器。") - if self.provider_stt_settings.get("enable"): - if not self.curr_stt_provider_inst: + + if stt_enabled and not self.curr_stt_provider_inst: logger.warning("未启用任何用于 语音转文本 的提供商适配器。") + + if tts_enabled and not self.curr_tts_provider_inst: + logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + def get_insts(self): return self.provider_insts diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 0c44923d0..34580c6f7 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -24,9 +24,32 @@ class ProviderMeta(): id: str model: str type: str + + +class AbstractProvider(abc.ABC): + def __init__(self, provider_config: dict) -> None: + super().__init__() + self.model_name = "" + self.provider_config = provider_config + + def set_model(self, model_name: str): + '''设置当前使用的模型名称''' + self.model_name = model_name + + def get_model(self) -> str: + '''获得当前使用的模型名称''' + return self.model_name + + def meta(self) -> ProviderMeta: + '''获取 Provider 的元数据''' + return ProviderMeta( + id=self.provider_config['id'], + model=self.get_model(), + type=self.provider_config['type'] + ) -class Provider(abc.ABC): +class Provider(AbstractProvider): def __init__( self, provider_config: dict, @@ -35,14 +58,11 @@ class Provider(abc.ABC): db_helper: BaseDatabase = None, default_persona: Personality = None ) -> None: - self.model_name = "" - '''当前使用的模型名称''' + super().__init__(provider_config) self.session_memory = defaultdict(list) '''维护了 session_id 的上下文,**不包含 system 指令**。''' - self.provider_config = provider_config - self.provider_settings = provider_settings self.curr_personality: Personality = default_persona @@ -58,14 +78,6 @@ class Provider(abc.ABC): self.session_memory[history.session_id] = json.loads(history.content) except BaseException as e: logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。") - - def set_model(self, model_name: str): - '''设置当前使用的模型名称''' - self.model_name = model_name - - def get_model(self) -> str: - '''获得当前使用的模型名称''' - return self.model_name @abc.abstractmethod def get_current_key(self) -> str: @@ -133,17 +145,11 @@ class Provider(abc.ABC): '''重置某一个 session_id 的上下文''' raise NotImplementedError() - def meta(self) -> ProviderMeta: - '''获取 Provider 的元数据''' - return ProviderMeta( - id=self.provider_config['id'], - model=self.get_model(), - type=self.provider_config['type'] - ) + - -class STTProvider(): +class STTProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @@ -151,19 +157,15 @@ class STTProvider(): async def get_text(self, audio_url: str) -> str: '''获取音频的文本''' raise NotImplementedError() + + +class TTSProvider(AbstractProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config) + self.provider_config = provider_config + self.provider_settings = provider_settings - def set_model(self, model_name: str): - '''设置当前使用的模型名称''' - self.model_name = model_name - - def get_model(self) -> str: - '''获取当前使用的模型''' - return self.provider_config.get("model", "") - - def meta(self) -> ProviderMeta: - '''获取 Provider 的元数据''' - return ProviderMeta( - id=self.provider_config['id'], - model=self.get_model(), - type=self.provider_config['type'] - ) \ No newline at end of file + @abc.abstractmethod + async def get_audio(self, text: str) -> str: + '''获取文本的音频,返回音频文件路径''' + raise NotImplementedError() \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index ef70ff00c..9a0e72fb2 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,4 +1,3 @@ -import traceback import base64 import json @@ -110,7 +109,7 @@ class ProviderOpenAIOfficial(Provider): ) assert isinstance(completion, ChatCompletion) - logger.debug(f"completion: {completion.usage}") + logger.debug(f"completion: {completion}") if len(completion.choices) == 0: raise Exception("API 返回的 completion 为空。") diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py new file mode 100644 index 000000000..b3aa9a35d --- /dev/null +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -0,0 +1,40 @@ +import uuid +import os +from openai import AsyncOpenAI, NOT_GIVEN +from ..provider import TTSProvider +from ..entites import ProviderType +from ..register import register_provider_adapter + + +@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH) +class ProviderOpenAITTSAPI(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.chosen_api_key = provider_config.get("api_key", "") + self.voice = provider_config.get("voice", "alloy") + + self.client = AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=provider_config.get("api_base", None), + timeout=provider_config.get("timeout", NOT_GIVEN), + ) + + self.set_model(provider_config.get("model", None)) + + + async def get_audio(self, text: str) -> str: + path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav' + async with self.client.audio.speech.with_streaming_response.create( + model=self.model_name, + voice=self.voice, + response_format='wav', + input=text + ) as response: + with open(path, 'wb') as f: + async for chunk in response.iter_bytes(chunk_size=1024): + f.write(chunk) + return path \ No newline at end of file diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index 6f75aaa86..386b9b0ea 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -2,36 +2,29 @@ import wave from io import BytesIO async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: - import pysilk + import pilk with open(silk_path, "rb") as f: - input_data = f.read() - if input_data.startswith(b'\x02'): - input_data = input_data[1:] - input_io = BytesIO(input_data) - output_io = BytesIO() - pysilk.decode(input_io, output_io, 24000) - output_io.seek(0) - with wave.open(output_path, 'wb') as wav: - wav.setnchannels(1) - wav.setsampwidth(2) - wav.setframerate(24000) - wav.writeframes(output_io.read()) + pcm_path = f"{output_path}.pcm" + pilk.decode(silk_path, pcm_path) + with open(pcm_path, "rb") as pcm: + with wave.open(output_path, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(pcm.read()) + return output_path -async def wav_to_tencent_silk(wav_path: str) -> BytesIO: - import pysilk - +async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int: + '''返回 duration''' + import pilk + + # wav to pcm with wave.open(wav_path, 'rb') as wav: - wav_data = wav.readframes(wav.getnframes()) - wav_data = BytesIO(wav_data) - output_io = BytesIO() - pysilk.encode(wav_data, output_io, 24000) - output_io.seek(0) + pcm_path = f"{wav_path}.pcm" + with open(pcm_path, "wb") as f: + f.write(wav.readframes(wav.getnframes())) - # 在首字节添加 \x02 - silk_data = output_io.read() - silk_data_with_prefix = b'\x02' + silk_data - - return BytesIO(silk_data_with_prefix) \ No newline at end of file + return pilk.encode(pcm_path, output_path, pcm_rate=24000, tencent=True) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 432d5bb19..5a3e75ee1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,4 @@ pyjwt apscheduler docstring_parser aiodocker -silk-python \ No newline at end of file +pilk \ No newline at end of file