From be662b913c4b01643132aa3c70fc6bd988d349a0 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 11 Jan 2025 17:19:28 +0800
Subject: [PATCH 1/6] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20Whisper=20STT?=
=?UTF-8?q?=EF=BC=8C=E5=B9=B6=E9=80=82=E9=85=8D=20Tencent=20=E8=AF=AD?=
=?UTF-8?q?=E9=9F=B3=E6=A0=BC=E5=BC=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/api/provider/__init__.py | 4 +-
astrbot/core/config/default.py | 32 ++++++-
astrbot/core/message/components.py | 2 +-
astrbot/core/pipeline/__init__.py | 3 +
.../core/pipeline/preprocess_stage/stage.py | 54 +++++++++++
astrbot/core/provider/__init__.py | 3 +-
astrbot/core/provider/entites.py | 17 +++-
astrbot/core/provider/manager.py | 47 +++++++--
astrbot/core/provider/provider.py | 27 ++++++
astrbot/core/provider/register.py | 16 +++-
.../provider/sources/whisper_api_source.py | 95 +++++++++++++++++++
astrbot/core/star/context.py | 12 +--
dashboard/package.json | 2 -
requirements.txt | 3 +-
14 files changed, 284 insertions(+), 33 deletions(-)
create mode 100644 astrbot/core/pipeline/preprocess_stage/stage.py
create mode 100644 astrbot/core/provider/sources/whisper_api_source.py
diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py
index 377f8d4b3..17e379478 100644
--- a/astrbot/api/provider/__init__.py
+++ b/astrbot/api/provider/__init__.py
@@ -1,2 +1,2 @@
-from astrbot.core.provider import Provider, Personality, ProviderMetaData
-from astrbot.core.provider.entites import ProviderRequest
\ No newline at end of file
+from astrbot.core.provider import Provider, STTProvider, Personality
+from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
\ No newline at end of file
diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 43a010d45..983ca3cd3 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -33,6 +33,10 @@ DEFAULT_CONFIG = {
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"prompt_prefix": "",
},
+ "provider_stt_settings": {
+ "enable": False,
+ "provider_id": "",
+ },
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -315,6 +319,14 @@ CONFIG_METADATA_2 = {
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
+ },
+ "whisper(API)": {
+ "id": "whisper",
+ "type": "openai_whisper_api",
+ "enable": True,
+ "api_key": "",
+ "api_base": "",
+ "model": "whisper-1",
}
},
"items": {
@@ -416,7 +428,8 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
- "hint": "是否启用大语言模型聊天。默认启用",
+ "hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
+ "obvious_hint": True
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
@@ -450,6 +463,23 @@ CONFIG_METADATA_2 = {
},
},
},
+ "provider_stt_settings": {
+ "description": "语音转文本(STT)",
+ "type": "object",
+ "items": {
+ "enable": {
+ "description": "启用语音转文本(STT)",
+ "type": "bool",
+ "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
+ "obvious_hint": True
+ },
+ "provider_id": {
+ "description": "提供商 ID,不填则默认第一个STT提供商",
+ "type": "string",
+ "hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
+ },
+ },
+ },
},
},
"misc_config_group": {
diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py
index 50e81c86d..74d90c535 100644
--- a/astrbot/core/message/components.py
+++ b/astrbot/core/message/components.py
@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
- path: T.Optional[str]
+ path: T.Optional[str] # 用这个
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py
index 108b1f134..bdba27699 100644
--- a/astrbot/core/pipeline/__init__.py
+++ b/astrbot/core/pipeline/__init__.py
@@ -3,6 +3,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .content_safety_check.stage import ContentSafetyCheckStage
+from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
@@ -12,6 +13,7 @@ STAGES_ORDER = [
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitCheckStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
+ "PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
@@ -21,6 +23,7 @@ __all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"ContentSafetyCheckStage",
+ "PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py
new file mode 100644
index 000000000..a28e15485
--- /dev/null
+++ b/astrbot/core/pipeline/preprocess_stage/stage.py
@@ -0,0 +1,54 @@
+import traceback
+import asyncio
+from typing import Union, AsyncGenerator
+from ..stage import Stage, register_stage
+from ..context import PipelineContext
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core import logger
+from astrbot.core.message.components import Plain, Record
+
+@register_stage
+class PreProcessStage(Stage):
+
+ async def initialize(self, ctx: PipelineContext) -> None:
+ self.ctx = ctx
+ self.config = ctx.astrbot_config
+ self.plugin_manager = ctx.plugin_manager
+
+ self.stt_settings: dict = self.config.get('provider_stt_settings', {})
+
+
+ async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
+ '''在处理事件之前的预处理'''
+
+ if self.stt_settings.get('enable', False):
+ # STT 处理
+ # TODO: 独立
+ stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
+ if stt_provider:
+ message_chain = event.get_messages()
+ for idx, component in enumerate(message_chain):
+ if isinstance(component, Record) and component.path:
+
+ path = component.path
+
+ retry = 5
+
+ for i in range(retry):
+ try:
+ result = await stt_provider.get_text(audio_url=path)
+ if result:
+ logger.info("语音转文本结果: " + result)
+ message_chain[idx] = Plain(result)
+ event.message_str += result
+ event.message_obj.message_str += result
+ break
+ except FileNotFoundError:
+ # napcat workaround
+ logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
+ await asyncio.sleep(0.5)
+ continue
+ except BaseException as e:
+ logger.error(traceback.format_exc())
+ logger.error(f"语音转文本失败: {e}")
+ break
diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py
index b1dfe8732..246a74e57 100644
--- a/astrbot/core/provider/__init__.py
+++ b/astrbot/core/provider/__init__.py
@@ -1,4 +1,4 @@
-from .provider import Provider, Personality
+from .provider import Provider, Personality, STTProvider
from .entites import ProviderMetaData
@@ -6,4 +6,5 @@ __all__ = [
"Provider",
"Personality",
"ProviderMetaData",
+ "STTProvider"
]
\ No newline at end of file
diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py
index 8dae2680d..0a733e3b9 100644
--- a/astrbot/core/provider/entites.py
+++ b/astrbot/core/provider/entites.py
@@ -1,13 +1,22 @@
+import enum
from dataclasses import dataclass, field
-from typing import List, Dict
+from typing import List, Dict, Type
from .func_tool_manager import FuncCall
+class ProviderType(enum.Enum):
+ CHAT_COMPLETION = "chat_completion"
+ SPEECH_TO_TEXT = "speech_to_text"
+ TEXT_TO_SPEECH = "text_to_speech"
+
@dataclass
class ProviderMetaData():
- type: str # 提供商适配器名称,如 openai, ollama
- desc: str = "" # 提供商适配器描述.
-
+ type: str
+ '''提供商适配器名称,如 openai, ollama'''
+ desc: str = ""
+ '''提供商适配器描述.'''
+ provider_type: ProviderType = ProviderType.CHAT_COMPLETION
+ cls_type: Type = None
@dataclass
class ProviderRequest():
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 19075338c..4f70c33f0 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -1,6 +1,7 @@
import traceback
from astrbot.core.config.astrbot_config import AstrBotConfig
-from .provider import Provider
+from .provider import Provider, STTProvider
+from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
@@ -11,10 +12,17 @@ class ProviderManager():
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
+ self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
+
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
+ self.stt_provider_insts: List[STTProvider] = []
+ '''加载的 Speech To Text Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
+ '''当前使用的 Provider 实例'''
+ self.curr_stt_provider_inst: STTProvider = None
+ '''当前使用的 Speech To Text Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
@@ -43,6 +51,8 @@ class ProviderManager():
from .sources.dify_source import ProviderDify # noqa: F401
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
+ case "openai_whisper_api":
+ from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
async def initialize(self):
@@ -53,14 +63,29 @@ class ProviderManager():
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
selected_provider_id = sp.get("curr_provider")
- cls_type = provider_cls_map[provider_config['type']]
+ selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
+
+ provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
- inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
- self.provider_insts.append(inst)
- if selected_provider_id == provider_config['id']:
- self.curr_provider_inst = inst
- logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
+ # 按任务实例化提供商
+
+ if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
+ # STT 任务
+ inst = provider_metadata.cls_type(provider_config, self.provider_settings)
+ self.stt_provider_insts.append(inst)
+ if selected_stt_provider_id == provider_config['id']:
+ self.curr_stt_provider_inst = inst
+ logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
+
+ elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
+ # 文本生成任务
+ inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
+ self.provider_insts.append(inst)
+ if selected_provider_id == provider_config['id']:
+ self.curr_provider_inst = inst
+ logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
+
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
@@ -68,8 +93,14 @@ class ProviderManager():
if len(self.provider_insts) > 0 and not self.curr_provider_inst:
self.curr_provider_inst = self.provider_insts[0]
+ if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst:
+ self.curr_stt_provider_inst = self.stt_provider_insts[0]
+
if not self.curr_provider_inst:
- logger.warning("未启用任何提供商适配器。")
+ logger.warning("未启用任何用于 文本生成 的提供商适配器。")
+ if self.provider_stt_settings.get("enable"):
+ if not self.curr_stt_provider_inst:
+ logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
def get_insts(self):
return self.provider_insts
diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py
index 553dd78dc..30c6b02d0 100644
--- a/astrbot/core/provider/provider.py
+++ b/astrbot/core/provider/provider.py
@@ -125,6 +125,33 @@ class Provider(abc.ABC):
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
+ def meta(self) -> ProviderMeta:
+ '''获取 Provider 的元数据'''
+ return ProviderMeta(
+ id=self.provider_config['id'],
+ model=self.get_model(),
+ type=self.provider_config['type']
+ )
+
+
+class STTProvider():
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
+ self.provider_config = provider_config
+ self.provider_settings = provider_settings
+
+ @abc.abstractmethod
+ async def get_text(self, audio_url: str) -> str:
+ '''获取音频的文本'''
+ raise NotImplementedError()
+
+ def set_model(self, model_name: str):
+ '''设置当前使用的模型名称'''
+ self.model_name = model_name
+
+ def get_model(self) -> str:
+ '''获取当前使用的模型'''
+ return self.provider_config.get("model", "")
+
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py
index 00c3ad877..61e64408f 100644
--- a/astrbot/core/provider/register.py
+++ b/astrbot/core/provider/register.py
@@ -1,16 +1,20 @@
from typing import List, Dict, Type
-from .entites import ProviderMetaData
+from .entites import ProviderMetaData, ProviderType
from astrbot.core import logger
from .func_tool_manager import FuncCall
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
-provider_cls_map: Dict[str, Type] = {}
-'''维护了 Provider 类型名称和 Provider 类的映射'''
+provider_cls_map: Dict[str, ProviderMetaData] = {}
+'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
llm_tools = FuncCall()
-def register_provider_adapter(provider_type_name: str, desc: str):
+def register_provider_adapter(
+ provider_type_name: str,
+ desc: str,
+ provider_type: ProviderType = ProviderType.CHAT_COMPLETION
+):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
@@ -19,9 +23,11 @@ def register_provider_adapter(provider_type_name: str, desc: str):
pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
+ provider_type=provider_type,
+ cls_type=cls
)
provider_registry.append(pm)
- provider_cls_map[provider_type_name] = cls
+ provider_cls_map[provider_type_name] = pm
logger.debug(f"Provider {provider_type_name} 已注册")
return cls
diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py
new file mode 100644
index 000000000..7159bf9d4
--- /dev/null
+++ b/astrbot/core/provider/sources/whisper_api_source.py
@@ -0,0 +1,95 @@
+import uuid
+import os
+import io
+from openai import AsyncOpenAI, NOT_GIVEN
+from ..provider import STTProvider
+from ..entites import ProviderType
+from astrbot.core.utils.io import download_file
+from ..register import register_provider_adapter
+from astrbot.core import logger
+
+@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
+class ProviderOpenAIWhisperAPI(STTProvider):
+ def __init__(
+ self,
+ provider_config: dict,
+ provider_settings: dict,
+ ) -> None:
+ super().__init__(provider_config, provider_settings)
+ self.chosen_api_key = provider_config.get("api_key", "")
+
+ self.client = AsyncOpenAI(
+ api_key=self.chosen_api_key,
+ base_url=provider_config.get("api_base", None),
+ timeout=provider_config.get("timeout", NOT_GIVEN),
+ )
+
+ self.set_model(provider_config.get("model", None))
+
+ async def _convert_audio(self, path: str) -> str:
+ from pyffmpeg import FFmpeg
+ filename = str(uuid.uuid4()) + '.mp3'
+ ff = FFmpeg()
+ output_path = ff.convert(path, os.path.join('data/temp', filename))
+ return output_path
+
+ async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
+ import wave
+
+ with wave.open(output_path, 'wb') as wav:
+ wav.setnchannels(1)
+ wav.setsampwidth(2)
+ wav.setframerate(24000)
+ wav.writeframes(input_io.read())
+
+ return output_path
+
+ async def _convert_silk(self, path: str) -> str:
+ import pysilk
+ filename = str(uuid.uuid4()) + '.wav'
+ output_path = os.path.join('data/temp', filename)
+ with open(path, "rb") as f:
+ input_data = f.read()
+ if input_data.startswith(b'\x02'):
+ # tencent 我爱你
+ input_data = input_data[1:]
+ input_io = io.BytesIO(input_data)
+ output_io = io.BytesIO()
+ pysilk.decode(input_io, output_io, 24000)
+ output_io.seek(0)
+ await self._pcm_to_wav(output_io, output_path)
+
+ return output_path
+
+ async def _is_silk_file(self, file_path):
+ silk_header = b"SILK"
+ with open(file_path, "rb") as f:
+ file_header = f.read(8)
+
+ if silk_header in file_header:
+ return True
+ else:
+ return False
+
+ async def get_text(self, audio_url: str) -> str:
+ '''only supports mp3, mp4, mpeg, m4a, wav, webm'''
+ if audio_url.startswith("http"):
+ name = str(uuid.uuid4())
+ path = os.path.join("data/temp", name)
+ audio_url = await download_file(audio_url, path)
+
+ if not os.path.exists(audio_url):
+ raise FileNotFoundError(f"文件不存在: {audio_url}")
+
+ if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
+ is_silk = await self._is_silk_file(audio_url)
+ if is_silk:
+ logger.info("Converting silk file to wav ...")
+ audio_url = await self._convert_silk(audio_url)
+
+
+ result = await self.client.audio.transcriptions.create(
+ model=self.model_name,
+ file=open(audio_url, "rb"),
+ )
+ return result.text
\ No newline at end of file
diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py
index 58810de96..b94b50a52 100644
--- a/astrbot/core/star/context.py
+++ b/astrbot/core/star/context.py
@@ -17,10 +17,6 @@ from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
-class StarCommand(TypedDict):
- full_command_name: str
- command_name: str
-
class Context:
'''
暴露给插件的接口上下文。
@@ -168,13 +164,13 @@ class Context:
def register_provider(self, provider: Provider):
'''
- 注册一个 LLM Provider。
+ 注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
- 通过 ID 获取 LLM Provider。
+ 通过 ID 获取 LLM Provider(Chat_Completion 类型)。
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
@@ -183,13 +179,13 @@ class Context:
def get_all_providers(self) -> List[Provider]:
'''
- 获取所有 LLM Provider。
+ 获取所有 LLM Provider(Chat_Completion 类型)。
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
- 获取当前使用的 LLM Provider。
+ 获取当前使用的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
diff --git a/dashboard/package.json b/dashboard/package.json
index 2888d4415..9b59cfa92 100644
--- a/dashboard/package.json
+++ b/dashboard/package.json
@@ -33,8 +33,6 @@
"vue3-apexcharts": "1.4.4",
"vue3-print-nb": "0.1.4",
"vuetify": "3.3.14",
- "xterm": "^5.3.0",
- "xterm-addon-fit": "^0.8.0",
"yup": "1.2.0"
},
"devDependencies": {
diff --git a/requirements.txt b/requirements.txt
index 965651870..432d5bb19 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,4 +16,5 @@ aiocqhttp
pyjwt
apscheduler
docstring_parser
-aiodocker
\ No newline at end of file
+aiodocker
+silk-python
\ No newline at end of file
From a09998f910b36e52ee09ffb96730f4b0f39725bb Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 11 Jan 2025 18:54:40 +0800
Subject: [PATCH 2/6] =?UTF-8?q?feat:=20webchat=20=E6=94=AF=E6=8C=81?=
=?UTF-8?q?=E8=AF=AD=E9=9F=B3=E8=BE=93=E5=85=A5?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../sources/webchat/webchat_adapter.py | 10 +-
astrbot/dashboard/routes/chat.py | 42 +++++-
dashboard/src/views/ChatPage.vue | 124 ++++++++++++++----
3 files changed, 148 insertions(+), 28 deletions(-)
diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py
index 87e9d5f7b..e2a438caf 100644
--- a/astrbot/core/platform/sources/webchat/webchat_adapter.py
+++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py
@@ -5,7 +5,7 @@ import os
from typing import Awaitable, Any
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
-from astrbot.api.message_components import Plain, Image # noqa: F403
+from astrbot.api.message_components import Plain, Image, Record # noqa: F403
from astrbot.api import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from .webchat_event import WebChatMessageEvent
@@ -70,6 +70,14 @@ class WebChatAdapter(Platform):
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
else:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
+ if payload['audio_url']:
+ if isinstance(payload['audio_url'], list):
+ for audio in payload['audio_url']:
+ path = os.path.join(self.imgs_dir, audio)
+ abm.message.append(Record(file=path, path=path))
+ else:
+ path = os.path.join(self.imgs_dir, payload['audio_url'])
+ abm.message.append(Record(file=path, path=path))
logger.debug(f"WebChatAdapter: {abm.message}")
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index cef7e563f..02f40d648 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -17,11 +17,14 @@ class ChatRoute(Route):
'/chat/get_conversation': ('GET', self.get_conversation),
'/chat/delete_conversation': ('GET', self.delete_conversation),
'/chat/get_file': ('GET', self.get_file),
- '/chat/post_image': ('POST', self.post_image)
+ '/chat/post_image': ('POST', self.post_image),
+ '/chat/post_file': ('POST', self.post_file)
}
self.db = db
self.register_routes()
self.imgs_dir = "data/webchat/imgs"
+
+ self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
async def get_file(self):
filename = request.args.get('filename')
@@ -30,7 +33,13 @@ class ChatRoute(Route):
try:
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
- return QuartResponse(f.read(), mimetype="image/jpeg")
+ if filename.endswith(".wav"):
+ return QuartResponse(f.read(), mimetype="audio/wav")
+ elif filename.split('.')[-1] in self.supported_imgs:
+ return QuartResponse(f.read(), mimetype="image/jpeg")
+ else:
+ return QuartResponse(f.read())
+
except FileNotFoundError:
return Response().error("File not found").__dict__
@@ -47,6 +56,25 @@ class ChatRoute(Route):
return Response().ok(data={
'filename': filename
}).__dict__
+
+ async def post_file(self):
+ post_data = await request.files
+ if 'file' not in post_data:
+ return Response().error("Missing key: file").__dict__
+
+ file = post_data['file']
+ filename = f"{str(uuid.uuid4())}"
+ print(file)
+ # 通过文件格式判断文件类型
+ if file.content_type.startswith('audio'):
+ filename += ".wav"
+
+ path = os.path.join(self.imgs_dir, filename)
+ await file.save(path)
+
+ return Response().ok(data={
+ 'filename': filename
+ }).__dict__
async def chat(self):
username = g.get('username', 'guest')
@@ -61,14 +89,16 @@ class ChatRoute(Route):
message = post_data['message']
conversation_id = post_data['conversation_id']
image_url = post_data.get('image_url')
- if not message and not image_url:
- return Response().error("Message and image_url are empty").__dict__
+ audio_url = post_data.get('audio_url')
+ if not message and not image_url and not audio_url:
+ return Response().error("Message and image_url and audio_url are empty").__dict__
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
await web_chat_queue.put((username, conversation_id, {
'message': message,
- 'image_url': image_url # list
+ 'image_url': image_url, # list
+ 'audio_url': audio_url
}))
async def stream():
@@ -98,6 +128,8 @@ class ChatRoute(Route):
}
if image_url:
new_his['image_url'] = image_url
+ if audio_url:
+ new_his['audio_url'] = audio_url
history.append(new_his)
for r in ret:
history.append({
diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue
index f4cb9b0e5..f3c0227a5 100644
--- a/dashboard/src/views/ChatPage.vue
+++ b/dashboard/src/views/ChatPage.vue
@@ -58,13 +58,21 @@ marked.setOptions({
{{ msg.message }}
-
+
+
+
+
+
@@ -79,26 +87,28 @@ marked.setOptions({
-
-
+ style="width: 100%; max-width: 850px;">
-
+
+
+
+
+
+
+
-
+
![]()
mdi-close-circle
+
+
+ 新录音
+ mdi-close-circle
+
+
+
@@ -128,7 +146,14 @@ export default {
conversations: [],
currCid: '',
stagedImagesUrl: [],
- loadingChat: false
+ loadingChat: false,
+
+ inputFieldLabel: '聊天吧!',
+
+ isRecording: false,
+ audioChunks: [],
+ stagedAudioUrl: "",
+ mediaRecorder: null
}
},
@@ -136,10 +161,54 @@ export default {
this.getConversations();
let inputField = document.getElementById('input-field');
inputField.addEventListener('paste', this.handlePaste);
-
},
methods: {
+
+ removeAudio() {
+ this.stagedAudioUrl = null;
+ },
+
+ async startRecording() {
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
+ this.mediaRecorder = new MediaRecorder(stream);
+ this.mediaRecorder.ondataavailable = (event) => {
+ this.audioChunks.push(event.data);
+ };
+ this.mediaRecorder.start();
+ this.isRecording = true;
+ this.inputFieldLabel = "录音中,请说话...";
+ },
+
+ async stopRecording() {
+ this.isRecording = false;
+ this.inputFieldLabel = "聊天吧!";
+ this.mediaRecorder.stop();
+ this.mediaRecorder.onstop = async () => {
+ const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' });
+ this.audioChunks = [];
+
+ const formData = new FormData();
+ formData.append('file', audioBlob);
+
+ try {
+ const response = await axios.post('/api/chat/post_file', formData, {
+ headers: {
+ 'Content-Type': 'multipart/form-data',
+ 'Authorization': 'Bearer ' + localStorage.getItem('token')
+ }
+ });
+
+ const audio = response.data.data.filename;
+ console.log('Audio uploaded:', audio);
+
+ this.stagedAudioUrl = `/api/chat/get_file?filename=${audio}`;
+ } catch (err) {
+ console.error('Error uploading audio:', err);
+ }
+ };
+ },
+
async handlePaste(event) {
console.log('Pasting image...');
const items = event.clipboardData.items;
@@ -198,6 +267,9 @@ export default {
message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`;
}
}
+ if (message[i].audio_url) {
+ message[i].audio_url = `/api/chat/get_file?filename=${message[i].audio_url}`;
+ }
}
this.messages = message;
}).catch(err => {
@@ -250,24 +322,26 @@ export default {
this.messages.push({
type: 'user',
message: this.prompt,
- image_url: this.stagedImagesUrl
+ image_url: this.stagedImagesUrl,
+ audio_url: this.stagedAudioUrl
});
- // let bot_resp = {
- // type: 'bot',
- // message: ref('')
- // }
-
- // this.messages.push(bot_resp);
-
this.scrollToBottom();
+ // images
let image_filenames = [];
for (let i = 0; i < this.stagedImagesUrl.length; i++) {
let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', '');
image_filenames.push(img);
}
+ // audio
+ let audio_filenames = [];
+ if (this.stagedAudioUrl) {
+ let audio = this.stagedAudioUrl.replace('/api/chat/get_file?filename=', '');
+ audio_filenames.push(audio);
+ }
+
this.loadingChat = true;
@@ -277,11 +351,17 @@ export default {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
- body: JSON.stringify({ message: this.prompt, conversation_id: this.currCid, image_url: image_filenames }) // 发送请求体
+ body: JSON.stringify({
+ message: this.prompt,
+ conversation_id: this.currCid,
+ image_url: image_filenames,
+ audio_url: audio_filenames
+ }) // 发送请求体
})
.then(response => {
this.prompt = '';
this.stagedImagesUrl = [];
+ this.stagedAudioUrl = "";
this.loadingChat = false;
From f2566c68e358fa1e84c36e782be268e255eeb87f Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 11 Jan 2025 19:07:26 +0800
Subject: [PATCH 3/6] =?UTF-8?q?feat:=20=E6=8C=89=20K=20=E8=AF=AD=E9=9F=B3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
dashboard/src/views/ChatPage.vue | 30 +++++++++++++++++++++++++++---
1 file changed, 27 insertions(+), 3 deletions(-)
diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue
index f3c0227a5..e921080fc 100644
--- a/dashboard/src/views/ChatPage.vue
+++ b/dashboard/src/views/ChatPage.vue
@@ -49,6 +49,12 @@ marked.setOptions({
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help
获取帮助 😊
+
+ 按
+ K
+ 开始语音 🎤
+
@@ -90,15 +96,21 @@ marked.setOptions({
-
-
+
+
+
+
+
+
+
@@ -161,6 +173,18 @@ export default {
this.getConversations();
let inputField = document.getElementById('input-field');
inputField.addEventListener('paste', this.handlePaste);
+ inputField.addEventListener('keydown', function (e) {
+ console.log(e);
+ if (e.keyCode == 13 && !e.shiftKey) {
+ e.preventDefault();
+ this.sendMessage();
+ }
+ }.bind(this));
+ document.addEventListener('keydown', function (e) {
+ if (e.keyCode == 75) {
+ this.isRecording ? this.stopRecording() : this.startRecording();
+ }
+ }.bind(this));
},
methods: {
From 97b58965f2650c904b7a15cc3b40f078eb008836 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 11 Jan 2025 19:31:56 +0800
Subject: [PATCH 4/6] =?UTF-8?q?feat:=20webchat=E5=8F=AF=E6=98=BE=E7=A4=BAP?=
=?UTF-8?q?rovider=E7=8A=B6=E6=80=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/dashboard/routes/chat.py | 16 +++++++--
astrbot/dashboard/server.py | 2 +-
dashboard/src/views/ChatPage.vue | 62 +++++++++++++++++++++++---------
3 files changed, 60 insertions(+), 20 deletions(-)
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index 02f40d648..45bcb5444 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -6,9 +6,11 @@ from astrbot.core import web_chat_queue, web_chat_back_queue
from quart import request, Response as QuartResponse, g
from astrbot.core.db import BaseDatabase
import asyncio
+from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
+
class ChatRoute(Route):
- def __init__(self, context: RouteContext, db: BaseDatabase) -> None:
+ def __init__(self, context: RouteContext, db: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
self.routes = {
'/chat/send': ('POST', self.chat),
@@ -18,13 +20,23 @@ class ChatRoute(Route):
'/chat/delete_conversation': ('GET', self.delete_conversation),
'/chat/get_file': ('GET', self.get_file),
'/chat/post_image': ('POST', self.post_image),
- '/chat/post_file': ('POST', self.post_file)
+ '/chat/post_file': ('POST', self.post_file),
+ '/chat/status': ('GET', self.status),
}
self.db = db
+ self.core_lifecycle = core_lifecycle
self.register_routes()
self.imgs_dir = "data/webchat/imgs"
self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
+
+ async def status(self):
+ has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None
+ has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
+ return Response().ok(data={
+ 'llm_enabled': has_llm_enabled,
+ 'stt_enabled': has_stt_enabled
+ }).__dict__
async def get_file(self):
filename = request.args.get('filename')
diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py
index 97013b423..798b212c2 100644
--- a/astrbot/dashboard/server.py
+++ b/astrbot/dashboard/server.py
@@ -31,7 +31,7 @@ class AstrBotDashboard():
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
- self.chat_route = ChatRoute(self.context, db)
+ self.chat_route = ChatRoute(self.context, db, core_lifecycle)
async def auth_middleware(self):
if not request.path.startswith("/api"):
diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue
index e921080fc..cfe96e544 100644
--- a/dashboard/src/views/ChatPage.vue
+++ b/dashboard/src/views/ChatPage.vue
@@ -20,7 +20,7 @@ marked.setOptions({
:disabled="!currCid">+ 创建对话
-
@@ -31,12 +31,24 @@ marked.setOptions({
+
+
+
+ LLM
+
+
+
+ 语音转文本
+
+
+
删除此对话
+
-
+
@@ -96,8 +108,7 @@ marked.setOptions({
+ @click:clear="clearMessage" style="width: 100%; max-width: 850px;">
@@ -106,17 +117,20 @@ marked.setOptions({
-
+
-
-
+
+
-
+
-
+
@@ -129,12 +143,13 @@ marked.setOptions({
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle
-
+
新录音
mdi-close-circle
-
+
@@ -165,16 +180,19 @@ export default {
isRecording: false,
audioChunks: [],
stagedAudioUrl: "",
- mediaRecorder: null
+ mediaRecorder: null,
+
+ status: {},
+ statusText: ''
}
},
mounted() {
+ this.checkStatus();
this.getConversations();
let inputField = document.getElementById('input-field');
inputField.addEventListener('paste', this.handlePaste);
inputField.addEventListener('keydown', function (e) {
- console.log(e);
if (e.keyCode == 13 && !e.shiftKey) {
e.preventDefault();
this.sendMessage();
@@ -193,6 +211,15 @@ export default {
this.stagedAudioUrl = null;
},
+ checkStatus() {
+ axios.get('/api/chat/status').then(response => {
+ console.log(response.data);
+ this.status = response.data.data;
+ }).catch(err => {
+ console.error(err);
+ });
+ },
+
async startRecording() {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
this.mediaRecorder = new MediaRecorder(stream);
@@ -212,6 +239,8 @@ export default {
const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' });
this.audioChunks = [];
+ this.mediaRecorder.stream.getTracks().forEach(track => track.stop());
+
const formData = new FormData();
formData.append('file', audioBlob);
@@ -253,7 +282,6 @@ export default {
const img = response.data.data.filename;
this.stagedImagesUrl.push(`/api/chat/get_file?filename=${img}`);
- scrollToBottom();
} catch (err) {
console.error('Error uploading image:', err);
}
@@ -375,9 +403,9 @@ export default {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
- body: JSON.stringify({
- message: this.prompt,
- conversation_id: this.currCid,
+ body: JSON.stringify({
+ message: this.prompt,
+ conversation_id: this.currCid,
image_url: image_filenames,
audio_url: audio_filenames
}) // 发送请求体
From 0f9ab082abed5dafc507089c8e6e254867848cb8 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 11 Jan 2025 19:45:42 +0800
Subject: [PATCH 5/6] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96webchat=EF=BC=8C?=
=?UTF-8?q?=E6=B2=A1=E6=9C=89=E7=BB=93=E6=9E=9C=E8=BF=94=E5=9B=9E=E6=97=B6?=
=?UTF-8?q?=E7=9A=84=E5=8F=8D=E9=A6=88?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/core/__init__.py | 4 ++--
astrbot/core/pipeline/scheduler.py | 4 ++++
astrbot/core/platform/sources/webchat/webchat_event.py | 10 +++++++---
astrbot/core/provider/manager.py | 10 ++++++----
astrbot/dashboard/routes/chat.py | 6 +++++-
5 files changed, 24 insertions(+), 10 deletions(-)
diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py
index 0ef8e039d..aac8fc117 100644
--- a/astrbot/core/__init__.py
+++ b/astrbot/core/__init__.py
@@ -20,6 +20,6 @@ if os.environ.get('TESTING', ""):
db_helper = SQLiteDatabase(DB_PATH)
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
-web_chat_queue = asyncio.Queue()
-web_chat_back_queue = asyncio.Queue()
+web_chat_queue = asyncio.Queue(maxsize=32)
+web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py
index 8842057b1..e18ee92be 100644
--- a/astrbot/core/pipeline/scheduler.py
+++ b/astrbot/core/pipeline/scheduler.py
@@ -41,4 +41,8 @@ class PipelineScheduler():
async def execute(self, event: AstrMessageEvent):
'''执行 pipeline'''
await self._process_stages(event)
+
+ if not event._has_send_oper and event.get_platform_name() == "webchat":
+ await event.send(None)
+
logger.debug("pipeline 执行完毕。")
\ No newline at end of file
diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py
index c988724be..4312b0cb1 100644
--- a/astrbot/core/platform/sources/webchat/webchat_event.py
+++ b/astrbot/core/platform/sources/webchat/webchat_event.py
@@ -12,9 +12,13 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(self.imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
+ if not message:
+ await web_chat_back_queue.put_nowait(None)
+ return
+
for comp in message.chain:
if isinstance(comp, Plain):
- await web_chat_back_queue.put(comp.text)
+ await web_chat_back_queue.put_nowait(comp.text)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
@@ -26,6 +30,6 @@ class WebChatMessageEvent(AstrMessageEvent):
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
- await web_chat_back_queue.put(f"[IMAGE]{filename}")
- await web_chat_back_queue.put(None)
+ await web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
+ await web_chat_back_queue.put_nowait(None)
await super().send(message)
\ No newline at end of file
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 4f70c33f0..7523960ef 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -64,6 +64,8 @@ class ProviderManager():
continue
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
+ provider_enabled = self.provider_settings.get("enable", False)
+ stt_enabled = self.provider_stt_settings.get("enable", False)
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
@@ -74,7 +76,7 @@ class ProviderManager():
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
self.stt_provider_insts.append(inst)
- if selected_stt_provider_id == provider_config['id']:
+ if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
@@ -82,7 +84,7 @@ class ProviderManager():
# 文本生成任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
- if selected_provider_id == provider_config['id']:
+ if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
@@ -90,10 +92,10 @@ class ProviderManager():
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
- if len(self.provider_insts) > 0 and not self.curr_provider_inst:
+ if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
self.curr_provider_inst = self.provider_insts[0]
- if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst:
+ if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if not self.curr_provider_inst:
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index 45bcb5444..639c7cafa 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -116,7 +116,11 @@ class ChatRoute(Route):
async def stream():
ret = []
while True:
- result = await web_chat_back_queue.get()
+ try:
+ result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒
+ except asyncio.TimeoutError:
+ yield '[Error] 30 秒内没有返回数据,已放弃。\n'
+ return
if result is None:
break
From ba198490fa1f619c08c673dbc598dfd5a0f1aa1e Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 11 Jan 2025 20:31:21 +0800
Subject: [PATCH 6/6] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=87=AA?=
=?UTF-8?q?=E9=83=A8=E7=BD=B2=20Whisper=20=E6=A8=A1=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/core/config/default.py | 15 ++-
.../core/pipeline/preprocess_stage/stage.py | 3 +-
.../platform/sources/webchat/webchat_event.py | 8 +-
astrbot/core/provider/manager.py | 46 +++++++---
.../sources/whisper_selfhosted_source.py | 92 +++++++++++++++++++
astrbot/dashboard/dashboard_lifecycle.py | 13 ++-
6 files changed, 154 insertions(+), 23 deletions(-)
create mode 100644 astrbot/core/provider/sources/whisper_selfhosted_source.py
diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 983ca3cd3..f9e2a3485 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -323,13 +323,26 @@ CONFIG_METADATA_2 = {
"whisper(API)": {
"id": "whisper",
"type": "openai_whisper_api",
- "enable": True,
+ "enable": False,
"api_key": "",
"api_base": "",
"model": "whisper-1",
+ },
+ "whisper(本地加载)": {
+ "whisper_hint": "(不用修改我)",
+ "enable": False,
+ "id": "whisper",
+ "type": "openai_whisper_selfhost",
+ "model": "tiny",
}
},
"items": {
+ "whisper_hint": {
+ "description": "本地部署 Whisper 模型须知",
+ "type": "string",
+ "hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
+ "obvious_hint": True
+ },
"id": {
"description": "ID",
"type": "string",
diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py
index a28e15485..455fd35c8 100644
--- a/astrbot/core/pipeline/preprocess_stage/stage.py
+++ b/astrbot/core/pipeline/preprocess_stage/stage.py
@@ -43,8 +43,9 @@ class PreProcessStage(Stage):
event.message_str += result
event.message_obj.message_str += result
break
- except FileNotFoundError:
+ except FileNotFoundError as e:
# napcat workaround
+ logger.warning(e)
logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py
index 4312b0cb1..0ef57ed5f 100644
--- a/astrbot/core/platform/sources/webchat/webchat_event.py
+++ b/astrbot/core/platform/sources/webchat/webchat_event.py
@@ -13,12 +13,12 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
if not message:
- await web_chat_back_queue.put_nowait(None)
+ web_chat_back_queue.put_nowait(None)
return
for comp in message.chain:
if isinstance(comp, Plain):
- await web_chat_back_queue.put_nowait(comp.text)
+ web_chat_back_queue.put_nowait(comp.text)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
@@ -30,6 +30,6 @@ class WebChatMessageEvent(AstrMessageEvent):
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
- await web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
- await web_chat_back_queue.put_nowait(None)
+ web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
+ web_chat_back_queue.put_nowait(None)
await super().send(message)
\ No newline at end of file
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 7523960ef..3b64126f4 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -39,21 +39,29 @@ class ProviderManager():
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。")
self.loaded_ids[provider_cfg['id']] = True
- match provider_cfg['type']:
- case "openai_chat_completion":
- from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
- case "zhipu_chat_completion":
- from .sources.zhipu_source import ProviderZhipu # noqa: F401
- case "llm_tuner":
- logger.info("加载 LLM Tuner 工具 ...")
- from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
- case "dify":
- from .sources.dify_source import ProviderDify # noqa: F401
- case "googlegenai_chat_completion":
- from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
- case "openai_whisper_api":
- from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
-
+ try:
+ match provider_cfg['type']:
+ case "openai_chat_completion":
+ from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
+ case "zhipu_chat_completion":
+ from .sources.zhipu_source import ProviderZhipu # noqa: F401
+ case "llm_tuner":
+ logger.info("加载 LLM Tuner 工具 ...")
+ from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
+ case "dify":
+ from .sources.dify_source import ProviderDify # noqa: F401
+ case "googlegenai_chat_completion":
+ from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
+ case "openai_whisper_api":
+ from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
+ case "openai_whisper_selfhost":
+ from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
+ except (ImportError, ModuleNotFoundError) as e:
+ logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
+ continue
+ except Exception as e:
+ logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
+ continue
async def initialize(self):
for provider_config in self.providers_config:
@@ -75,6 +83,10 @@ class ProviderManager():
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
+
+ if getattr(inst, "initialize", None):
+ await inst.initialize()
+
self.stt_provider_insts.append(inst)
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
@@ -83,6 +95,10 @@ class ProviderManager():
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
+
+ if getattr(inst, "initialize", None):
+ await inst.initialize()
+
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py
new file mode 100644
index 000000000..6f16d559a
--- /dev/null
+++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py
@@ -0,0 +1,92 @@
+import uuid
+import os
+import io
+import asyncio
+import whisper
+from ..provider import STTProvider
+from ..entites import ProviderType
+from astrbot.core.utils.io import download_file
+from ..register import register_provider_adapter
+from astrbot.core import logger
+
+
+@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
+class ProviderOpenAIWhisperSelfHost(STTProvider):
+ def __init__(
+ self,
+ provider_config: dict,
+ provider_settings: dict,
+ ) -> None:
+ super().__init__(provider_config, provider_settings)
+ self.set_model(provider_config.get("model", None))
+ self.model = None
+
+ async def initialize(self):
+ loop = asyncio.get_event_loop()
+ logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
+ self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
+ logger.info("Whisper 模型加载完成。")
+
+ async def _convert_audio(self, path: str) -> str:
+ from pyffmpeg import FFmpeg
+ filename = str(uuid.uuid4()) + '.mp3'
+ ff = FFmpeg()
+ output_path = ff.convert(path, os.path.join('data/temp', filename))
+ return output_path
+
+ async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
+ import wave
+
+ with wave.open(output_path, 'wb') as wav:
+ wav.setnchannels(1)
+ wav.setsampwidth(2)
+ wav.setframerate(24000)
+ wav.writeframes(input_io.read())
+
+ return output_path
+
+ async def _convert_silk(self, path: str) -> str:
+ import pysilk
+ filename = str(uuid.uuid4()) + '.wav'
+ output_path = os.path.join('data/temp', filename)
+ with open(path, "rb") as f:
+ input_data = f.read()
+ if input_data.startswith(b'\x02'):
+ # tencent 我爱你
+ input_data = input_data[1:]
+ input_io = io.BytesIO(input_data)
+ output_io = io.BytesIO()
+ pysilk.decode(input_io, output_io, 24000)
+ output_io.seek(0)
+ await self._pcm_to_wav(output_io, output_path)
+
+ return output_path
+
+ async def _is_silk_file(self, file_path):
+ silk_header = b"SILK"
+ with open(file_path, "rb") as f:
+ file_header = f.read(8)
+
+ if silk_header in file_header:
+ return True
+ else:
+ return False
+
+ async def get_text(self, audio_url: str) -> str:
+ loop = asyncio.get_event_loop()
+ if audio_url.startswith("http"):
+ name = str(uuid.uuid4())
+ path = os.path.join("data/temp", name)
+ audio_url = await download_file(audio_url, path)
+
+ if not os.path.exists(audio_url):
+ raise FileNotFoundError(f"文件不存在: {audio_url}")
+
+ if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
+ is_silk = await self._is_silk_file(audio_url)
+ if is_silk:
+ logger.info("Converting silk file to wav ...")
+ audio_url = await self._convert_silk(audio_url)
+
+ result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
+ return result['text']
\ No newline at end of file
diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/dashboard/dashboard_lifecycle.py
index 176ca4dc1..b363ae3a7 100644
--- a/astrbot/dashboard/dashboard_lifecycle.py
+++ b/astrbot/dashboard/dashboard_lifecycle.py
@@ -1,4 +1,5 @@
import asyncio
+import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
@@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle:
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
- await core_lifecycle.initialize()
- core_task = core_lifecycle.start()
+
+ core_task = []
+ try:
+ await core_lifecycle.initialize()
+ core_task = core_lifecycle.start()
+ except Exception as e:
+ logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
+ logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
+ logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
+
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
task = asyncio.gather(core_task, self.dashboard_server.run())