feat: 支持注册消息平台适配器的 logo (#2109)
* feat: 添加平台适配器 logo 支持 * 优化平台logo注册逻辑,增加缓存机制并支持并行处理 * 去除判断绝对路径 --------- Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -14,3 +14,5 @@ class PlatformMetadata:
|
||||
"""平台的默认配置模板"""
|
||||
adapter_display_name: str = None
|
||||
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||
logo_path: str = None
|
||||
"""平台适配器的 logo 文件路径(相对于插件目录)"""
|
||||
|
||||
@@ -13,10 +13,12 @@ def register_platform_adapter(
|
||||
desc: str,
|
||||
default_config_tmpl: dict = None,
|
||||
adapter_display_name: str = None,
|
||||
logo_path: str = None,
|
||||
):
|
||||
"""用于注册平台适配器的带参装饰器。
|
||||
|
||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
@@ -39,6 +41,7 @@ def register_platform_adapter(
|
||||
description=desc,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import typing
|
||||
import traceback
|
||||
import os
|
||||
import inspect
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
@@ -13,10 +14,10 @@ from astrbot.core.config.default import (
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.platform.register import platform_registry, platform_cls_map
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, file_token_service
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
import asyncio
|
||||
@@ -149,6 +150,7 @@ class ConfigRoute(Route):
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config: AstrBotConfig = core_lifecycle.astrbot_config
|
||||
self._logo_token_cache = {} # 缓存logo token,避免重复注册
|
||||
self.acm = core_lifecycle.astrbot_config_mgr
|
||||
self.routes = {
|
||||
"/config/abconf/new": ("POST", self.create_abconf),
|
||||
@@ -655,6 +657,78 @@ class ConfigRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "删除成功,已经实时生效~").__dict__
|
||||
|
||||
async def get_llm_tools(self):
|
||||
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
|
||||
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
tools = tool_mgr.get_func_desc_openai_style()
|
||||
return Response().ok(tools).__dict__
|
||||
|
||||
async def _register_platform_logo(self, platform, platform_default_tmpl):
|
||||
"""注册平台logo文件并生成访问令牌"""
|
||||
if not platform.logo_path:
|
||||
return
|
||||
|
||||
try:
|
||||
# 检查缓存
|
||||
cache_key = f"{platform.name}:{platform.logo_path}"
|
||||
if cache_key in self._logo_token_cache:
|
||||
cached_token = self._logo_token_cache[cache_key]
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl:
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
elif not isinstance(platform_default_tmpl[platform.name], dict):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
platform_default_tmpl[platform.name]["logo_token"] = cached_token
|
||||
logger.debug(f"Using cached logo token for platform {platform.name}")
|
||||
return
|
||||
|
||||
# 获取平台适配器类
|
||||
platform_cls = platform_cls_map.get(platform.name)
|
||||
if not platform_cls:
|
||||
logger.warning(f"Platform class not found for {platform.name}")
|
||||
return
|
||||
|
||||
# 获取插件目录路径
|
||||
module_file = inspect.getfile(platform_cls)
|
||||
plugin_dir = os.path.dirname(module_file)
|
||||
|
||||
# 解析logo文件路径
|
||||
logo_file_path = os.path.join(plugin_dir, platform.logo_path)
|
||||
|
||||
# 检查文件是否存在并注册令牌
|
||||
if os.path.exists(logo_file_path):
|
||||
logo_token = await file_token_service.register_file(
|
||||
logo_file_path, timeout=3600
|
||||
)
|
||||
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl:
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
elif not isinstance(platform_default_tmpl[platform.name], dict):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
|
||||
platform_default_tmpl[platform.name]["logo_token"] = logo_token
|
||||
|
||||
# 缓存token
|
||||
self._logo_token_cache[cache_key] = logo_token
|
||||
|
||||
logger.debug(f"Logo token registered for platform {platform.name}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Platform {platform.name} logo file not found: {logo_file_path}"
|
||||
)
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to import required modules for platform {platform.name}: {e}"
|
||||
)
|
||||
except (OSError, IOError) as e:
|
||||
logger.warning(f"File system error for platform {platform.name} logo: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unexpected error registering logo for platform {platform.name}: {e}"
|
||||
)
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
|
||||
@@ -662,9 +736,21 @@ class ConfigRoute(Route):
|
||||
platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][
|
||||
"platform"
|
||||
]["config_template"]
|
||||
|
||||
# 收集需要注册logo的平台
|
||||
logo_registration_tasks = []
|
||||
for platform in platform_registry:
|
||||
if platform.default_config_tmpl:
|
||||
platform_default_tmpl[platform.name] = platform.default_config_tmpl
|
||||
# 收集logo注册任务
|
||||
if platform.logo_path:
|
||||
logo_registration_tasks.append(
|
||||
self._register_platform_logo(platform, platform_default_tmpl)
|
||||
)
|
||||
|
||||
# 并行执行logo注册
|
||||
if logo_registration_tasks:
|
||||
await asyncio.gather(*logo_registration_tasks, return_exceptions=True)
|
||||
|
||||
# 服务提供商的默认配置模板注入
|
||||
provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][
|
||||
|
||||
@@ -114,7 +114,7 @@
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm')
|
||||
}}</v-btn>
|
||||
}}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -241,7 +241,15 @@ export default {
|
||||
|
||||
methods: {
|
||||
// 从工具函数导入
|
||||
getPlatformIcon,
|
||||
getPlatformIcon(platform_id) {
|
||||
// 首先检查是否有来自插件的 logo_token
|
||||
const template = this.metadata['platform_group']?.metadata?.platform?.config_template?.[platform_id];
|
||||
if (template && template.logo_token) {
|
||||
// 通过文件服务访问插件提供的 logo
|
||||
return `/api/file/${template.logo_token}`;
|
||||
}
|
||||
return getPlatformIcon(platform_id);
|
||||
},
|
||||
|
||||
openTutorial() {
|
||||
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
|
||||
|
||||
Reference in New Issue
Block a user