fix: add supports for Whisper with QQ amr audio file
* fix: Whisper API对QQ语音amr文件的支持 * Update whisper_api_source.py * fix: cleanup temporary files in Whisper API --------- Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -6,7 +6,10 @@ from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.tencent_record_helper import (
|
||||
convert_to_pcm_wav,
|
||||
tencent_silk_to_wav,
|
||||
)
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import STTProvider
|
||||
@@ -35,18 +38,28 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
|
||||
self.set_model(provider_config.get("model"))
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
async def _get_audio_format(self, file_path):
|
||||
# 定义要检测的头部字节
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
amr_header = b"#!AMR"
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
return False
|
||||
return "silk"
|
||||
|
||||
if amr_header in file_header:
|
||||
return "amr"
|
||||
return None
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
"""Only supports mp3, mp4, mpeg, m4a, wav, webm"""
|
||||
is_tencent = False
|
||||
output_path = None
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
if "multimedia.nt.qq.com.cn" in audio_url:
|
||||
@@ -62,16 +75,35 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
file_format = await self._get_audio_format(audio_url)
|
||||
|
||||
# 判断是否需要转换
|
||||
if file_format in ["silk", "amr"]:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
|
||||
if file_format == "silk":
|
||||
logger.info(
|
||||
"Converting silk file to wav using tencent_silk_to_wav..."
|
||||
)
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
elif file_format == "amr":
|
||||
logger.info(
|
||||
"Converting amr file to wav using convert_to_pcm_wav..."
|
||||
)
|
||||
await convert_to_pcm_wav(audio_url, output_path)
|
||||
|
||||
audio_url = output_path
|
||||
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
file=("audio.wav", open(audio_url, "rb")),
|
||||
)
|
||||
|
||||
# remove temp file
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(audio_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temp file {audio_url}: {e}")
|
||||
return result.text
|
||||
|
||||
Reference in New Issue
Block a user