diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index 28120a1ce..721192355 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -1,7 +1,12 @@ from dataclasses import dataclass @dataclass class PlatformMetadata(): - name: str # 平台的名称 - description: str # 平台的描述 + name: str + '''平台的名称''' + description: str + '''平台的描述''' - default_config_tmpl: dict = None # 平台的默认配置模板 \ No newline at end of file + default_config_tmpl: dict = None + '''平台的默认配置模板''' + adapter_display_name: str = None + '''显示在 WebUI 配置页中的平台名称,如空则是 name''' \ No newline at end of file diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index f451f5b6a..e66f8a22f 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -7,7 +7,12 @@ platform_registry: List[PlatformMetadata] = [] platform_cls_map: Dict[str, Type] = {} '''维护了平台适配器名称和适配器类的映射''' -def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None): +def register_platform_adapter( + adapter_name: str, + desc: str, + default_config_tmpl: dict = None, + adapter_display_name: str = None +): '''用于注册平台适配器的带参装饰器。 default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。 @@ -26,7 +31,8 @@ def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: pm = PlatformMetadata( name=adapter_name, description=desc, - default_config_tmpl=default_config_tmpl + default_config_tmpl=default_config_tmpl, + adapter_display_name=adapter_display_name ) platform_registry.append(pm) platform_cls_map[adapter_name] = cls diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 0a733e3b9..c9af724c1 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -17,6 +17,11 @@ class ProviderMetaData(): '''提供商适配器描述.''' provider_type: ProviderType = ProviderType.CHAT_COMPLETION cls_type: Type = None + + default_config_tmpl: dict = None + '''平台的默认配置模板''' + provider_display_name: str = None + '''显示在 WebUI 配置页中的提供商名称,如空则是 type''' @dataclass class ProviderRequest(): diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 61e64408f..ecc4e1ec4 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -13,22 +13,33 @@ llm_tools = FuncCall() def register_provider_adapter( provider_type_name: str, desc: str, - provider_type: ProviderType = ProviderType.CHAT_COMPLETION + provider_type: ProviderType = ProviderType.CHAT_COMPLETION, + default_config_tmpl: dict = None, + provider_display_name: str = None ): '''用于注册平台适配器的带参装饰器''' def decorator(cls): if provider_type_name in provider_cls_map: raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。") + + # 添加必备选项 + if default_config_tmpl: + if 'type' not in default_config_tmpl: + default_config_tmpl['type'] = provider_type_name + if 'enable' not in default_config_tmpl: + default_config_tmpl['enable'] = False pm = ProviderMetaData( type=provider_type_name, desc=desc, provider_type=provider_type, - cls_type=cls + cls_type=cls, + default_config_tmpl=default_config_tmpl, + provider_display_name=provider_display_name ) provider_registry.append(pm) provider_cls_map[provider_type_name] = pm - logger.debug(f"Provider {provider_type_name} 已注册") + logger.debug(f"服务提供商 Provider {provider_type_name} 已注册") return cls return decorator diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 1f80e50d3..0e14b6036 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -8,6 +8,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.star.config import update_config from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.platform.register import platform_registry +from astrbot.core.provider.register import provider_registry def try_cast(value: str, type_: str): if type_ == "int" and value.isdigit(): @@ -123,11 +124,18 @@ class ConfigRoute(Route): async def _get_astrbot_config(self): config = self.config + # 平台适配器的默认配置模板注入 platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template'] for platform in platform_registry: if platform.default_config_tmpl: platform_default_tmpl[platform.name] = platform.default_config_tmpl + # 服务提供商的默认配置模板注入 + provider_default_tmpl = CONFIG_METADATA_2['provider_group']['metadata']['provider']['config_template'] + for provider in provider_registry: + if provider.default_config_tmpl: + provider_default_tmpl[provider.type] = provider.default_config_tmpl + return { "metadata": CONFIG_METADATA_2, "config": config diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index a423e76c1..5cf4d70b9 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -150,7 +150,7 @@ class Main(star.Star): return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" return self.config["sandbox"]["image"] - async def _save_config(self): + def _save_config(self): with open(PATH, "w") as f: json.dump(self.config, f) @@ -207,7 +207,7 @@ class Main(star.Star): """) else: self.config["sandbox"]["docker_mirror"] = url - await self._save_config() + self._save_config() yield event.plain_result("设置 Docker 镜像地址成功。") @pi.command("repull")