feat: add dynamic embedding dimension retrieval for providers and enhance error handling

This commit is contained in:
Soulter
2025-10-25 16:39:11 +08:00
parent f7d018cf94
commit 33618c4a6b
7 changed files with 166 additions and 70 deletions
+1
View File
@@ -1417,6 +1417,7 @@ CONFIG_METADATA_2 = {
"description": "嵌入维度",
"type": "int",
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
"_special": "get_embedding_dim",
},
"embedding_model": {
"description": "嵌入模型",
+14 -9
View File
@@ -1,6 +1,5 @@
import asyncio
import traceback
from typing import List
from astrbot.core import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -28,7 +27,7 @@ class ProviderManager:
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
self.providers_config: List = config["provider"]
self.providers_config: list = config["provider"]
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
@@ -36,15 +35,15 @@ class ProviderManager:
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
self.default_persona_name = persona_mgr.default_persona
self.provider_insts: List[Provider] = []
self.provider_insts: list[Provider] = []
"""加载的 Provider 的实例"""
self.stt_provider_insts: List[STTProvider] = []
self.stt_provider_insts: list[STTProvider] = []
"""加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = []
self.tts_provider_insts: list[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[EmbeddingProvider] = []
self.embedding_provider_insts: list[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.rerank_provider_insts: List[RerankProvider] = []
self.rerank_provider_insts: list[RerankProvider] = []
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
@@ -175,7 +174,11 @@ class ProviderManager:
async def initialize(self):
# 逐个初始化提供商
for provider_config in self.providers_config:
await self.load_provider(provider_config)
try:
await self.load_provider(provider_config)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get(
@@ -404,10 +407,12 @@ class ProviderManager:
self.inst_map[provider_config["id"]] = inst
except Exception as e:
logger.error(traceback.format_exc())
logger.error(
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
)
raise Exception(
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
)
async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config["id"])
@@ -32,7 +32,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
self.model = provider_config.get(
"embedding_model", "gemini-embedding-exp-03-07"
)
self.dimension = provider_config.get("embedding_dimensions", 768)
async def get_embedding(self, text: str) -> list[float]:
"""
@@ -60,4 +59,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension
return self.provider_config.get("embedding_dimensions", 768)
@@ -22,7 +22,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
timeout=int(provider_config.get("timeout", 20)),
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
self.dimension = provider_config.get("embedding_dimensions", 1024)
async def get_embedding(self, text: str) -> list[float]:
"""
@@ -40,4 +39,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension
return self.provider_config.get("embedding_dimensions", 1024)
+58 -5
View File
@@ -1,4 +1,3 @@
import typing
import traceback
import os
import inspect
@@ -45,9 +44,7 @@ def try_cast(value: str, type_: str):
return None
def validate_config(
data, schema: dict, is_core: bool
) -> typing.Tuple[typing.List[str], typing.Dict]:
def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]:
errors = []
def validate(data: dict, metadata: dict = schema, path=""):
@@ -178,6 +175,7 @@ class ConfigRoute(Route):
"/config/provider/check_one": ("GET", self.check_one_provider_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/model_list": ("GET", self.get_provider_model_list),
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
}
self.register_routes()
@@ -601,6 +599,61 @@ class ConfigRoute(Route):
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_embedding_dim(self):
"""获取嵌入模型的维度"""
post_data = await request.json
provider_config = post_data.get("provider_config", None)
if not provider_config:
return Response().error("缺少参数 provider_config").__dict__
try:
# 动态导入 EmbeddingProvider
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.register import provider_cls_map
# 获取 provider 类型
provider_type = provider_config.get("type", None)
if not provider_type:
return Response().error("provider_config 缺少 type 字段").__dict__
# 获取对应的 provider 类
if provider_type not in provider_cls_map:
return (
Response()
.error(f"未找到适用于 {provider_type} 的提供商适配器")
.__dict__
)
provider_metadata = provider_cls_map[provider_type]
cls_type = provider_metadata.cls_type
if not cls_type:
return Response().error(f"无法找到 {provider_type} 的类").__dict__
# 实例化 provider
inst = cls_type(provider_config, {})
# 检查是否是 EmbeddingProvider
if not isinstance(inst, EmbeddingProvider):
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
# 初始化
if getattr(inst, "initialize", None):
await inst.initialize()
# 获取嵌入向量维度
vec = await inst.get_embedding("echo")
dim = len(vec)
logger.info(
f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}"
)
return Response().ok({"embedding_dimensions": dim}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取嵌入维度失败: {str(e)}").__dict__
async def get_platform_list(self):
"""获取所有平台的列表"""
platform_list = []
@@ -797,7 +850,7 @@ class ConfigRoute(Route):
logger.warning(
f"Failed to import required modules for platform {platform.name}: {e}"
)
except (OSError, IOError) as e:
except OSError as e:
logger.warning(f"File system error for platform {platform.name} logo: {e}")
except Exception as e:
logger.warning(
@@ -7,6 +7,8 @@ import ProviderSelector from './ProviderSelector.vue'
import PersonaSelector from './PersonaSelector.vue'
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
import { useI18n } from '@/i18n/composables'
import axios from 'axios'
import { useToast } from '@/utils/toast'
const props = defineProps({
metadata: {
@@ -40,6 +42,7 @@ const currentEditingKey = ref('')
const currentEditingLanguage = ref('json')
const currentEditingTheme = ref('vs-light')
let currentEditingKeyIterable = null
const loadingEmbeddingDim = ref(false)
function openEditorDialog(key, value, theme, language) {
currentEditingKey.value = key
@@ -49,10 +52,34 @@ function openEditorDialog(key, value, theme, language) {
dialog.value = true
}
function saveEditedContent() {
dialog.value = false
}
async function getEmbeddingDimensions(providerConfig) {
if (loadingEmbeddingDim.value) return
loadingEmbeddingDim.value = true
try {
const response = await axios.post('/api/config/provider/get_embedding_dim', {
provider_config: providerConfig
})
if (response.data.status != "error" && response.data.data?.embedding_dimensions) {
console.log(response.data.data.embedding_dimensions)
providerConfig.embedding_dimensions = response.data.data.embedding_dimensions
useToast().success("获取成功: " + response.data.data.embedding_dimensions)
} else {
useToast().error(response.data.message)
}
} catch (error) {
console.error('Error getting embedding dimensions:', error)
} finally {
loadingEmbeddingDim.value = false
}
}
function getValueBySelector(obj, selector) {
const keys = selector.split('.')
let current = obj
@@ -184,6 +211,29 @@ function hasVisibleItemsAfter(items, currentIndex) {
v-model="iterable[key]"
/>
</div>
<!-- Numeric input with get_embedding_dim button -->
<div v-else-if="metadata[metadataKey].items[key]?._special === 'get_embedding_dim'"
class="d-flex align-center gap-2">
<v-text-field
v-model="iterable[key]"
density="compact"
variant="outlined"
class="config-field"
type="number"
hide-details
></v-text-field>
<v-btn
color="primary"
variant="tonal"
size="small"
@click="getEmbeddingDimensions(iterable)"
:loading="loadingEmbeddingDim"
class="ml-2"
>
自动检测
</v-btn>
</div>
<!-- List item with options-->
<div v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible && metadata[metadataKey].items[key]?.render_type === 'checkbox'"
class="d-flex flex-wrap gap-20">
+41 -52
View File
@@ -12,7 +12,8 @@
</p>
</div>
<div>
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddProviderDialog = true" rounded="xl" size="x-large">
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddProviderDialog = true"
rounded="xl" size="x-large">
{{ tm('providers.addProvider') }}
</v-btn>
</div>
@@ -56,30 +57,16 @@
<v-row v-else>
<v-col v-for="(provider, index) in filteredProviders" :key="index" cols="12" md="6" lg="4" xl="3">
<item-card
:item="provider"
title-field="id"
enabled-field="enable"
:loading="isProviderTesting(provider.id)"
@toggle-enabled="providerStatusChange"
:bglogo="getProviderIcon(provider.provider)"
@delete="deleteProvider"
@edit="configExistingProvider"
@copy="copyProvider"
:show-copy-button="true">
<template #actions="{ item }">
<v-btn
style="z-index: 100000;"
variant="tonal"
color="info"
rounded="xl"
size="small"
:loading="isProviderTesting(item.id)"
@click="testSingleProvider(item)"
>
{{ tm('availability.test') }}
</v-btn>
</template>
<item-card :item="provider" title-field="id" enabled-field="enable"
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
@copy="copyProvider" :show-copy-button="true">
<template #actions="{ item }">
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
{{ tm('availability.test') }}
</v-btn>
</template>
<template v-slot:details="{ item }">
</template>
</item-card>
@@ -115,16 +102,12 @@
<v-col v-for="status in providerStatuses" :key="status.id" cols="12" sm="6" md="4">
<v-card variant="outlined" class="status-card" :class="`status-${status.status}`">
<v-card-item>
<v-icon v-if="status.status === 'available'" color="success" class="me-2">mdi-check-circle</v-icon>
<v-icon v-else-if="status.status === 'unavailable'" color="error" class="me-2">mdi-alert-circle</v-icon>
<v-progress-circular
v-else-if="status.status === 'pending'"
indeterminate
color="primary"
size="20"
width="2"
class="me-2"
></v-progress-circular>
<v-icon v-if="status.status === 'available'" color="success"
class="me-2">mdi-check-circle</v-icon>
<v-icon v-else-if="status.status === 'unavailable'" color="error"
class="me-2">mdi-alert-circle</v-icon>
<v-progress-circular v-else-if="status.status === 'pending'" indeterminate color="primary"
size="20" width="2" class="me-2"></v-progress-circular>
<span class="font-weight-bold">{{ status.id }}</span>
@@ -165,22 +148,16 @@
</v-container>
<!-- 添加提供商对话框 -->
<AddNewProvider
v-model:show="showAddProviderDialog"
:metadata="metadata"
@select-template="selectProviderTemplate"
/>
<AddNewProvider v-model:show="showAddProviderDialog" :metadata="metadata"
@select-template="selectProviderTemplate" />
<!-- 配置对话框 -->
<v-dialog v-model="showProviderCfg" width="900" persistent>
<v-card :title="updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') + ` ${newSelectedProviderName} ` + tm('dialogs.config.provider')">
<v-card
:title="updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') + ` ${newSelectedProviderName} ` + tm('dialogs.config.provider')">
<v-card-text class="py-4">
<AstrBotConfig
:iterable="newSelectedProviderConfig"
:metadata="metadata['provider_group']?.metadata"
metadataKey="provider"
:is-editing="updatingMode"
/>
<AstrBotConfig :iterable="newSelectedProviderConfig" :metadata="metadata['provider_group']?.metadata"
metadataKey="provider" :is-editing="updatingMode" />
</v-card-text>
<v-divider></v-divider>
@@ -467,7 +444,7 @@ export default {
for (let key in source) {
if (source.hasOwnProperty(key)) {
if (typeof source[key] === 'object' && source[key] !== null) {
target[key] = Array.isArray(source[key]) ? [...source[key]] : {...source[key]};
target[key] = Array.isArray(source[key]) ? [...source[key]] : { ...source[key] };
} else {
target[key] = source[key];
}
@@ -528,7 +505,14 @@ export default {
id: this.newSelectedProviderName,
config: this.newSelectedProviderConfig
});
if (res.data.status === 'error') {
this.showError(res.data.message || "更新失败!");
return
}
this.showSuccess(res.data.message || "更新成功!");
if (wasUpdating) {
this.updatingMode = false;
}
} else {
// ID
const existingProvider = this.config_data.provider?.find(p => p.id === this.newSelectedProviderConfig.id);
@@ -541,17 +525,18 @@ export default {
}
const res = await axios.post('/api/config/provider/new', this.newSelectedProviderConfig);
if (res.data.status === 'error') {
this.showError(res.data.message || "添加失败!");
return
}
this.showSuccess(res.data.message || "添加成功!");
}
this.showProviderCfg = false;
this.getConfig();
} catch (err) {
this.showError(err.response?.data?.message || err.message);
} finally {
this.loading = false;
if (wasUpdating) {
this.updatingMode = false;
}
this.getConfig();
}
},
@@ -607,6 +592,10 @@ export default {
id: provider.id,
config: provider
}).then((res) => {
if (res.data.status === 'error') {
this.showError(res.data.message)
return
}
this.getConfig();
this.showSuccess(res.data.message || this.messages.success.statusUpdate);
}).catch((err) => {