feat: 支持并完善服务提供商默认配置模板接口

This commit is contained in:
Soulter
2025-01-13 02:05:57 +08:00
parent 8315cf5818
commit 7b23d76559
6 changed files with 45 additions and 10 deletions
+8 -3
View File
@@ -1,7 +1,12 @@
from dataclasses import dataclass
@dataclass
class PlatformMetadata():
name: str # 平台的名称
description: str # 平台的描述
name: str
'''平台的名称'''
description: str
'''平台的描述'''
default_config_tmpl: dict = None # 平台的默认配置模板
default_config_tmpl: dict = None
'''平台的默认配置模板'''
adapter_display_name: str = None
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''
+8 -2
View File
@@ -7,7 +7,12 @@ platform_registry: List[PlatformMetadata] = []
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
def register_platform_adapter(
adapter_name: str,
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None
):
'''用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
@@ -26,7 +31,8 @@ def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl:
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
+5
View File
@@ -17,6 +17,11 @@ class ProviderMetaData():
'''提供商适配器描述.'''
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
default_config_tmpl: dict = None
'''平台的默认配置模板'''
provider_display_name: str = None
'''显示在 WebUI 配置页中的提供商名称,如空则是 type'''
@dataclass
class ProviderRequest():
+14 -3
View File
@@ -13,22 +13,33 @@ llm_tools = FuncCall()
def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
default_config_tmpl: dict = None,
provider_display_name: str = None
):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = provider_type_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
provider_type=provider_type,
cls_type=cls
cls_type=cls,
default_config_tmpl=default_config_tmpl,
provider_display_name=provider_display_name
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = pm
logger.debug(f"Provider {provider_type_name} 已注册")
logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
return cls
return decorator
+8
View File
@@ -8,6 +8,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -123,11 +124,18 @@ class ConfigRoute(Route):
async def _get_astrbot_config(self):
config = self.config
# 平台适配器的默认配置模板注入
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
# 服务提供商的默认配置模板注入
provider_default_tmpl = CONFIG_METADATA_2['provider_group']['metadata']['provider']['config_template']
for provider in provider_registry:
if provider.default_config_tmpl:
provider_default_tmpl[provider.type] = provider.default_config_tmpl
return {
"metadata": CONFIG_METADATA_2,
"config": config
+2 -2
View File
@@ -150,7 +150,7 @@ class Main(star.Star):
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
return self.config["sandbox"]["image"]
async def _save_config(self):
def _save_config(self):
with open(PATH, "w") as f:
json.dump(self.config, f)
@@ -207,7 +207,7 @@ class Main(star.Star):
""")
else:
self.config["sandbox"]["docker_mirror"] = url
await self._save_config()
self._save_config()
yield event.plain_result("设置 Docker 镜像地址成功。")
@pi.command("repull")