diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index b2ce1bfe0..a08a94543 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -3,6 +3,7 @@ """ import os + from astrbot.core.utils.astrbot_path import get_astrbot_data_path VERSION = "3.5.17" @@ -12,7 +13,7 @@ DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db") DEFAULT_CONFIG = { "config_version": 2, "platform_settings": { - "plugin_enable":[], + "plugin_enable": [], "unique_session": False, "rate_limit": { "time": 60, @@ -976,6 +977,18 @@ CONFIG_METADATA_2 = { "api_base": "https://openspeech.bytedance.com/api/v1/tts", "timeout": 20, }, + "Gemini TTS": { + "id": "gemini_tts", + "type": "gemini_tts", + "provider_type": "text_to_speech", + "enable": False, + "gemini_tts_api_key": "", + "gemini_tts_api_base": "", + "gemini_tts_timeout": 20, + "gemini_tts_model": "gemini-2.5-flash-preview-tts", + "gemini_tts_prefix": "", + "gemini_tts_voice_name": "Leda", + }, "OpenAI Embedding": { "id": "openai_embedding", "type": "openai_embedding", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 382f469fe..5886b8083 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,12 +1,14 @@ -import traceback import asyncio -from astrbot.core.config.astrbot_config import AstrBotConfig -from .provider import Provider, STTProvider, TTSProvider, Personality -from .entities import ProviderType +import traceback from typing import List -from astrbot.core.db import BaseDatabase -from .register import provider_cls_map, llm_tools + from astrbot.core import logger, sp +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.db import BaseDatabase + +from .entities import ProviderType +from .provider import Personality, Provider, STTProvider, TTSProvider +from .register import llm_tools, provider_cls_map class ProviderManager: @@ -38,13 +40,11 @@ class ProviderManager: begin_dialogs = [] user_turn = True for dialog in begin_dialogs: - bd_processed.append( - { - "role": "user" if user_turn else "assistant", - "content": dialog, - "_no_save": None, # 不持久化到 db - } - ) + bd_processed.append({ + "role": "user" if user_turn else "assistant", + "content": dialog, + "_no_save": None, # 不持久化到 db + }) user_turn = not user_turn if mood_imitation_dialogs: if len(mood_imitation_dialogs) % 2 != 0: @@ -253,6 +253,10 @@ class ProviderManager: from .sources.volcengine_tts import ( ProviderVolcengineTTS as ProviderVolcengineTTS, ) + case "gemini_tts": + from .sources.gemini_tts_source import ( + ProviderGeminiTTSAPI as ProviderGeminiTTSAPI, + ) case "openai_embedding": from .sources.openai_embedding_source import ( OpenAIEmbeddingProvider as OpenAIEmbeddingProvider, diff --git a/astrbot/core/provider/sources/gemini_tts_source.py b/astrbot/core/provider/sources/gemini_tts_source.py new file mode 100644 index 000000000..7f3b48204 --- /dev/null +++ b/astrbot/core/provider/sources/gemini_tts_source.py @@ -0,0 +1,85 @@ +import os +import uuid +import wave + +from google import genai +from google.genai import types + +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "gemini_tts", "Gemini TTS API", provider_type=ProviderType.TEXT_TO_SPEECH +) +class ProviderGeminiTTSAPI(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + api_key: str = provider_config.get("gemini_tts_api_key", "") + api_base: str | None = provider_config.get("gemini_tts_api_base") + timeout: int = int(provider_config.get("gemini_tts_timeout", 20)) + http_options = types.HttpOptions(timeout=timeout * 1000) + + if api_base: + if api_base.endswith("/"): + api_base = api_base[:-1] + http_options.base_url = api_base + + self.client = genai.Client(api_key=api_key, http_options=http_options).aio + self.model: str = provider_config.get( + "gemini_tts_model", "gemini-2.5-flash-preview-tts" + ) + self.prefix: str | None = provider_config.get( + "gemini_tts_prefix", + ) + self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda") + + async def get_audio(self, text: str) -> str: + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav") + if self.prefix: + prompt = f"{self.prefix}: {text}" + else: + prompt = text + response = await self.client.models.generate_content( + model=self.model, + contents=prompt, + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=self.voice_name, + ) + ) + ), + ), + ) + + # 不想看类型检查报错 + if ( + not response.candidates + or not response.candidates[0].content + or not response.candidates[0].content.parts + or not response.candidates[0].content.parts[0].inline_data + or not response.candidates[0].content.parts[0].inline_data.data + ): + raise Exception("No audio content returned from Gemini TTS API.") + + with wave.open(path, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(response.candidates[0].content.parts[0].inline_data.data) + + if not os.path.exists(path): + raise Exception(f"Failed to save audio to {path}.") + + return path