feat: add dynamic embedding dimension retrieval for providers and enhance error handling
This commit is contained in:
@@ -1417,6 +1417,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "嵌入维度",
|
||||
"type": "int",
|
||||
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
|
||||
"_special": "get_embedding_dim",
|
||||
},
|
||||
"embedding_model": {
|
||||
"description": "嵌入模型",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -28,7 +27,7 @@ class ProviderManager:
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: List = config["provider"]
|
||||
self.providers_config: list = config["provider"]
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
@@ -36,15 +35,15 @@ class ProviderManager:
|
||||
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
||||
self.default_persona_name = persona_mgr.default_persona
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
self.provider_insts: list[Provider] = []
|
||||
"""加载的 Provider 的实例"""
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
self.stt_provider_insts: list[STTProvider] = []
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
self.tts_provider_insts: list[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||
self.embedding_provider_insts: list[EmbeddingProvider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.rerank_provider_insts: List[RerankProvider] = []
|
||||
self.rerank_provider_insts: list[RerankProvider] = []
|
||||
"""加载的 Rerank Provider 的实例"""
|
||||
self.inst_map: dict[
|
||||
str,
|
||||
@@ -175,7 +174,11 @@ class ProviderManager:
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
await self.load_provider(provider_config)
|
||||
try:
|
||||
await self.load_provider(provider_config)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
@@ -404,10 +407,12 @@ class ProviderManager:
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(
|
||||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
||||
)
|
||||
raise Exception(
|
||||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
||||
)
|
||||
|
||||
async def reload(self, provider_config: dict):
|
||||
await self.terminate_provider(provider_config["id"])
|
||||
|
||||
@@ -32,7 +32,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
self.model = provider_config.get(
|
||||
"embedding_model", "gemini-embedding-exp-03-07"
|
||||
)
|
||||
self.dimension = provider_config.get("embedding_dimensions", 768)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
@@ -60,4 +59,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
return self.provider_config.get("embedding_dimensions", 768)
|
||||
|
||||
@@ -22,7 +22,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1024)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
@@ -40,4 +39,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
return self.provider_config.get("embedding_dimensions", 1024)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import typing
|
||||
import traceback
|
||||
import os
|
||||
import inspect
|
||||
@@ -45,9 +44,7 @@ def try_cast(value: str, type_: str):
|
||||
return None
|
||||
|
||||
|
||||
def validate_config(
|
||||
data, schema: dict, is_core: bool
|
||||
) -> typing.Tuple[typing.List[str], typing.Dict]:
|
||||
def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]:
|
||||
errors = []
|
||||
|
||||
def validate(data: dict, metadata: dict = schema, path=""):
|
||||
@@ -178,6 +175,7 @@ class ConfigRoute(Route):
|
||||
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
||||
"/config/provider/list": ("GET", self.get_provider_config_list),
|
||||
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
||||
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
@@ -601,6 +599,61 @@ class ConfigRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_embedding_dim(self):
|
||||
"""获取嵌入模型的维度"""
|
||||
post_data = await request.json
|
||||
provider_config = post_data.get("provider_config", None)
|
||||
if not provider_config:
|
||||
return Response().error("缺少参数 provider_config").__dict__
|
||||
|
||||
try:
|
||||
# 动态导入 EmbeddingProvider
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
|
||||
# 获取 provider 类型
|
||||
provider_type = provider_config.get("type", None)
|
||||
if not provider_type:
|
||||
return Response().error("provider_config 缺少 type 字段").__dict__
|
||||
|
||||
# 获取对应的 provider 类
|
||||
if provider_type not in provider_cls_map:
|
||||
return (
|
||||
Response()
|
||||
.error(f"未找到适用于 {provider_type} 的提供商适配器")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
provider_metadata = provider_cls_map[provider_type]
|
||||
cls_type = provider_metadata.cls_type
|
||||
|
||||
if not cls_type:
|
||||
return Response().error(f"无法找到 {provider_type} 的类").__dict__
|
||||
|
||||
# 实例化 provider
|
||||
inst = cls_type(provider_config, {})
|
||||
|
||||
# 检查是否是 EmbeddingProvider
|
||||
if not isinstance(inst, EmbeddingProvider):
|
||||
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
|
||||
|
||||
# 初始化
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
# 获取嵌入向量维度
|
||||
vec = await inst.get_embedding("echo")
|
||||
dim = len(vec)
|
||||
|
||||
logger.info(
|
||||
f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}"
|
||||
)
|
||||
|
||||
return Response().ok({"embedding_dimensions": dim}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取嵌入维度失败: {str(e)}").__dict__
|
||||
|
||||
async def get_platform_list(self):
|
||||
"""获取所有平台的列表"""
|
||||
platform_list = []
|
||||
@@ -797,7 +850,7 @@ class ConfigRoute(Route):
|
||||
logger.warning(
|
||||
f"Failed to import required modules for platform {platform.name}: {e}"
|
||||
)
|
||||
except (OSError, IOError) as e:
|
||||
except OSError as e:
|
||||
logger.warning(f"File system error for platform {platform.name} logo: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
||||
@@ -7,6 +7,8 @@ import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import axios from 'axios'
|
||||
import { useToast } from '@/utils/toast'
|
||||
|
||||
const props = defineProps({
|
||||
metadata: {
|
||||
@@ -40,6 +42,7 @@ const currentEditingKey = ref('')
|
||||
const currentEditingLanguage = ref('json')
|
||||
const currentEditingTheme = ref('vs-light')
|
||||
let currentEditingKeyIterable = null
|
||||
const loadingEmbeddingDim = ref(false)
|
||||
|
||||
function openEditorDialog(key, value, theme, language) {
|
||||
currentEditingKey.value = key
|
||||
@@ -49,10 +52,34 @@ function openEditorDialog(key, value, theme, language) {
|
||||
dialog.value = true
|
||||
}
|
||||
|
||||
|
||||
function saveEditedContent() {
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
async function getEmbeddingDimensions(providerConfig) {
|
||||
if (loadingEmbeddingDim.value) return
|
||||
|
||||
loadingEmbeddingDim.value = true
|
||||
try {
|
||||
const response = await axios.post('/api/config/provider/get_embedding_dim', {
|
||||
provider_config: providerConfig
|
||||
})
|
||||
|
||||
if (response.data.status != "error" && response.data.data?.embedding_dimensions) {
|
||||
console.log(response.data.data.embedding_dimensions)
|
||||
providerConfig.embedding_dimensions = response.data.data.embedding_dimensions
|
||||
useToast().success("获取成功: " + response.data.data.embedding_dimensions)
|
||||
} else {
|
||||
useToast().error(response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error getting embedding dimensions:', error)
|
||||
} finally {
|
||||
loadingEmbeddingDim.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function getValueBySelector(obj, selector) {
|
||||
const keys = selector.split('.')
|
||||
let current = obj
|
||||
@@ -184,6 +211,29 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
v-model="iterable[key]"
|
||||
/>
|
||||
</div>
|
||||
<!-- Numeric input with get_embedding_dim button -->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'get_embedding_dim'"
|
||||
class="d-flex align-center gap-2">
|
||||
<v-text-field
|
||||
v-model="iterable[key]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
size="small"
|
||||
@click="getEmbeddingDimensions(iterable)"
|
||||
:loading="loadingEmbeddingDim"
|
||||
class="ml-2"
|
||||
>
|
||||
自动检测
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- List item with options-->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible && metadata[metadataKey].items[key]?.render_type === 'checkbox'"
|
||||
class="d-flex flex-wrap gap-20">
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddProviderDialog = true" rounded="xl" size="x-large">
|
||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddProviderDialog = true"
|
||||
rounded="xl" size="x-large">
|
||||
{{ tm('providers.addProvider') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -56,30 +57,16 @@
|
||||
|
||||
<v-row v-else>
|
||||
<v-col v-for="(provider, index) in filteredProviders" :key="index" cols="12" md="6" lg="4" xl="3">
|
||||
<item-card
|
||||
:item="provider"
|
||||
title-field="id"
|
||||
enabled-field="enable"
|
||||
:loading="isProviderTesting(provider.id)"
|
||||
@toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)"
|
||||
@delete="deleteProvider"
|
||||
@edit="configExistingProvider"
|
||||
@copy="copyProvider"
|
||||
:show-copy-button="true">
|
||||
<template #actions="{ item }">
|
||||
<v-btn
|
||||
style="z-index: 100000;"
|
||||
variant="tonal"
|
||||
color="info"
|
||||
rounded="xl"
|
||||
size="small"
|
||||
:loading="isProviderTesting(item.id)"
|
||||
@click="testSingleProvider(item)"
|
||||
>
|
||||
{{ tm('availability.test') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<item-card :item="provider" title-field="id" enabled-field="enable"
|
||||
:loading="isProviderTesting(provider.id)" @toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider"
|
||||
@copy="copyProvider" :show-copy-button="true">
|
||||
<template #actions="{ item }">
|
||||
<v-btn style="z-index: 100000;" variant="tonal" color="info" rounded="xl" size="small"
|
||||
:loading="isProviderTesting(item.id)" @click="testSingleProvider(item)">
|
||||
{{ tm('availability.test') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-slot:details="{ item }">
|
||||
</template>
|
||||
</item-card>
|
||||
@@ -115,16 +102,12 @@
|
||||
<v-col v-for="status in providerStatuses" :key="status.id" cols="12" sm="6" md="4">
|
||||
<v-card variant="outlined" class="status-card" :class="`status-${status.status}`">
|
||||
<v-card-item>
|
||||
<v-icon v-if="status.status === 'available'" color="success" class="me-2">mdi-check-circle</v-icon>
|
||||
<v-icon v-else-if="status.status === 'unavailable'" color="error" class="me-2">mdi-alert-circle</v-icon>
|
||||
<v-progress-circular
|
||||
v-else-if="status.status === 'pending'"
|
||||
indeterminate
|
||||
color="primary"
|
||||
size="20"
|
||||
width="2"
|
||||
class="me-2"
|
||||
></v-progress-circular>
|
||||
<v-icon v-if="status.status === 'available'" color="success"
|
||||
class="me-2">mdi-check-circle</v-icon>
|
||||
<v-icon v-else-if="status.status === 'unavailable'" color="error"
|
||||
class="me-2">mdi-alert-circle</v-icon>
|
||||
<v-progress-circular v-else-if="status.status === 'pending'" indeterminate color="primary"
|
||||
size="20" width="2" class="me-2"></v-progress-circular>
|
||||
|
||||
<span class="font-weight-bold">{{ status.id }}</span>
|
||||
|
||||
@@ -165,22 +148,16 @@
|
||||
</v-container>
|
||||
|
||||
<!-- 添加提供商对话框 -->
|
||||
<AddNewProvider
|
||||
v-model:show="showAddProviderDialog"
|
||||
:metadata="metadata"
|
||||
@select-template="selectProviderTemplate"
|
||||
/>
|
||||
<AddNewProvider v-model:show="showAddProviderDialog" :metadata="metadata"
|
||||
@select-template="selectProviderTemplate" />
|
||||
|
||||
<!-- 配置对话框 -->
|
||||
<v-dialog v-model="showProviderCfg" width="900" persistent>
|
||||
<v-card :title="updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') + ` ${newSelectedProviderName} ` + tm('dialogs.config.provider')">
|
||||
<v-card
|
||||
:title="updatingMode ? tm('dialogs.config.editTitle') : tm('dialogs.config.addTitle') + ` ${newSelectedProviderName} ` + tm('dialogs.config.provider')">
|
||||
<v-card-text class="py-4">
|
||||
<AstrBotConfig
|
||||
:iterable="newSelectedProviderConfig"
|
||||
:metadata="metadata['provider_group']?.metadata"
|
||||
metadataKey="provider"
|
||||
:is-editing="updatingMode"
|
||||
/>
|
||||
<AstrBotConfig :iterable="newSelectedProviderConfig" :metadata="metadata['provider_group']?.metadata"
|
||||
metadataKey="provider" :is-editing="updatingMode" />
|
||||
</v-card-text>
|
||||
|
||||
<v-divider></v-divider>
|
||||
@@ -467,7 +444,7 @@ export default {
|
||||
for (let key in source) {
|
||||
if (source.hasOwnProperty(key)) {
|
||||
if (typeof source[key] === 'object' && source[key] !== null) {
|
||||
target[key] = Array.isArray(source[key]) ? [...source[key]] : {...source[key]};
|
||||
target[key] = Array.isArray(source[key]) ? [...source[key]] : { ...source[key] };
|
||||
} else {
|
||||
target[key] = source[key];
|
||||
}
|
||||
@@ -528,7 +505,14 @@ export default {
|
||||
id: this.newSelectedProviderName,
|
||||
config: this.newSelectedProviderConfig
|
||||
});
|
||||
if (res.data.status === 'error') {
|
||||
this.showError(res.data.message || "更新失败!");
|
||||
return
|
||||
}
|
||||
this.showSuccess(res.data.message || "更新成功!");
|
||||
if (wasUpdating) {
|
||||
this.updatingMode = false;
|
||||
}
|
||||
} else {
|
||||
// 检查 ID 是否已存在
|
||||
const existingProvider = this.config_data.provider?.find(p => p.id === this.newSelectedProviderConfig.id);
|
||||
@@ -541,17 +525,18 @@ export default {
|
||||
}
|
||||
|
||||
const res = await axios.post('/api/config/provider/new', this.newSelectedProviderConfig);
|
||||
if (res.data.status === 'error') {
|
||||
this.showError(res.data.message || "添加失败!");
|
||||
return
|
||||
}
|
||||
this.showSuccess(res.data.message || "添加成功!");
|
||||
}
|
||||
this.showProviderCfg = false;
|
||||
this.getConfig();
|
||||
} catch (err) {
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
} finally {
|
||||
this.loading = false;
|
||||
if (wasUpdating) {
|
||||
this.updatingMode = false;
|
||||
}
|
||||
this.getConfig();
|
||||
}
|
||||
},
|
||||
|
||||
@@ -607,6 +592,10 @@ export default {
|
||||
id: provider.id,
|
||||
config: provider
|
||||
}).then((res) => {
|
||||
if (res.data.status === 'error') {
|
||||
this.showError(res.data.message)
|
||||
return
|
||||
}
|
||||
this.getConfig();
|
||||
this.showSuccess(res.data.message || this.messages.success.statusUpdate);
|
||||
}).catch((err) => {
|
||||
|
||||
Reference in New Issue
Block a user