feat: 增加Gemini TTS API实现
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.17"
|
||||
@@ -12,7 +13,7 @@ DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
DEFAULT_CONFIG = {
|
||||
"config_version": 2,
|
||||
"platform_settings": {
|
||||
"plugin_enable":[],
|
||||
"plugin_enable": [],
|
||||
"unique_session": False,
|
||||
"rate_limit": {
|
||||
"time": 60,
|
||||
@@ -976,6 +977,18 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": 20,
|
||||
},
|
||||
"Gemini TTS": {
|
||||
"id": "gemini_tts",
|
||||
"type": "gemini_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"gemini_tts_api_key": "",
|
||||
"gemini_tts_api_base": "",
|
||||
"gemini_tts_timeout": 20,
|
||||
"gemini_tts_model": "gemini-2.5-flash-preview-tts",
|
||||
"gemini_tts_prefix": "",
|
||||
"gemini_tts_voice_name": "Leda",
|
||||
},
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
"type": "openai_embedding",
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||
from .entities import ProviderType
|
||||
import traceback
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from .register import provider_cls_map, llm_tools
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
@@ -38,13 +40,11 @@ class ProviderManager:
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
bd_processed.append({
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
})
|
||||
user_turn = not user_turn
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
@@ -253,6 +253,10 @@ class ProviderManager:
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import uuid
|
||||
import wave
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"gemini_tts", "Gemini TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderGeminiTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
api_key: str = provider_config.get("gemini_tts_api_key", "")
|
||||
api_base: str | None = provider_config.get("gemini_tts_api_base")
|
||||
timeout: int = int(provider_config.get("gemini_tts_timeout", 20))
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
|
||||
if api_base:
|
||||
if api_base.endswith("/"):
|
||||
api_base = api_base[:-1]
|
||||
http_options.base_url = api_base
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
self.model: str = provider_config.get(
|
||||
"gemini_tts_model", "gemini-2.5-flash-preview-tts"
|
||||
)
|
||||
self.prefix: str | None = provider_config.get(
|
||||
"gemini_tts_prefix",
|
||||
)
|
||||
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
|
||||
if self.prefix:
|
||||
prompt = f"{self.prefix}: {text}"
|
||||
else:
|
||||
prompt = text
|
||||
response = await self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name=self.voice_name,
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# 不想看类型检查报错
|
||||
if (
|
||||
not response.candidates
|
||||
or not response.candidates[0].content
|
||||
or not response.candidates[0].content.parts
|
||||
or not response.candidates[0].content.parts[0].inline_data
|
||||
or not response.candidates[0].content.parts[0].inline_data.data
|
||||
):
|
||||
raise Exception("No audio content returned from Gemini TTS API.")
|
||||
|
||||
with wave.open(path, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise Exception(f"Failed to save audio to {path}.")
|
||||
|
||||
return path
|
||||
Reference in New Issue
Block a user