From baae842210d81996b59f1b6a87849dbcac66ef97 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 24 Jan 2025 13:41:13 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20napcat=20=E4=B8=8B=E8=AF=AD=E9=9F=B3?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E6=8E=A5=E6=94=B6=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/pipeline/preprocess_stage/stage.py | 9 +---- .../provider/sources/whisper_api_source.py | 35 ++---------------- .../sources/whisper_selfhosted_source.py | 37 +++---------------- astrbot/core/utils/tencent_record_helper.py | 37 +++++++++++++++++++ 4 files changed, 48 insertions(+), 70 deletions(-) create mode 100644 astrbot/core/utils/tencent_record_helper.py diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index b55a1a82d..64242010b 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -29,13 +29,8 @@ class PreProcessStage(Stage): message_chain = event.get_messages() for idx, component in enumerate(message_chain): if isinstance(component, Record) and component.url: - - path = component.url - - path.removeprefix("file:///") - + path = component.url.removeprefix("file://") retry = 5 - for i in range(retry): try: result = await stt_provider.get_text(audio_url=path) @@ -48,7 +43,7 @@ class PreProcessStage(Stage): except FileNotFoundError as e: # napcat workaround logger.warning(e) - logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}") + logger.warning(f"重试中: {i + 1}/{retry}") await asyncio.sleep(0.5) continue except BaseException as e: diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 8ff33fa9b..3190c042e 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -1,12 +1,12 @@ import uuid import os -import io from openai import AsyncOpenAI, NOT_GIVEN from ..provider import STTProvider from ..entites import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger +from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav @register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT) class ProviderOpenAIWhisperAPI(STTProvider): @@ -33,34 +33,6 @@ class ProviderOpenAIWhisperAPI(STTProvider): output_path = ff.convert(path, os.path.join('data/temp', filename)) return output_path - async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str: - import wave - - with wave.open(output_path, 'wb') as wav: - wav.setnchannels(1) - wav.setsampwidth(2) - wav.setframerate(24000) - wav.writeframes(input_io.read()) - - return output_path - - async def _convert_silk(self, path: str) -> str: - import pysilk - filename = str(uuid.uuid4()) + '.wav' - output_path = os.path.join('data/temp', filename) - with open(path, "rb") as f: - input_data = f.read() - if input_data.startswith(b'\x02'): - # tencent 我爱你 - input_data = input_data[1:] - input_io = io.BytesIO(input_data) - output_io = io.BytesIO() - pysilk.decode(input_io, output_io, 24000) - output_io.seek(0) - await self._pcm_to_wav(output_io, output_path) - - return output_path - async def _is_silk_file(self, file_path): silk_header = b"SILK" with open(file_path, "rb") as f: @@ -91,8 +63,9 @@ class ProviderOpenAIWhisperAPI(STTProvider): is_silk = await self._is_silk_file(audio_url) if is_silk: logger.info("Converting silk file to wav ...") - audio_url = await self._convert_silk(audio_url) - + output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav') + await tencent_silk_to_wav(audio_url, output_path) + audio_url = output_path result = await self.client.audio.transcriptions.create( model=self.model_name, diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 94c9a2be2..6b95a57b8 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -1,6 +1,5 @@ import uuid import os -import io import asyncio import whisper from ..provider import STTProvider @@ -8,7 +7,7 @@ from ..entites import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger - +from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav @register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT) class ProviderOpenAIWhisperSelfHost(STTProvider): @@ -34,34 +33,6 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): output_path = ff.convert(path, os.path.join('data/temp', filename)) return output_path - async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str: - import wave - - with wave.open(output_path, 'wb') as wav: - wav.setnchannels(1) - wav.setsampwidth(2) - wav.setframerate(24000) - wav.writeframes(input_io.read()) - - return output_path - - async def _convert_silk(self, path: str) -> str: - import pysilk - filename = str(uuid.uuid4()) + '.wav' - output_path = os.path.join('data/temp', filename) - with open(path, "rb") as f: - input_data = f.read() - if input_data.startswith(b'\x02'): - # tencent 我爱你 - input_data = input_data[1:] - input_io = io.BytesIO(input_data) - output_io = io.BytesIO() - pysilk.decode(input_io, output_io, 24000) - output_io.seek(0) - await self._pcm_to_wav(output_io, output_path) - - return output_path - async def _is_silk_file(self, file_path): silk_header = b"SILK" with open(file_path, "rb") as f: @@ -93,7 +64,9 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): is_silk = await self._is_silk_file(audio_url) if is_silk: logger.info("Converting silk file to wav ...") - audio_url = await self._convert_silk(audio_url) - + output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav') + await tencent_silk_to_wav(audio_url, output_path) + audio_url = output_path + result = await loop.run_in_executor(None, self.model.transcribe, audio_url) return result['text'] \ No newline at end of file diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py new file mode 100644 index 000000000..6f75aaa86 --- /dev/null +++ b/astrbot/core/utils/tencent_record_helper.py @@ -0,0 +1,37 @@ +import wave +from io import BytesIO + +async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: + import pysilk + + with open(silk_path, "rb") as f: + input_data = f.read() + if input_data.startswith(b'\x02'): + input_data = input_data[1:] + input_io = BytesIO(input_data) + output_io = BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + with wave.open(output_path, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(output_io.read()) + + return output_path + +async def wav_to_tencent_silk(wav_path: str) -> BytesIO: + import pysilk + + with wave.open(wav_path, 'rb') as wav: + wav_data = wav.readframes(wav.getnframes()) + wav_data = BytesIO(wav_data) + output_io = BytesIO() + pysilk.encode(wav_data, output_io, 24000) + output_io.seek(0) + + # 在首字节添加 \x02 + silk_data = output_io.read() + silk_data_with_prefix = b'\x02' + silk_data + + return BytesIO(silk_data_with_prefix) \ No newline at end of file