From 8288d5e51f369d72181a77e855e016f6dc4783e9 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 30 May 2025 18:07:52 +0800 Subject: [PATCH 1/4] feat: embedding provider --- astrbot/core/config/default.py | 26 ++++++++ astrbot/core/provider/entities.py | 5 +- astrbot/core/provider/manager.py | 14 +++++ astrbot/core/provider/provider.py | 5 ++ .../sources/openai_embedding_source.py | 42 +++++++++++++ astrbot/core/star/context.py | 7 +-- astrbot/dashboard/routes/config.py | 12 ++++ dashboard/src/views/ProviderPage.vue | 63 +++++++++++++++++-- dashboard/src/views/alkaid/KnowledgeBase.vue | 51 +++++++++++++-- 9 files changed, 210 insertions(+), 15 deletions(-) create mode 100644 astrbot/core/provider/sources/openai_embedding_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 19960269a..dce1da52c 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -862,8 +862,34 @@ CONFIG_METADATA_2 = { "api_base": "https://openspeech.bytedance.com/api/v1/tts", "timeout": 20, }, + "OpenAI Embedding": { + "id": "openai_embedding", + "type": "openai_embedding", + "provider_type": "embedding", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "", + "embedding_model": "", + "embedding_dimensions": 1536, + "timeout": 20, + }, }, "items": { + "embedding_dimensions": { + "description": "嵌入维度", + "type": "int", + "hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。", + }, + "embedding_model": { + "description": "嵌入模型", + "type": "string", + "hint": "嵌入模型名称。", + }, + "embedding_api_key": { + "description": "API Key", + "type": "string", + "hint": "API Key", + }, "volcengine_cluster": { "type": "string", "description": "火山引擎集群", diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 6ad67da55..e01e46cf9 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -19,6 +19,7 @@ class ProviderType(enum.Enum): CHAT_COMPLETION = "chat_completion" SPEECH_TO_TEXT = "speech_to_text" TEXT_TO_SPEECH = "text_to_speech" + EMBEDDING = "embedding" @dataclass @@ -155,7 +156,9 @@ class ProviderRequest: if self.image_urls: user_content = { "role": "user", - "content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}], + "content": [ + {"type": "text", "text": self.prompt if self.prompt else "[图片]"} + ], } for image_url in self.image_urls: if image_url.startswith("http"): diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 78337ce95..edfd9f581 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -98,6 +98,8 @@ class ProviderManager: """加载的 Speech To Text Provider 的实例""" self.tts_provider_insts: List[TTSProvider] = [] """加载的 Text To Speech Provider 的实例""" + self.embedding_provider_insts: List[Provider] = [] + """加载的 Embedding Provider 的实例""" self.inst_map = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools @@ -211,6 +213,10 @@ class ProviderManager: from .sources.volcengine_tts import ( ProviderVolcengineTTS as ProviderVolcengineTTS, ) + case "openai_embedding": + from .sources.openai_embedding_source import ( + OpenAIEmbeddingProvider as OpenAIEmbeddingProvider, + ) except (ImportError, ModuleNotFoundError) as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。" @@ -290,6 +296,14 @@ class ProviderManager: if not self.curr_provider_inst: self.curr_provider_inst = inst + elif provider_metadata.provider_type == ProviderType.EMBEDDING: + inst = provider_metadata.cls_type( + provider_config, self.provider_settings + ) + if getattr(inst, "initialize", None): + await inst.initialize() + self.embedding_provider_insts.append(inst) + self.inst_map[provider_config["id"]] = inst except Exception as e: logger.error(traceback.format_exc()) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 7019113c7..c285ebd42 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -192,6 +192,11 @@ class EmbeddingProvider(AbstractProvider): """获取文本的向量""" ... + @abc.abstractmethod + async def get_embeddings(self, text: list[str]) -> list[list[float]]: + """批量获取文本的向量""" + ... + @abc.abstractmethod def get_dim(self) -> int: """获取向量的维度""" diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py new file mode 100644 index 000000000..2d339e57e --- /dev/null +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -0,0 +1,42 @@ +from openai import AsyncOpenAI +from ..provider import EmbeddingProvider +from ..register import register_provider_adapter +from ..entities import ProviderType + + +@register_provider_adapter( + "openai_embedding", + "OpenAI API Embedding 提供商适配器", + provider_type=ProviderType.EMBEDDING, +) +class OpenAIEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + self.client = AsyncOpenAI( + api_key=provider_config.get("embedding_api_key"), + base_url=provider_config.get( + "embedding_api_base", "https://api.openai.com/v1" + ), + ) + self.model = provider_config.get("embedding_model", "text-embedding-3-small") + self.dimension = provider_config.get("embedding_dimensions", 1536) + + async def get_embedding(self, text: str) -> list[float]: + """ + 获取文本的嵌入 + """ + embedding = await self.client.embeddings.create(input=text, model=self.model) + return embedding.data[0].embedding + + async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + """ + 批量获取文本的嵌入 + """ + embeddings = await self.client.embeddings.create(input=texts, model=self.model) + return [item.embedding for item in embeddings.data] + + def get_dim(self) -> int: + """获取向量的维度""" + return self.dimension diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 996b8ae5e..880b0c72c 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -125,11 +125,8 @@ class Context: self.provider_manager.provider_insts.append(provider) def get_provider_by_id(self, provider_id: str) -> Provider: - """通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" - for provider in self.provider_manager.provider_insts: - if provider.meta().id == provider_id: - return provider - return None + """通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。""" + return self.provider_manager.inst_map.get(provider_id) def get_all_providers(self) -> List[Provider]: """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index b2677de10..2d214b77c 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -164,6 +164,7 @@ class ConfigRoute(Route): "/config/provider/update": ("POST", self.post_update_provider), "/config/provider/delete": ("POST", self.post_delete_provider), "/config/llmtools": ("GET", self.get_llm_tools), + "/config/provider/list": ("GET", self.get_provider_config_list), } self.register_routes() @@ -175,6 +176,17 @@ class ConfigRoute(Route): return Response().ok(await self._get_astrbot_config()).__dict__ return Response().ok(await self._get_plugin_config(plugin_name)).__dict__ + async def get_provider_config_list(self): + provider_type = request.args.get("provider_type", None) + if not provider_type: + return Response().error("缺少参数 provider_type").__dict__ + provider_list = [] + astrbot_config = self.core_lifecycle.astrbot_config + for provider in astrbot_config["provider"]: + if provider.get("provider_type", None) == provider_type: + provider_list.append(provider) + return Response().ok(provider_list).__dict__ + async def post_astrbot_configs(self): post_configs = await request.json try: diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 8bad17e4d..53e7c8bc7 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -27,13 +27,39 @@ + + + + + mdi-filter-variant + 全部 + + + mdi-message-text + 基本对话 + + + mdi-microphone-message + 语音转文字 + + + mdi-volume-high + 文字转语音 + + + mdi-code-json + Embedding + + + + mdi-volume-high 文字转语音 + + mdi-code-json + Embedding + - @@ -225,6 +255,21 @@ export default { // 新增提供商对话框相关 showAddProviderDialog: false, activeProviderTab: 'chat_completion', + + // 添加提供商类型分类 + activeProviderTypeTab: 'all', + } + }, + + computed: { + // 根据选择的标签过滤提供商列表 + filteredProviders() { + if (!this.config_data.provider || this.activeProviderTypeTab === 'all') { + return this.config_data.provider || []; + } + + return this.config_data.provider.filter(provider => + provider.provider_type === this.activeProviderTypeTab); } }, @@ -243,6 +288,15 @@ export default { }); }, + // 获取空列表文本 + getEmptyText() { + if (this.activeProviderTypeTab === 'all') { + return "暂无服务提供商,点击 新增服务提供商 添加"; + } else { + return `暂无${this.getTabTypeName(this.activeProviderTypeTab)}类型的服务提供商,点击 新增服务提供商 添加`; + } + }, + // 按提供商类型获取模板列表 getTemplatesByType(type) { const templates = this.metadata['provider_group']?.metadata?.provider?.config_template || {}; @@ -294,7 +348,8 @@ export default { const names = { 'chat_completion': '基本对话', 'speech_to_text': '语音转文本', - 'text_to_speech': '文本转语音' + 'text_to_speech': '文本转语音', + 'embedding': 'Embedding' }; return names[tabType] || tabType; }, diff --git a/dashboard/src/views/alkaid/KnowledgeBase.vue b/dashboard/src/views/alkaid/KnowledgeBase.vue index 5b678fac1..c3f0b4244 100644 --- a/dashboard/src/views/alkaid/KnowledgeBase.vue +++ b/dashboard/src/views/alkaid/KnowledgeBase.vue @@ -72,6 +72,10 @@ + + + @@ -256,7 +260,8 @@ export default { newKB: { name: '', emoji: '🙂', - description: '' + description: '', + embedding_provider_id: '' }, snackbar: { show: false, @@ -306,13 +311,21 @@ export default { deleteTarget: { collection_name: '' }, - deleting: false + deleting: false, + embeddingProviderConfigs: [] } }, mounted() { this.checkPlugin(); + this.getEmbeddingProviderList(); }, methods: { + embeddingModelProps(providerConfig) { + return { + title: providerConfig.embedding_model, + subtitle: `提供商 ID: ${providerConfig.id}`, + } + }, checkPlugin() { axios.get('/api/plugin/get?name=astrbot_plugin_knowledge_base') .then(response => { @@ -365,10 +378,15 @@ export default { }, createCollection(name, emoji, description) { + // 如果 this.newKB.embedding_provider_id 是 Object + if (typeof this.newKB.embedding_provider_id === 'object') { + this.newKB.embedding_provider_id = this.newKB.embedding_provider_id.id || ''; + } axios.post('/api/plug/alkaid/kb/create_collection', { collection_name: name, emoji: emoji, - description: description + description: description, + embedding_provider_id: this.newKB.embedding_provider_id || '' }) .then(response => { if (response.data.status === 'ok') { @@ -394,7 +412,8 @@ export default { this.createCollection( this.newKB.name, this.newKB.emoji || '🙂', - this.newKB.description + this.newKB.description, + this.newKB.embedding_provider_id || '' ); }, @@ -402,7 +421,8 @@ export default { this.newKB = { name: '', emoji: '🙂', - description: '' + description: '', + embedding_provider: '' }; }, @@ -582,6 +602,27 @@ export default { this.deleting = false; }); }, + + getEmbeddingProviderList() { + axios.get('/api/config/provider/list', { + params: { + provider_type: 'embedding' + } + }) + .then(response => { + if (response.data.status === 'ok') { + this.embeddingProviderConfigs = response.data.data || []; + } else { + this.showSnackbar(response.data.message || '获取嵌入模型列表失败', 'error'); + return []; + } + }) + .catch(error => { + console.error('Error fetching embedding providers:', error); + this.showSnackbar('获取嵌入模型列表失败', 'error'); + return []; + }); + } } } From 3c13b5049d94e3e984a8e5b16a261141f538f219 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 30 May 2025 23:00:37 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E7=9A=84=E5=88=86=E7=89=87=E3=80=81=E9=87=8D?= =?UTF-8?q?=E5=8F=A0=E8=AE=BE=E7=BD=AE=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dashboard/src/views/alkaid/KnowledgeBase.vue | 116 +++++++++++++++++-- 1 file changed, 108 insertions(+), 8 deletions(-) diff --git a/dashboard/src/views/alkaid/KnowledgeBase.vue b/dashboard/src/views/alkaid/KnowledgeBase.vue index c3f0b4244..e31644e47 100644 --- a/dashboard/src/views/alkaid/KnowledgeBase.vue +++ b/dashboard/src/views/alkaid/KnowledgeBase.vue @@ -5,8 +5,8 @@

还没有安装知识库插件

- + 立即安装
@@ -49,9 +49,9 @@
Tips: 在聊天页面通过 /kb 指令了解如何使用!
- + - + @@ -73,9 +73,11 @@ - + + + Tips: 一旦选择了一个知识库的嵌入模型,请不要再修改该提供商的模型或者向量维度信息,否则将严重影响该知识库的召回率甚至报错。 @@ -122,6 +124,18 @@ +
+ + mdi-database + 嵌入模型: {{ currentKB._embedding_provider_config.embedding_model }} + + + mdi-vector-point + 向量维度: {{ currentKB._embedding_provider_config.embedding_dimensions }} + + 💡 使用方式: 在聊天页中输入 “/kb use {{ currentKB.collection_name }}” +
+ 上传文件 @@ -144,6 +158,54 @@

拖放文件到这里或点击上传

+ + + + mdi-puzzle-outline + 分片设置 + + + + 分片长度决定每块文本的大小,重叠长度决定相邻文本块之间的重叠程度。
+ 较小的分片更精确但会增加数量,适当的重叠可提高检索准确性。 +
+
+
+ +
+ + + +
+
+
+
@@ -301,6 +363,8 @@ export default { }, activeTab: 'upload', selectedFile: null, + chunkSize: null, + overlap: null, uploading: false, searchQuery: '', searchResults: [], @@ -323,7 +387,7 @@ export default { embeddingModelProps(providerConfig) { return { title: providerConfig.embedding_model, - subtitle: `提供商 ID: ${providerConfig.id}`, + subtitle: `提供商 ID: ${providerConfig.id} | 嵌入模型维度: ${providerConfig.embedding_dimensions}`, } }, checkPlugin() { @@ -439,6 +503,9 @@ export default { this.searchQuery = ''; this.searchResults = []; this.searchPerformed = false; + // 重置分片长度和重叠长度参数 + this.chunkSize = null; + this.overlap = null; }, triggerFileInput() { @@ -492,6 +559,15 @@ export default { const formData = new FormData(); formData.append('file', this.selectedFile); formData.append('collection_name', this.currentKB.collection_name); + + // 添加可选的分片长度和重叠长度参数 + if (this.chunkSize && this.chunkSize > 0) { + formData.append('chunk_size', this.chunkSize); + } + + if (this.overlap && this.overlap >= 0) { + formData.append('chunk_overlap', this.overlap); + } axios.post('/api/plug/alkaid/kb/collection/add_file', formData, { headers: { @@ -500,7 +576,7 @@ export default { }) .then(response => { if (response.data.status === 'ok') { - this.showSnackbar('文件上传成功'); + this.showSnackbar('操作成功: ' + response.data.message); this.selectedFile = null; // 刷新知识库列表,获取更新的数量 @@ -792,4 +868,28 @@ export default { .kb-card:hover .kb-actions { opacity: 1; } + +.chunk-settings-card { + border: 1px solid rgba(92, 107, 192, 0.2) !important; + transition: all 0.3s ease; +} + +.chunk-settings-card:hover { + border-color: rgba(92, 107, 192, 0.4) !important; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.07) !important; +} + +.chunk-field :deep(.v-field__input) { + padding-top: 8px; + padding-bottom: 8px; +} + +.chunk-field :deep(.v-field__prepend-inner) { + padding-right: 8px; + opacity: 0.7; +} + +.chunk-field:focus-within :deep(.v-field__prepend-inner) { + opacity: 1; +} From 40c27d87f5989e94c48ee772d3e32e07ed9ae4bb Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 30 May 2025 23:18:19 +0800 Subject: [PATCH 3/4] feat: knowledge-base --- .../full/vertical-sidebar/sidebarItem.ts | 10 +- dashboard/src/views/alkaid/KnowledgeBase.vue | 8 +- dashboard/src/views/alkaid/LongTermMemory.vue | 111 +++++++----------- 3 files changed, 53 insertions(+), 76 deletions(-) diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index f541d1d91..e8f49c741 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -65,11 +65,11 @@ const sidebarItem: menu[] = [ icon: 'mdi-console', to: '/console' }, - // { - // title: 'Alkaid', - // icon: 'mdi-test-tube', - // to: '/alkaid' - // }, + { + title: 'Alkaid', + icon: 'mdi-test-tube', + to: '/alkaid' + }, { title: '关于', icon: 'mdi-information', diff --git a/dashboard/src/views/alkaid/KnowledgeBase.vue b/dashboard/src/views/alkaid/KnowledgeBase.vue index e31644e47..29889f41f 100644 --- a/dashboard/src/views/alkaid/KnowledgeBase.vue +++ b/dashboard/src/views/alkaid/KnowledgeBase.vue @@ -18,7 +18,9 @@
-

知识库列表

+

知识库列表 + mdi-information-outline +

创建知识库 @@ -698,6 +700,10 @@ export default { this.showSnackbar('获取嵌入模型列表失败', 'error'); return []; }); + }, + + openUrl(url) { + window.open(url, '_blank'); } } } diff --git a/dashboard/src/views/alkaid/LongTermMemory.vue b/dashboard/src/views/alkaid/LongTermMemory.vue index 45d687fec..e534c70d8 100644 --- a/dashboard/src/views/alkaid/LongTermMemory.vue +++ b/dashboard/src/views/alkaid/LongTermMemory.vue @@ -1,6 +1,11 @@