From 591a2284315fdc2b8fa39344d52082a8b6e583a7 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 17 Dec 2025 17:08:52 +0800 Subject: [PATCH] refactor: enhance provider management with resource locking and CRUD operations --- astrbot/core/provider/manager.py | 45 ++++++++++++++++++++++++++-- astrbot/dashboard/routes/config.py | 47 ++++-------------------------- 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 92a63852c..b23fdf4f5 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -33,7 +33,7 @@ class ProviderManager: persona_mgr: PersonaManager, ): self.reload_lock = asyncio.Lock() - self.delete_lock = asyncio.Lock() + self.resource_lock = asyncio.Lock() self.persona_mgr = persona_mgr self.acm = acm config = acm.confs["default"] @@ -614,7 +614,8 @@ class ProviderManager: async def delete_provider( self, provider_id: str | None = None, provider_source_id: str | None = None ): - async with self.delete_lock: + """Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion.""" + async with self.resource_lock: # delete from config target_prov_ids = [] if provider_id: @@ -632,6 +633,46 @@ class ProviderManager: config.save_config() logger.info(f"Provider {target_prov_ids} 已从配置中删除。") + async def update_provider(self, origin_provider_id: str, new_config: dict): + """Update provider config and reload the instance. Config will be saved after update.""" + async with self.resource_lock: + npid = new_config.get("id", None) + if not npid: + raise ValueError("New provider config must have an 'id' field") + config = self.acm.default_conf + for provider in config["provider"]: + if ( + provider.get("id", None) == npid + and provider.get("id", None) != origin_provider_id + ): + raise ValueError(f"Provider ID {npid} already exists") + # update config + for idx, provider in enumerate(config["provider"]): + if provider.get("id", None) == origin_provider_id: + config["provider"][idx] = new_config + break + else: + raise ValueError(f"Provider ID {origin_provider_id} not found") + config.save_config() + # reload instance + await self.reload(new_config) + + async def create_provider(self, new_config: dict): + """Add new provider config and load the instance. Config will be saved after addition.""" + async with self.resource_lock: + npid = new_config.get("id", None) + if not npid: + raise ValueError("New provider config must have an 'id' field") + config = self.acm.default_conf + for provider in config["provider"]: + if provider.get("id", None) == npid: + raise ValueError(f"Provider ID {npid} already exists") + # add to config + config["provider"].append(new_config) + config.save_config() + # load instance + await self.load_provider(new_config) + async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 59a7f509f..606e3df5e 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -828,27 +828,13 @@ class ConfigRoute(Route): async def post_new_provider(self): new_provider_config = await request.json - # check id uniqueness - npid = new_provider_config.get("id", None) - if not npid: - return Response().error("服务提供商配置缺少 id 字段").__dict__ - for provider in self.config["provider"]: - if provider.get("id", None) == npid: - return ( - Response() - .error(f"provider with ID '{npid}' already exists") - .__dict__ - ) - - self.config["provider"].append(new_provider_config) try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.provider_manager.load_provider( - new_provider_config, + await self.core_lifecycle.provider_manager.create_provider( + new_provider_config ) except Exception as e: return Response().error(str(e)).__dict__ - return Response().ok(None, "新增服务提供商配置成功~").__dict__ + return Response().ok(None, "新增服务提供商配置成功").__dict__ async def post_update_platform(self): update_platform_config = await request.json @@ -884,31 +870,10 @@ class ConfigRoute(Route): if not origin_provider_id or not new_config: return Response().error("参数错误").__dict__ - # check id uniqueness - npid = new_config.get("id", None) - if not npid: - return Response().error("服务提供商配置缺少 id 字段").__dict__ - for provider in self.config["provider"]: - if ( - provider.get("id", None) == npid - and provider.get("id", None) != origin_provider_id - ): - return ( - Response() - .error(f"provider with ID '{npid}' already exists") - .__dict__ - ) - - for i, provider in enumerate(self.config["provider"]): - if provider["id"] == origin_provider_id: - self.config["provider"][i] = new_config - break - else: - return Response().error("未找到对应服务提供商").__dict__ - try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.provider_manager.reload(new_config) + await self.core_lifecycle.provider_manager.update_provider( + origin_provider_id, new_config + ) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "更新成功,已经实时生效~").__dict__