feat: embedding provider

This commit is contained in:
Soulter
2025-05-30 18:07:52 +08:00
parent ebb6665f64
commit 8288d5e51f
9 changed files with 210 additions and 15 deletions
+26
View File
@@ -862,8 +862,34 @@ CONFIG_METADATA_2 = {
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
"timeout": 20,
},
"OpenAI Embedding": {
"id": "openai_embedding",
"type": "openai_embedding",
"provider_type": "embedding",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
"embedding_model": "",
"embedding_dimensions": 1536,
"timeout": 20,
},
},
"items": {
"embedding_dimensions": {
"description": "嵌入维度",
"type": "int",
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
},
"embedding_model": {
"description": "嵌入模型",
"type": "string",
"hint": "嵌入模型名称。",
},
"embedding_api_key": {
"description": "API Key",
"type": "string",
"hint": "API Key",
},
"volcengine_cluster": {
"type": "string",
"description": "火山引擎集群",
+4 -1
View File
@@ -19,6 +19,7 @@ class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
EMBEDDING = "embedding"
@dataclass
@@ -155,7 +156,9 @@ class ProviderRequest:
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
+14
View File
@@ -98,6 +98,8 @@ class ProviderManager:
"""加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[Provider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -211,6 +213,10 @@ class ProviderManager:
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
case "openai_embedding":
from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
@@ -290,6 +296,14 @@ class ProviderManager:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
logger.error(traceback.format_exc())
+5
View File
@@ -192,6 +192,11 @@ class EmbeddingProvider(AbstractProvider):
"""获取文本的向量"""
...
@abc.abstractmethod
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的向量"""
...
@abc.abstractmethod
def get_dim(self) -> int:
"""获取向量的维度"""
@@ -0,0 +1,42 @@
from openai import AsyncOpenAI
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
from ..entities import ProviderType
@register_provider_adapter(
"openai_embedding",
"OpenAI API Embedding 提供商适配器",
provider_type=ProviderType.EMBEDDING,
)
class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
base_url=provider_config.get(
"embedding_api_base", "https://api.openai.com/v1"
),
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
self.dimension = provider_config.get("embedding_dimensions", 1536)
async def get_embedding(self, text: str) -> list[float]:
"""
获取文本的嵌入
"""
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
批量获取文本的嵌入
"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension
+2 -5
View File
@@ -125,11 +125,8 @@ class Context:
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
"""通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
def get_all_providers(self) -> List[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
+12
View File
@@ -164,6 +164,7 @@ class ConfigRoute(Route):
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
"/config/llmtools": ("GET", self.get_llm_tools),
"/config/provider/list": ("GET", self.get_provider_config_list),
}
self.register_routes()
@@ -175,6 +176,17 @@ class ConfigRoute(Route):
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
async def get_provider_config_list(self):
provider_type = request.args.get("provider_type", None)
if not provider_type:
return Response().error("缺少参数 provider_type").__dict__
provider_list = []
astrbot_config = self.core_lifecycle.astrbot_config
for provider in astrbot_config["provider"]:
if provider.get("provider_type", None) == provider_type:
provider_list.append(provider)
return Response().ok(provider_list).__dict__
async def post_astrbot_configs(self):
post_configs = await request.json
try:
+59 -4
View File
@@ -27,13 +27,39 @@
<v-divider></v-divider>
<!-- 添加分类标签页 -->
<v-card-text class="px-4 pt-3 pb-0">
<v-tabs v-model="activeProviderTypeTab" bg-color="transparent">
<v-tab value="all" class="font-weight-medium px-3">
<v-icon start>mdi-filter-variant</v-icon>
全部
</v-tab>
<v-tab value="chat_completion" class="font-weight-medium px-3">
<v-icon start>mdi-message-text</v-icon>
基本对话
</v-tab>
<v-tab value="speech_to_text" class="font-weight-medium px-3">
<v-icon start>mdi-microphone-message</v-icon>
语音转文字
</v-tab>
<v-tab value="text_to_speech" class="font-weight-medium px-3">
<v-icon start>mdi-volume-high</v-icon>
文字转语音
</v-tab>
<v-tab value="embedding" class="font-weight-medium px-3">
<v-icon start>mdi-code-json</v-icon>
Embedding
</v-tab>
</v-tabs>
</v-card-text>
<v-card-text class="px-4 py-3">
<item-card-grid
:items="config_data.provider || []"
:items="filteredProviders"
title-field="id"
enabled-field="enable"
empty-icon="mdi-api-off"
empty-text="暂无服务提供商点击 新增服务提供商 添加"
:empty-text="getEmptyText()"
@toggle-enabled="providerStatusChange"
@delete="deleteProvider"
@edit="configExistingProvider"
@@ -109,10 +135,14 @@
<v-icon start>mdi-volume-high</v-icon>
文字转语音
</v-tab>
<v-tab value="embedding" class="font-weight-medium px-3">
<v-icon start>mdi-code-json</v-icon>
Embedding
</v-tab>
</v-tabs>
<v-window v-model="activeProviderTab" class="mt-4">
<v-window-item v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech']"
<v-window-item v-for="tabType in ['chat_completion', 'speech_to_text', 'text_to_speech', 'embedding']"
:key="tabType"
:value="tabType">
<v-row class="mt-1">
@@ -225,6 +255,21 @@ export default {
// 新增提供商对话框相关
showAddProviderDialog: false,
activeProviderTab: 'chat_completion',
// 添加提供商类型分类
activeProviderTypeTab: 'all',
}
},
computed: {
// 根据选择的标签过滤提供商列表
filteredProviders() {
if (!this.config_data.provider || this.activeProviderTypeTab === 'all') {
return this.config_data.provider || [];
}
return this.config_data.provider.filter(provider =>
provider.provider_type === this.activeProviderTypeTab);
}
},
@@ -243,6 +288,15 @@ export default {
});
},
// 获取空列表文本
getEmptyText() {
if (this.activeProviderTypeTab === 'all') {
return "暂无服务提供商,点击 新增服务提供商 添加";
} else {
return `暂无${this.getTabTypeName(this.activeProviderTypeTab)}类型的服务提供商,点击 新增服务提供商 添加`;
}
},
// 按提供商类型获取模板列表
getTemplatesByType(type) {
const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {};
@@ -294,7 +348,8 @@ export default {
const names = {
'chat_completion': '基本对话',
'speech_to_text': '语音转文本',
'text_to_speech': '文本转语音'
'text_to_speech': '文本转语音',
'embedding': 'Embedding'
};
return names[tabType] || tabType;
},
+46 -5
View File
@@ -72,6 +72,10 @@
<v-textarea v-model="newKB.description" label="描述" variant="outlined" placeholder="知识库的简短描述..."
rows="3"></v-textarea>
<v-select v-model="newKB.embedding_provider_id" :items="embeddingProviderConfigs" :item-props="embeddingModelProps" label="Embedding(嵌入)模型"
variant="outlined" class="mt-2">
</v-select>
</v-form>
</v-card-text>
<v-card-actions>
@@ -256,7 +260,8 @@ export default {
newKB: {
name: '',
emoji: '🙂',
description: ''
description: '',
embedding_provider_id: ''
},
snackbar: {
show: false,
@@ -306,13 +311,21 @@ export default {
deleteTarget: {
collection_name: ''
},
deleting: false
deleting: false,
embeddingProviderConfigs: []
}
},
mounted() {
this.checkPlugin();
this.getEmbeddingProviderList();
},
methods: {
embeddingModelProps(providerConfig) {
return {
title: providerConfig.embedding_model,
subtitle: `提供商 ID: ${providerConfig.id}`,
}
},
checkPlugin() {
axios.get('/api/plugin/get?name=astrbot_plugin_knowledge_base')
.then(response => {
@@ -365,10 +378,15 @@ export default {
},
createCollection(name, emoji, description) {
// 如果 this.newKB.embedding_provider_id 是 Object
if (typeof this.newKB.embedding_provider_id === 'object') {
this.newKB.embedding_provider_id = this.newKB.embedding_provider_id.id || '';
}
axios.post('/api/plug/alkaid/kb/create_collection', {
collection_name: name,
emoji: emoji,
description: description
description: description,
embedding_provider_id: this.newKB.embedding_provider_id || ''
})
.then(response => {
if (response.data.status === 'ok') {
@@ -394,7 +412,8 @@ export default {
this.createCollection(
this.newKB.name,
this.newKB.emoji || '🙂',
this.newKB.description
this.newKB.description,
this.newKB.embedding_provider_id || ''
);
},
@@ -402,7 +421,8 @@ export default {
this.newKB = {
name: '',
emoji: '🙂',
description: ''
description: '',
embedding_provider: ''
};
},
@@ -582,6 +602,27 @@ export default {
this.deleting = false;
});
},
getEmbeddingProviderList() {
axios.get('/api/config/provider/list', {
params: {
provider_type: 'embedding'
}
})
.then(response => {
if (response.data.status === 'ok') {
this.embeddingProviderConfigs = response.data.data || [];
} else {
this.showSnackbar(response.data.message || '获取嵌入模型列表失败', 'error');
return [];
}
})
.catch(error => {
console.error('Error fetching embedding providers:', error);
this.showSnackbar('获取嵌入模型列表失败', 'error');
return [];
});
}
}
}
</script>