refactor: enhance provider management with resource locking and CRUD operations

This commit is contained in:
Soulter
2025-12-17 17:08:52 +08:00
parent f52f375154
commit 591a228431
2 changed files with 49 additions and 43 deletions
+43 -2
View File
@@ -33,7 +33,7 @@ class ProviderManager:
persona_mgr: PersonaManager,
):
self.reload_lock = asyncio.Lock()
self.delete_lock = asyncio.Lock()
self.resource_lock = asyncio.Lock()
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
@@ -614,7 +614,8 @@ class ProviderManager:
async def delete_provider(
self, provider_id: str | None = None, provider_source_id: str | None = None
):
async with self.delete_lock:
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
async with self.resource_lock:
# delete from config
target_prov_ids = []
if provider_id:
@@ -632,6 +633,46 @@ class ProviderManager:
config.save_config()
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
async def update_provider(self, origin_provider_id: str, new_config: dict):
"""Update provider config and reload the instance. Config will be saved after update."""
async with self.resource_lock:
npid = new_config.get("id", None)
if not npid:
raise ValueError("New provider config must have an 'id' field")
config = self.acm.default_conf
for provider in config["provider"]:
if (
provider.get("id", None) == npid
and provider.get("id", None) != origin_provider_id
):
raise ValueError(f"Provider ID {npid} already exists")
# update config
for idx, provider in enumerate(config["provider"]):
if provider.get("id", None) == origin_provider_id:
config["provider"][idx] = new_config
break
else:
raise ValueError(f"Provider ID {origin_provider_id} not found")
config.save_config()
# reload instance
await self.reload(new_config)
async def create_provider(self, new_config: dict):
"""Add new provider config and load the instance. Config will be saved after addition."""
async with self.resource_lock:
npid = new_config.get("id", None)
if not npid:
raise ValueError("New provider config must have an 'id' field")
config = self.acm.default_conf
for provider in config["provider"]:
if provider.get("id", None) == npid:
raise ValueError(f"Provider ID {npid} already exists")
# add to config
config["provider"].append(new_config)
config.save_config()
# load instance
await self.load_provider(new_config)
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
+6 -41
View File
@@ -828,27 +828,13 @@ class ConfigRoute(Route):
async def post_new_provider(self):
new_provider_config = await request.json
# check id uniqueness
npid = new_provider_config.get("id", None)
if not npid:
return Response().error("服务提供商配置缺少 id 字段").__dict__
for provider in self.config["provider"]:
if provider.get("id", None) == npid:
return (
Response()
.error(f"provider with ID '{npid}' already exists")
.__dict__
)
self.config["provider"].append(new_provider_config)
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.load_provider(
new_provider_config,
await self.core_lifecycle.provider_manager.create_provider(
new_provider_config
)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "新增服务提供商配置成功~").__dict__
return Response().ok(None, "新增服务提供商配置成功").__dict__
async def post_update_platform(self):
update_platform_config = await request.json
@@ -884,31 +870,10 @@ class ConfigRoute(Route):
if not origin_provider_id or not new_config:
return Response().error("参数错误").__dict__
# check id uniqueness
npid = new_config.get("id", None)
if not npid:
return Response().error("服务提供商配置缺少 id 字段").__dict__
for provider in self.config["provider"]:
if (
provider.get("id", None) == npid
and provider.get("id", None) != origin_provider_id
):
return (
Response()
.error(f"provider with ID '{npid}' already exists")
.__dict__
)
for i, provider in enumerate(self.config["provider"]):
if provider["id"] == origin_provider_id:
self.config["provider"][i] = new_config
break
else:
return Response().error("未找到对应服务提供商").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.reload(new_config)
await self.core_lifecycle.provider_manager.update_provider(
origin_provider_id, new_config
)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新成功,已经实时生效~").__dict__