Merge pull request #251 from Soulter/feat-tts

适配 OpenAI TTS API,并支持 Napcat,Gewechat,Lagrange 的语音输出
This commit is contained in:
Soulter
2025-01-25 19:51:22 +08:00
committed by GitHub
16 changed files with 281 additions and 96 deletions
+29
View File
@@ -40,6 +40,10 @@ DEFAULT_CONFIG = {
"enable": False,
"provider_id": "",
},
"provider_tts_settings": {
"enable": False,
"provider_id": "",
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -371,6 +375,14 @@ CONFIG_METADATA_2 = {
"type": "openai_whisper_selfhost",
"model": "tiny",
},
"openai_tts(API)": {
"id": "openai_tts",
"type": "openai_tts_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "tts-1",
},
},
"items": {
"whisper_hint": {
@@ -570,6 +582,23 @@ CONFIG_METADATA_2 = {
},
},
},
"provider_tts_settings": {
"description": "文本转语音(TTS)",
"type": "object",
"items": {
"enable": {
"description": "启用文本转语音(TTS)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID,不填则默认第一个TTS提供商",
"type": "string",
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
},
},
"misc_config_group": {
-1
View File
@@ -7,7 +7,6 @@ from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
+1 -1
View File
@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: T.Optional[str] # 用这个
path: T.Optional[str]
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
+6 -1
View File
@@ -139,7 +139,7 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
@@ -148,5 +148,10 @@ class MessageEventResult(MessageChain):
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
CommandResult = MessageEventResult
+23 -2
View File
@@ -1,11 +1,12 @@
import time
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply
from astrbot.core.message.components import Plain, Image, At, Reply, Record
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -16,6 +17,7 @@ class ResultDecorateStage:
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
self.t2i = ctx.astrbot_config['t2i']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
@@ -32,9 +34,28 @@ class ResultDecorateStage:
# 回复前缀
if self.reply_prefix:
result.chain.insert(0, Plain(self.reply_prefix))
# TTS
if self.use_tts and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
plain_str += " " + comp.text
else:
break
if plain_str:
try:
audio_path = await tts_provider.get_audio(plain_str)
logger.info("TTS 结果: " + audio_path)
if audio_path:
result.chain = [Record(file=audio_path, url=audio_path)]
except BaseException:
traceback.print_exc()
logger.error("TTS 失败,使用文本发送。")
# 文本转图片
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
@@ -3,7 +3,7 @@ import random
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Record
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -20,16 +20,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
d = segment.toDict()
if isinstance(segment, Plain):
d['type'] = 'text'
if isinstance(segment, Image):
if isinstance(segment, (Image, Record)):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
image_base64 = file_to_base64(segment.file[8:])
bs64_data = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
image_base64 = file_to_base64(image_file_path)
bs64_data = file_to_base64(image_file_path)
else:
bs64_data = file_to_base64(segment.file)
d['data'] = {
'file': image_base64,
'file': bs64_data,
}
ret.append(d)
return ret
@@ -163,7 +163,7 @@ class SimpleGewechatClient():
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
file_path = f"data/temp/{file_id}.jpg"
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
async def _set_callback_url(self):
@@ -185,11 +185,7 @@ class SimpleGewechatClient():
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
async def start_polling(self):
# 设置回调
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host=self.host,
port=self.port,
@@ -327,4 +323,23 @@ class SimpleGewechatClient():
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
logger.debug(f"发送图片结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
@@ -1,13 +1,24 @@
import random
import asyncio
import wave
import uuid
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Record
from .client import SimpleGewechatClient
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
@@ -39,18 +50,53 @@ class GewechatPlatformEvent(AstrMessageEvent):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
with open(comp.file[8:], "rb") as f:
img_path = save_temp_img(f.read())
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
if not img_path:
logger.error("无法获取到图片路径。")
return
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
temp_directory = os.path.abspath('data/temp')
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path).split(".")[0]
file_id = os.path.basename(img_path)
img_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await self.client.post_image(to_wxid, img_url)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
duration = await wav_to_tencent_silk(record_path, silk_path)
print(f"duration: {duration}, {silk_path}")
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
# temp_directory = os.path.abspath('data/temp')
# record_path = os.path.abspath(record_path)
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
# with open(record_path, "rb") as f:
# record_path = f"data/temp/{uuid.uuid4()}.wav"
# with open(record_path, "wb") as f2:
# f2.write(f.read())
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{self.client.file_server_url}/{file_id}"
await self.client.post_voice(to_wxid, record_url, duration*1000)
await super().send(message)
@@ -80,4 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
else:
image_base64 = file_to_base64(i.file).replace("base64://", "")
image_file_path = i.file
return plain_text, image_base64, image_file_path
@@ -32,6 +32,10 @@ class WebChatMessageEvent(AstrMessageEvent):
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
web_chat_back_queue.put_nowait(None)
await super().send(message)
+30 -3
View File
@@ -1,6 +1,6 @@
import traceback
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, Personality
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
@@ -64,11 +64,15 @@ class ProviderManager():
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
@@ -103,6 +107,8 @@ class ProviderManager():
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
@@ -119,8 +125,10 @@ class ProviderManager():
continue
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_settings.get("enable", False)
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
@@ -138,6 +146,18 @@ class ProviderManager():
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
@@ -167,11 +187,18 @@ class ProviderManager():
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if self.provider_stt_settings.get("enable"):
if not self.curr_stt_provider_inst:
if stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
def get_insts(self):
return self.provider_insts
+39 -37
View File
@@ -24,9 +24,32 @@ class ProviderMeta():
id: str
model: str
type: str
class AbstractProvider(abc.ABC):
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class Provider(abc.ABC):
class Provider(AbstractProvider):
def __init__(
self,
provider_config: dict,
@@ -35,14 +58,11 @@ class Provider(abc.ABC):
db_helper: BaseDatabase = None,
default_persona: Personality = None
) -> None:
self.model_name = ""
'''当前使用的模型名称'''
super().__init__(provider_config)
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_config = provider_config
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
@@ -58,14 +78,6 @@ class Provider(abc.ABC):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
@abc.abstractmethod
def get_current_key(self) -> str:
@@ -133,17 +145,11 @@ class Provider(abc.ABC):
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class STTProvider():
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@@ -151,19 +157,15 @@ class STTProvider():
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获取当前使用的模型'''
return self.provider_config.get("model", "")
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
'''获取文本的音频,返回音频文件路径'''
raise NotImplementedError()
@@ -1,4 +1,3 @@
import traceback
import base64
import json
@@ -110,7 +109,7 @@ class ProviderOpenAIOfficial(Provider):
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion.usage}")
logger.debug(f"completion: {completion}")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
@@ -0,0 +1,40 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
class ProviderOpenAITTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.voice = provider_config.get("voice", "alloy")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def get_audio(self, text: str) -> str:
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice=self.voice,
response_format='wav',
input=text
) as response:
with open(path, 'wb') as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
return path
+19 -26
View File
@@ -2,36 +2,29 @@ import wave
from io import BytesIO
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
import pysilk
import pilk
with open(silk_path, "rb") as f:
input_data = f.read()
if input_data.startswith(b'\x02'):
input_data = input_data[1:]
input_io = BytesIO(input_data)
output_io = BytesIO()
pysilk.decode(input_io, output_io, 24000)
output_io.seek(0)
with wave.open(output_path, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(output_io.read())
pcm_path = f"{output_path}.pcm"
pilk.decode(silk_path, pcm_path)
with open(pcm_path, "rb") as pcm:
with wave.open(output_path, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(pcm.read())
return output_path
async def wav_to_tencent_silk(wav_path: str) -> BytesIO:
import pysilk
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
'''返回 duration'''
import pilk
# wav to pcm
with wave.open(wav_path, 'rb') as wav:
wav_data = wav.readframes(wav.getnframes())
wav_data = BytesIO(wav_data)
output_io = BytesIO()
pysilk.encode(wav_data, output_io, 24000)
output_io.seek(0)
pcm_path = f"{wav_path}.pcm"
with open(pcm_path, "wb") as f:
f.write(wav.readframes(wav.getnframes()))
# 在首字节添加 \x02
silk_data = output_io.read()
silk_data_with_prefix = b'\x02' + silk_data
return BytesIO(silk_data_with_prefix)
return pilk.encode(pcm_path, output_path, pcm_rate=24000, tencent=True)
+1 -1
View File
@@ -17,4 +17,4 @@ pyjwt
apscheduler
docstring_parser
aiodocker
silk-python
pilk