feat: 支持注册消息平台适配器的 logo (#2109)

* feat: 添加平台适配器 logo 支持

* 优化平台logo注册逻辑,增加缓存机制并支持并行处理

* 去除判断绝对路径

---------

Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
kterna
2025-10-02 14:36:15 +08:00
committed by GitHub
parent cef0c22f52
commit 8bdd748aec
4 changed files with 103 additions and 4 deletions
@@ -14,3 +14,5 @@ class PlatformMetadata:
"""平台的默认配置模板"""
adapter_display_name: str = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
+3
View File
@@ -13,10 +13,12 @@ def register_platform_adapter(
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None,
logo_path: str = None,
):
"""用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
"""
def decorator(cls):
@@ -39,6 +41,7 @@ def register_platform_adapter(
description=desc,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
+88 -2
View File
@@ -1,6 +1,7 @@
import typing
import traceback
import os
import inspect
from .route import Route, Response, RouteContext
from astrbot.core.provider.entities import ProviderType
from quart import request
@@ -13,10 +14,10 @@ from astrbot.core.config.default import (
from astrbot.core.utils.astrbot_path import get_astrbot_path
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.platform.register import platform_registry, platform_cls_map
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
from astrbot.core import logger, file_token_service
from astrbot.core.provider import Provider
from astrbot.core.provider.provider import RerankProvider
import asyncio
@@ -149,6 +150,7 @@ class ConfigRoute(Route):
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.config: AstrBotConfig = core_lifecycle.astrbot_config
self._logo_token_cache = {} # 缓存logo token,避免重复注册
self.acm = core_lifecycle.astrbot_config_mgr
self.routes = {
"/config/abconf/new": ("POST", self.create_abconf),
@@ -655,6 +657,78 @@ class ConfigRoute(Route):
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def get_llm_tools(self):
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = tool_mgr.get_func_desc_openai_style()
return Response().ok(tools).__dict__
async def _register_platform_logo(self, platform, platform_default_tmpl):
"""注册平台logo文件并生成访问令牌"""
if not platform.logo_path:
return
try:
# 检查缓存
cache_key = f"{platform.name}:{platform.logo_path}"
if cache_key in self._logo_token_cache:
cached_token = self._logo_token_cache[cache_key]
# 确保platform_default_tmpl[platform.name]存在且为字典
if platform.name not in platform_default_tmpl:
platform_default_tmpl[platform.name] = {}
elif not isinstance(platform_default_tmpl[platform.name], dict):
platform_default_tmpl[platform.name] = {}
platform_default_tmpl[platform.name]["logo_token"] = cached_token
logger.debug(f"Using cached logo token for platform {platform.name}")
return
# 获取平台适配器类
platform_cls = platform_cls_map.get(platform.name)
if not platform_cls:
logger.warning(f"Platform class not found for {platform.name}")
return
# 获取插件目录路径
module_file = inspect.getfile(platform_cls)
plugin_dir = os.path.dirname(module_file)
# 解析logo文件路径
logo_file_path = os.path.join(plugin_dir, platform.logo_path)
# 检查文件是否存在并注册令牌
if os.path.exists(logo_file_path):
logo_token = await file_token_service.register_file(
logo_file_path, timeout=3600
)
# 确保platform_default_tmpl[platform.name]存在且为字典
if platform.name not in platform_default_tmpl:
platform_default_tmpl[platform.name] = {}
elif not isinstance(platform_default_tmpl[platform.name], dict):
platform_default_tmpl[platform.name] = {}
platform_default_tmpl[platform.name]["logo_token"] = logo_token
# 缓存token
self._logo_token_cache[cache_key] = logo_token
logger.debug(f"Logo token registered for platform {platform.name}")
else:
logger.warning(
f"Platform {platform.name} logo file not found: {logo_file_path}"
)
except (ImportError, AttributeError) as e:
logger.warning(
f"Failed to import required modules for platform {platform.name}: {e}"
)
except (OSError, IOError) as e:
logger.warning(f"File system error for platform {platform.name} logo: {e}")
except Exception as e:
logger.warning(
f"Unexpected error registering logo for platform {platform.name}: {e}"
)
async def _get_astrbot_config(self):
config = self.config
@@ -662,9 +736,21 @@ class ConfigRoute(Route):
platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][
"platform"
]["config_template"]
# 收集需要注册logo的平台
logo_registration_tasks = []
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
# 收集logo注册任务
if platform.logo_path:
logo_registration_tasks.append(
self._register_platform_logo(platform, platform_default_tmpl)
)
# 并行执行logo注册
if logo_registration_tasks:
await asyncio.gather(*logo_registration_tasks, return_exceptions=True)
# 服务提供商的默认配置模板注入
provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][
+10 -2
View File
@@ -114,7 +114,7 @@
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm')
}}</v-btn>
}}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -241,7 +241,15 @@ export default {
methods: {
// 从工具函数导入
getPlatformIcon,
getPlatformIcon(platform_id) {
// 首先检查是否有来自插件的 logo_token
const template = this.metadata['platform_group']?.metadata?.platform?.config_template?.[platform_id];
if (template && template.logo_token) {
// 通过文件服务访问插件提供的 logo
return `/api/file/${template.logo_token}`;
}
return getPlatformIcon(platform_id);
},
openTutorial() {
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);