From 7b23d765594999c6d26b2f3207e8a6376c32d519 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 13 Jan 2025 02:05:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=B9=B6=E5=AE=8C?= =?UTF-8?q?=E5=96=84=E6=9C=8D=E5=8A=A1=E6=8F=90=E4=BE=9B=E5=95=86=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E9=85=8D=E7=BD=AE=E6=A8=A1=E6=9D=BF=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/platform/platform_metadata.py | 11 ++++++++--- astrbot/core/platform/register.py | 10 ++++++++-- astrbot/core/provider/entites.py | 5 +++++ astrbot/core/provider/register.py | 17 ++++++++++++++--- astrbot/dashboard/routes/config.py | 8 ++++++++ packages/python_interpreter/main.py | 4 ++-- 6 files changed, 45 insertions(+), 10 deletions(-) diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index 28120a1ce..721192355 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -1,7 +1,12 @@ from dataclasses import dataclass @dataclass class PlatformMetadata(): - name: str # 平台的名称 - description: str # 平台的描述 + name: str + '''平台的名称''' + description: str + '''平台的描述''' - default_config_tmpl: dict = None # 平台的默认配置模板 \ No newline at end of file + default_config_tmpl: dict = None + '''平台的默认配置模板''' + adapter_display_name: str = None + '''显示在 WebUI 配置页中的平台名称,如空则是 name''' \ No newline at end of file diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index f451f5b6a..e66f8a22f 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -7,7 +7,12 @@ platform_registry: List[PlatformMetadata] = [] platform_cls_map: Dict[str, Type] = {} '''维护了平台适配器名称和适配器类的映射''' -def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None): +def register_platform_adapter( + adapter_name: str, + desc: str, + default_config_tmpl: dict = None, + adapter_display_name: str = None +): '''用于注册平台适配器的带参装饰器。 default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。 @@ -26,7 +31,8 @@ def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: pm = PlatformMetadata( name=adapter_name, description=desc, - default_config_tmpl=default_config_tmpl + default_config_tmpl=default_config_tmpl, + adapter_display_name=adapter_display_name ) platform_registry.append(pm) platform_cls_map[adapter_name] = cls diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 0a733e3b9..c9af724c1 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -17,6 +17,11 @@ class ProviderMetaData(): '''提供商适配器描述.''' provider_type: ProviderType = ProviderType.CHAT_COMPLETION cls_type: Type = None + + default_config_tmpl: dict = None + '''平台的默认配置模板''' + provider_display_name: str = None + '''显示在 WebUI 配置页中的提供商名称,如空则是 type''' @dataclass class ProviderRequest(): diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 61e64408f..ecc4e1ec4 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -13,22 +13,33 @@ llm_tools = FuncCall() def register_provider_adapter( provider_type_name: str, desc: str, - provider_type: ProviderType = ProviderType.CHAT_COMPLETION + provider_type: ProviderType = ProviderType.CHAT_COMPLETION, + default_config_tmpl: dict = None, + provider_display_name: str = None ): '''用于注册平台适配器的带参装饰器''' def decorator(cls): if provider_type_name in provider_cls_map: raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。") + + # 添加必备选项 + if default_config_tmpl: + if 'type' not in default_config_tmpl: + default_config_tmpl['type'] = provider_type_name + if 'enable' not in default_config_tmpl: + default_config_tmpl['enable'] = False pm = ProviderMetaData( type=provider_type_name, desc=desc, provider_type=provider_type, - cls_type=cls + cls_type=cls, + default_config_tmpl=default_config_tmpl, + provider_display_name=provider_display_name ) provider_registry.append(pm) provider_cls_map[provider_type_name] = pm - logger.debug(f"Provider {provider_type_name} 已注册") + logger.debug(f"服务提供商 Provider {provider_type_name} 已注册") return cls return decorator diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 1f80e50d3..0e14b6036 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -8,6 +8,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.star.config import update_config from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.platform.register import platform_registry +from astrbot.core.provider.register import provider_registry def try_cast(value: str, type_: str): if type_ == "int" and value.isdigit(): @@ -123,11 +124,18 @@ class ConfigRoute(Route): async def _get_astrbot_config(self): config = self.config + # 平台适配器的默认配置模板注入 platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template'] for platform in platform_registry: if platform.default_config_tmpl: platform_default_tmpl[platform.name] = platform.default_config_tmpl + # 服务提供商的默认配置模板注入 + provider_default_tmpl = CONFIG_METADATA_2['provider_group']['metadata']['provider']['config_template'] + for provider in provider_registry: + if provider.default_config_tmpl: + provider_default_tmpl[provider.type] = provider.default_config_tmpl + return { "metadata": CONFIG_METADATA_2, "config": config diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index a423e76c1..5cf4d70b9 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -150,7 +150,7 @@ class Main(star.Star): return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" return self.config["sandbox"]["image"] - async def _save_config(self): + def _save_config(self): with open(PATH, "w") as f: json.dump(self.config, f) @@ -207,7 +207,7 @@ class Main(star.Star): """) else: self.config["sandbox"]["docker_mirror"] = url - await self._save_config() + self._save_config() yield event.plain_result("设置 Docker 镜像地址成功。") @pi.command("repull")