From 33618c4a6be53acdb48e14dd2964d40e2997755b Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 25 Oct 2025 16:39:11 +0800 Subject: [PATCH] feat: add dynamic embedding dimension retrieval for providers and enhance error handling --- astrbot/core/config/default.py | 1 + astrbot/core/provider/manager.py | 23 +++-- .../sources/gemini_embedding_source.py | 3 +- .../sources/openai_embedding_source.py | 3 +- astrbot/dashboard/routes/config.py | 63 ++++++++++++- .../src/components/shared/AstrBotConfig.vue | 50 ++++++++++ dashboard/src/views/ProviderPage.vue | 93 ++++++++----------- 7 files changed, 166 insertions(+), 70 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index c0d162161..ee9ab9458 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1417,6 +1417,7 @@ CONFIG_METADATA_2 = { "description": "嵌入维度", "type": "int", "hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。", + "_special": "get_embedding_dim", }, "embedding_model": { "description": "嵌入模型", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 6666b33e6..edb9b767a 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,6 +1,5 @@ import asyncio import traceback -from typing import List from astrbot.core import logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager @@ -28,7 +27,7 @@ class ProviderManager: self.persona_mgr = persona_mgr self.acm = acm config = acm.confs["default"] - self.providers_config: List = config["provider"] + self.providers_config: list = config["provider"] self.provider_settings: dict = config["provider_settings"] self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) @@ -36,15 +35,15 @@ class ProviderManager: # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager self.default_persona_name = persona_mgr.default_persona - self.provider_insts: List[Provider] = [] + self.provider_insts: list[Provider] = [] """加载的 Provider 的实例""" - self.stt_provider_insts: List[STTProvider] = [] + self.stt_provider_insts: list[STTProvider] = [] """加载的 Speech To Text Provider 的实例""" - self.tts_provider_insts: List[TTSProvider] = [] + self.tts_provider_insts: list[TTSProvider] = [] """加载的 Text To Speech Provider 的实例""" - self.embedding_provider_insts: List[EmbeddingProvider] = [] + self.embedding_provider_insts: list[EmbeddingProvider] = [] """加载的 Embedding Provider 的实例""" - self.rerank_provider_insts: List[RerankProvider] = [] + self.rerank_provider_insts: list[RerankProvider] = [] """加载的 Rerank Provider 的实例""" self.inst_map: dict[ str, @@ -175,7 +174,11 @@ class ProviderManager: async def initialize(self): # 逐个初始化提供商 for provider_config in self.providers_config: - await self.load_provider(provider_config) + try: + await self.load_provider(provider_config) + except Exception as e: + logger.error(traceback.format_exc()) + logger.error(e) # 设置默认提供商 selected_provider_id = sp.get( @@ -404,10 +407,12 @@ class ProviderManager: self.inst_map[provider_config["id"]] = inst except Exception as e: - logger.error(traceback.format_exc()) logger.error( f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}" ) + raise Exception( + f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}" + ) async def reload(self, provider_config: dict): await self.terminate_provider(provider_config["id"]) diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py index baccf52a2..562d11353 100644 --- a/astrbot/core/provider/sources/gemini_embedding_source.py +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -32,7 +32,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider): self.model = provider_config.get( "embedding_model", "gemini-embedding-exp-03-07" ) - self.dimension = provider_config.get("embedding_dimensions", 768) async def get_embedding(self, text: str) -> list[float]: """ @@ -60,4 +59,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider): def get_dim(self) -> int: """获取向量的维度""" - return self.dimension + return self.provider_config.get("embedding_dimensions", 768) diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 79b2e83b2..e6f692a35 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -22,7 +22,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): timeout=int(provider_config.get("timeout", 20)), ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") - self.dimension = provider_config.get("embedding_dimensions", 1024) async def get_embedding(self, text: str) -> list[float]: """ @@ -40,4 +39,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): def get_dim(self) -> int: """获取向量的维度""" - return self.dimension + return self.provider_config.get("embedding_dimensions", 1024) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 01b9b2432..998240c99 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,4 +1,3 @@ -import typing import traceback import os import inspect @@ -45,9 +44,7 @@ def try_cast(value: str, type_: str): return None -def validate_config( - data, schema: dict, is_core: bool -) -> typing.Tuple[typing.List[str], typing.Dict]: +def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: errors = [] def validate(data: dict, metadata: dict = schema, path=""): @@ -178,6 +175,7 @@ class ConfigRoute(Route): "/config/provider/check_one": ("GET", self.check_one_provider_status), "/config/provider/list": ("GET", self.get_provider_config_list), "/config/provider/model_list": ("GET", self.get_provider_model_list), + "/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim), } self.register_routes() @@ -601,6 +599,61 @@ class ConfigRoute(Route): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ + async def get_embedding_dim(self): + """获取嵌入模型的维度""" + post_data = await request.json + provider_config = post_data.get("provider_config", None) + if not provider_config: + return Response().error("缺少参数 provider_config").__dict__ + + try: + # 动态导入 EmbeddingProvider + from astrbot.core.provider.provider import EmbeddingProvider + from astrbot.core.provider.register import provider_cls_map + + # 获取 provider 类型 + provider_type = provider_config.get("type", None) + if not provider_type: + return Response().error("provider_config 缺少 type 字段").__dict__ + + # 获取对应的 provider 类 + if provider_type not in provider_cls_map: + return ( + Response() + .error(f"未找到适用于 {provider_type} 的提供商适配器") + .__dict__ + ) + + provider_metadata = provider_cls_map[provider_type] + cls_type = provider_metadata.cls_type + + if not cls_type: + return Response().error(f"无法找到 {provider_type} 的类").__dict__ + + # 实例化 provider + inst = cls_type(provider_config, {}) + + # 检查是否是 EmbeddingProvider + if not isinstance(inst, EmbeddingProvider): + return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ + + # 初始化 + if getattr(inst, "initialize", None): + await inst.initialize() + + # 获取嵌入向量维度 + vec = await inst.get_embedding("echo") + dim = len(vec) + + logger.info( + f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}" + ) + + return Response().ok({"embedding_dimensions": dim}).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取嵌入维度失败: {str(e)}").__dict__ + async def get_platform_list(self): """获取所有平台的列表""" platform_list = [] @@ -797,7 +850,7 @@ class ConfigRoute(Route): logger.warning( f"Failed to import required modules for platform {platform.name}: {e}" ) - except (OSError, IOError) as e: + except OSError as e: logger.warning(f"File system error for platform {platform.name} logo: {e}") except Exception as e: logger.warning( diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index e45327eb8..d6c6fee9c 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -7,6 +7,8 @@ import ProviderSelector from './ProviderSelector.vue' import PersonaSelector from './PersonaSelector.vue' import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue' import { useI18n } from '@/i18n/composables' +import axios from 'axios' +import { useToast } from '@/utils/toast' const props = defineProps({ metadata: { @@ -40,6 +42,7 @@ const currentEditingKey = ref('') const currentEditingLanguage = ref('json') const currentEditingTheme = ref('vs-light') let currentEditingKeyIterable = null +const loadingEmbeddingDim = ref(false) function openEditorDialog(key, value, theme, language) { currentEditingKey.value = key @@ -49,10 +52,34 @@ function openEditorDialog(key, value, theme, language) { dialog.value = true } + function saveEditedContent() { dialog.value = false } +async function getEmbeddingDimensions(providerConfig) { + if (loadingEmbeddingDim.value) return + + loadingEmbeddingDim.value = true + try { + const response = await axios.post('/api/config/provider/get_embedding_dim', { + provider_config: providerConfig + }) + + if (response.data.status != "error" && response.data.data?.embedding_dimensions) { + console.log(response.data.data.embedding_dimensions) + providerConfig.embedding_dimensions = response.data.data.embedding_dimensions + useToast().success("获取成功: " + response.data.data.embedding_dimensions) + } else { + useToast().error(response.data.message) + } + } catch (error) { + console.error('Error getting embedding dimensions:', error) + } finally { + loadingEmbeddingDim.value = false + } +} + function getValueBySelector(obj, selector) { const keys = selector.split('.') let current = obj @@ -184,6 +211,29 @@ function hasVisibleItemsAfter(items, currentIndex) { v-model="iterable[key]" /> + +