From 4651bd2807959f45828f67d641c01366d92cc1cc Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 17 Dec 2025 15:00:22 +0800 Subject: [PATCH] feat: implement provider deletion functionality and ensure unique provider IDs --- astrbot/core/provider/manager.py | 22 +++++++ astrbot/dashboard/routes/config.py | 66 +++++++++---------- .../src/composables/useProviderSources.ts | 19 +++++- dashboard/src/views/ProviderPage.vue | 9 ++- 4 files changed, 77 insertions(+), 39 deletions(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index bb6308a4b..92a63852c 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -33,6 +33,7 @@ class ProviderManager: persona_mgr: PersonaManager, ): self.reload_lock = asyncio.Lock() + self.delete_lock = asyncio.Lock() self.persona_mgr = persona_mgr self.acm = acm config = acm.confs["default"] @@ -610,6 +611,27 @@ class ProviderManager: ) del self.inst_map[provider_id] + async def delete_provider( + self, provider_id: str | None = None, provider_source_id: str | None = None + ): + async with self.delete_lock: + # delete from config + target_prov_ids = [] + if provider_id: + target_prov_ids.append(provider_id) + else: + for prov in self.providers_config: + if prov.get("provider_source_id") == provider_source_id: + target_prov_ids.append(prov.get("id")) + config = self.acm.default_conf + for tpid in target_prov_ids: + await self.terminate_provider(tpid) + config["provider"] = [ + prov for prov in config["provider"] if prov.get("id") != tpid + ] + config.save_config() + logger.info(f"Provider {target_prov_ids} 已从配置中删除。") + async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 3ee200b21..4d54d4b38 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -219,39 +219,20 @@ class ConfigRoute(Route): # 删除 provider_source del provider_sources[target_idx] - # 更新引用了该 provider_source 的 providers - affected_providers = [] - for provider in self.config.get("provider", []): - if provider.get("provider_source_id") == provider_source_id: - provider["provider_source_id"] = None - affected_providers.append(provider) - # 写回配置 self.config["provider_sources"] = provider_sources + # 删除引用了该 provider_source 的 providers + await self.core_lifecycle.provider_manager.delete_provider( + provider_source_id=provider_source_id + ) + try: save_config(self.config, self.config, is_core=True) except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - # 重载受影响的 providers,使新的 source 配置生效 - reload_errors = [] - prov_mgr = self.core_lifecycle.provider_manager - for provider in affected_providers: - try: - await prov_mgr.reload(provider) - except Exception as e: - logger.error(traceback.format_exc()) - reload_errors.append(f"{provider.get('id')}: {e}") - - if reload_errors: - return ( - Response() - .error("删除成功,但部分提供商重载失败: " + ", ".join(reload_errors)) - .__dict__ - ) - return Response().ok(message="删除 provider source 成功").__dict__ async def update_provider_source(self, provider_source_id: str): @@ -895,13 +876,28 @@ class ConfigRoute(Route): async def post_update_provider(self): update_provider_config = await request.json - provider_id = update_provider_config.get("id", None) + origin_provider_id = update_provider_config.get("id", None) new_config = update_provider_config.get("config", None) - if not provider_id or not new_config: + if not origin_provider_id or not new_config: return Response().error("参数错误").__dict__ + # check id uniqueness + npid = new_config.get("id", None) + if not npid: + return Response().error("服务提供商配置缺少 id 字段").__dict__ + for provider in self.config["provider"]: + if ( + provider.get("id", None) == npid + and provider.get("id", None) != origin_provider_id + ): + return ( + Response() + .error(f"provider with ID '{npid}' already exists") + .__dict__ + ) + for i, provider in enumerate(self.config["provider"]): - if provider["id"] == provider_id: + if provider["id"] == origin_provider_id: self.config["provider"][i] = new_config break else: @@ -933,18 +929,16 @@ class ConfigRoute(Route): async def post_delete_provider(self): provider_id = await request.json provider_id = provider_id.get("id", "") - for i, provider in enumerate(self.config["provider"]): - if provider["id"] == provider_id: - del self.config["provider"][i] - break - else: - return Response().error("未找到对应服务提供商").__dict__ + if not provider_id: + return Response().error("缺少参数 id").__dict__ + try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.provider_manager.terminate_provider(provider_id) + await self.core_lifecycle.provider_manager.delete_provider( + provider_id=provider_id + ) except Exception as e: return Response().error(str(e)).__dict__ - return Response().ok(None, "删除成功,已经实时生效~").__dict__ + return Response().ok(None, "删除成功,已经实时生效。").__dict__ async def get_llm_tools(self): """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" diff --git a/dashboard/src/composables/useProviderSources.ts b/dashboard/src/composables/useProviderSources.ts index 1eb69a210..9e73bd8fc 100644 --- a/dashboard/src/composables/useProviderSources.ts +++ b/dashboard/src/composables/useProviderSources.ts @@ -354,6 +354,20 @@ export function useProviderSources(options: UseProviderSourcesOptions) { return sourceFields } + function generateUniqueSourceId(baseId: string) { + const existingIds = new Set(providerSources.value.map((s: any) => s.id)) + if (!existingIds.has(baseId)) return baseId + + let counter = 1 + let candidate = `${baseId}_${counter}` + while (existingIds.has(candidate)) { + counter += 1 + candidate = `${baseId}_${counter}` + } + + return candidate + } + function addProviderSource(templateKey: string) { const template = providerTemplates.value[templateKey] if (!template) { @@ -361,7 +375,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) { return } - const newId = template.id + const newId = generateUniqueSourceId(template.id) const newSource = { ...extractSourceFieldsFromTemplate(template), id: newId, @@ -398,6 +412,8 @@ export function useProviderSources(options: UseProviderSourcesOptions) { showMessage(tm('providerSources.deleteSuccess')) } catch (error: any) { showMessage(error.message || tm('providerSources.deleteError'), 'error') + } finally { + await loadConfig() } } @@ -445,6 +461,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) { return false } finally { savingSource.value = false + loadConfig() } } diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index f4121f794..8a3ad6e64 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -337,18 +337,21 @@ const showAddProviderDialog = ref(false) const showProviderCfg = ref(false) const newSelectedProviderName = ref('') const newSelectedProviderConfig = ref({}) +const newProviderOriginalId = ref('') const updatingMode = ref(false) const loading = ref(false) const providerStatuses = ref([]) const showAgentRunnerDialog = ref(false) const showProviderEditDialog = ref(false) const providerEditData = ref(null) +const providerEditOriginalId = ref('') const showManualModelDialog = ref(false) const savingProviders = ref([]) function openProviderEdit(provider) { providerEditData.value = JSON.parse(JSON.stringify(provider)) + providerEditOriginalId.value = provider.id showProviderEditDialog.value = true } @@ -390,6 +393,7 @@ function getEmptyText() { function selectProviderTemplate(name) { newSelectedProviderName.value = name + newProviderOriginalId.value = '' showProviderCfg.value = true updatingMode.value = false newSelectedProviderConfig.value = JSON.parse(JSON.stringify( @@ -399,6 +403,7 @@ function selectProviderTemplate(name) { function configExistingProvider(provider) { newSelectedProviderName.value = provider.id + newProviderOriginalId.value = provider.id newSelectedProviderConfig.value = {} // 比对默认配置模版,看看是否有更新 @@ -460,7 +465,7 @@ async function newProvider() { try { if (wasUpdating) { const res = await axios.post('/api/config/provider/update', { - id: newSelectedProviderName.value, + id: newProviderOriginalId.value || newSelectedProviderName.value, config: newSelectedProviderConfig.value }) if (res.data.status === 'error') { @@ -494,7 +499,7 @@ async function saveEditedProvider() { savingProviders.value.push(providerEditData.value.id) try { const res = await axios.post('/api/config/provider/update', { - id: providerEditData.value.id, + id: providerEditOriginalId.value || providerEditData.value.id, config: providerEditData.value })