Files
AstrBot/astrbot/dashboard/routes/config.py
T
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00

834 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import inspect
import os
import traceback
import uuid
from typing import Any
from quart import request
from astrbot.core import file_token_service, logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import (
CONFIG_METADATA_2,
CONFIG_METADATA_3,
CONFIG_METADATA_3_SYSTEM,
DEFAULT_CONFIG,
DEFAULT_VALUE_MAP,
WEBHOOK_SUPPORTED_PLATFORMS,
)
from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from .route import Response, Route, RouteContext
def try_cast(value: Any, type_: str):
if type_ == "int":
try:
return int(value)
except (ValueError, TypeError):
return None
elif (
type_ == "float"
and isinstance(value, str)
and value.replace(".", "", 1).isdigit()
) or (type_ == "float" and isinstance(value, int)):
return float(value)
elif type_ == "float":
try:
return float(value)
except (ValueError, TypeError):
return None
def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]:
errors = []
def validate(data: dict, metadata: dict = schema, path=""):
for key, value in data.items():
if key not in metadata:
continue
meta = metadata[key]
if "type" not in meta:
logger.debug(f"配置项 {path}{key} 没有类型定义, 跳过校验")
continue
# null 转换
if value is None:
data[key] = DEFAULT_VALUE_MAP[meta["type"]]
continue
if meta["type"] == "list" and not isinstance(value, list):
errors.append(
f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}",
)
elif (
meta["type"] == "list"
and isinstance(value, list)
and value
and "items" in meta
and isinstance(value[0], dict)
):
# 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置
for item in value:
validate(item, meta["items"], path=f"{path}{key}.")
elif meta["type"] == "object" and isinstance(value, dict):
validate(value, meta["items"], path=f"{path}{key}.")
if meta["type"] == "int" and not isinstance(value, int):
casted = try_cast(value, "int")
if casted is None:
errors.append(
f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}",
)
data[key] = casted
elif meta["type"] == "float" and not isinstance(value, float):
casted = try_cast(value, "float")
if casted is None:
errors.append(
f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}",
)
data[key] = casted
elif meta["type"] == "bool" and not isinstance(value, bool):
errors.append(
f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}",
)
elif meta["type"] in ["string", "text"] and not isinstance(value, str):
errors.append(
f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}",
)
elif meta["type"] == "list" and not isinstance(value, list):
errors.append(
f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}",
)
elif meta["type"] == "object" and not isinstance(value, dict):
errors.append(
f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}",
)
if is_core:
meta_all = {
**schema["platform_group"]["metadata"],
**schema["provider_group"]["metadata"],
**schema["misc_config_group"]["metadata"],
}
validate(data, meta_all)
else:
validate(data, schema)
return errors, data
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
"""验证并保存配置"""
errors = None
logger.info(f"Saving config, is_core={is_core}")
try:
if is_core:
errors, post_config = validate_config(
post_config,
CONFIG_METADATA_2,
is_core,
)
else:
errors, post_config = validate_config(
post_config, getattr(config, "schema", {}), is_core
)
except BaseException as e:
logger.error(traceback.format_exc())
logger.warning(f"验证配置时出现异常: {e}")
raise ValueError(f"验证配置时出现异常: {e}")
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.save_config(post_config)
class ConfigRoute(Route):
def __init__(
self,
context: RouteContext,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.config: AstrBotConfig = core_lifecycle.astrbot_config
self._logo_token_cache = {} # 缓存logo token,避免重复注册
self.acm = core_lifecycle.astrbot_config_mgr
self.ucr = core_lifecycle.umop_config_router
self.routes = {
"/config/abconf/new": ("POST", self.create_abconf),
"/config/abconf": ("GET", self.get_abconf),
"/config/abconfs": ("GET", self.get_abconf_list),
"/config/abconf/delete": ("POST", self.delete_abconf),
"/config/abconf/update": ("POST", self.update_abconf),
"/config/umo_abconf_routes": ("GET", self.get_uc_table),
"/config/umo_abconf_route/update_all": ("POST", self.update_ucr_all),
"/config/umo_abconf_route/update": ("POST", self.update_ucr),
"/config/umo_abconf_route/delete": ("POST", self.delete_ucr),
"/config/get": ("GET", self.get_configs),
"/config/default": ("GET", self.get_default_config),
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
"/config/plugin/update": ("POST", self.post_plugin_configs),
"/config/platform/new": ("POST", self.post_new_platform),
"/config/platform/update": ("POST", self.post_update_platform),
"/config/platform/delete": ("POST", self.post_delete_platform),
"/config/platform/list": ("GET", self.get_platform_list),
"/config/provider/new": ("POST", self.post_new_provider),
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
"/config/provider/check_one": ("GET", self.check_one_provider_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/model_list": ("GET", self.get_provider_model_list),
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
}
self.register_routes()
async def get_uc_table(self):
"""获取 UMOP 配置路由表"""
return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__
async def update_ucr_all(self):
"""更新 UMOP 配置路由表的全部内容"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
new_routing = post_data.get("routing", None)
if not new_routing or not isinstance(new_routing, dict):
return Response().error("缺少或错误的路由表数据").__dict__
try:
await self.ucr.update_routing_data(new_routing)
return Response().ok(message="更新成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新路由表失败: {e!s}").__dict__
async def update_ucr(self):
"""更新 UMOP 配置路由表"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
umo = post_data.get("umo", None)
conf_id = post_data.get("conf_id", None)
if not umo or not conf_id:
return Response().error("缺少 UMO 或配置文件 ID").__dict__
try:
await self.ucr.update_route(umo, conf_id)
return Response().ok(message="更新成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新路由表失败: {e!s}").__dict__
async def delete_ucr(self):
"""删除 UMOP 配置路由表中的一项"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
umo = post_data.get("umo", None)
if not umo:
return Response().error("缺少 UMO").__dict__
try:
if umo in self.ucr.umop_to_conf_id:
del self.ucr.umop_to_conf_id[umo]
await self.ucr.update_routing_data(self.ucr.umop_to_conf_id)
return Response().ok(message="删除成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除路由表项失败: {e!s}").__dict__
async def get_default_config(self):
"""获取默认配置文件"""
metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3)
return Response().ok({"config": DEFAULT_CONFIG, "metadata": metadata}).__dict__
async def get_abconf_list(self):
"""获取所有 AstrBot 配置文件的列表"""
abconf_list = self.acm.get_conf_list()
return Response().ok({"info_list": abconf_list}).__dict__
async def create_abconf(self):
"""创建新的 AstrBot 配置文件"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
name = post_data.get("name", None)
config = post_data.get("config", DEFAULT_CONFIG)
try:
conf_id = self.acm.create_conf(name=name, config=config)
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
async def get_abconf(self):
"""获取指定 AstrBot 配置文件"""
abconf_id = request.args.get("id")
system_config = request.args.get("system_config", "0").lower() == "1"
if not abconf_id and not system_config:
return Response().error("缺少配置文件 ID").__dict__
try:
if system_config:
abconf = self.acm.confs["default"]
metadata = ConfigMetadataI18n.convert_to_i18n_keys(
CONFIG_METADATA_3_SYSTEM
)
return Response().ok({"config": abconf, "metadata": metadata}).__dict__
if abconf_id is None:
raise ValueError("abconf_id cannot be None")
abconf = self.acm.confs[abconf_id]
metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3)
return Response().ok({"config": abconf, "metadata": metadata}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
async def delete_abconf(self):
"""删除指定 AstrBot 配置文件"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
conf_id = post_data.get("id")
if not conf_id:
return Response().error("缺少配置文件 ID").__dict__
try:
success = self.acm.delete_conf(conf_id)
if success:
return Response().ok(message="删除成功").__dict__
return Response().error("删除失败").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除配置文件失败: {e!s}").__dict__
async def update_abconf(self):
"""更新指定 AstrBot 配置文件信息"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
conf_id = post_data.get("id")
if not conf_id:
return Response().error("缺少配置文件 ID").__dict__
name = post_data.get("name")
try:
success = self.acm.update_conf_info(conf_id, name=name)
if success:
return Response().ok(message="更新成功").__dict__
return Response().error("更新失败").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新配置文件失败: {e!s}").__dict__
async def _test_single_provider(self, provider):
"""辅助函数:测试单个 provider 的可用性"""
meta = provider.meta()
provider_name = provider.provider_config.get("id", "Unknown Provider")
provider_capability_type = meta.provider_type
status_info = {
"id": getattr(meta, "id", "Unknown ID"),
"model": getattr(meta, "model", "Unknown Model"),
"type": provider_capability_type.value,
"name": provider_name,
"status": "unavailable", # 默认为不可用
"error": None,
}
logger.debug(
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
)
try:
await provider.test()
status_info["status"] = "available"
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
)
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
return status_info
def _error_response(
self,
message: str,
status_code: int = 500,
log_fn=logger.error,
):
log_fn(message)
# 记录更详细的traceback信息,但只在是严重错误时
if status_code == 500:
log_fn(traceback.format_exc())
return Response().error(message).__dict__
async def check_one_provider_status(self):
"""API: check a single LLM Provider's status by id"""
provider_id = request.args.get("id")
if not provider_id:
return self._error_response(
"Missing provider_id parameter",
400,
logger.warning,
)
logger.info(f"API call: /config/provider/check_one id={provider_id}")
try:
prov_mgr = self.core_lifecycle.provider_manager
target = prov_mgr.inst_map.get(provider_id)
if not target:
logger.warning(
f"Provider with id '{provider_id}' not found in provider_manager.",
)
return (
Response()
.error(f"Provider with id '{provider_id}' not found")
.__dict__
)
result = await self._test_single_provider(target)
return Response().ok(result).__dict__
except Exception as e:
return self._error_response(
f"Critical error checking provider {provider_id}: {e}",
500,
)
async def get_configs(self):
# plugin_name 为空时返回 AstrBot 配置
# 否则返回指定 plugin_name 的插件配置
plugin_name = request.args.get("plugin_name", None)
if not plugin_name:
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
async def get_provider_config_list(self):
provider_type = request.args.get("provider_type", None)
if not provider_type:
return Response().error("缺少参数 provider_type").__dict__
provider_type_ls = provider_type.split(",")
provider_list = []
astrbot_config = self.core_lifecycle.astrbot_config
for provider in astrbot_config["provider"]:
if provider.get("provider_type", None) in provider_type_ls:
provider_list.append(provider)
return Response().ok(provider_list).__dict__
async def get_provider_model_list(self):
"""获取指定提供商的模型列表"""
provider_id = request.args.get("provider_id", None)
if not provider_id:
return Response().error("缺少参数 provider_id").__dict__
prov_mgr = self.core_lifecycle.provider_manager
provider = prov_mgr.inst_map.get(provider_id, None)
if not provider:
return Response().error(f"未找到 ID 为 {provider_id} 的提供商").__dict__
if not isinstance(provider, Provider):
return (
Response()
.error(f"提供商 {provider_id} 类型不支持获取模型列表")
.__dict__
)
try:
models = await provider.get_models()
ret = {
"models": models,
"provider_id": provider_id,
}
return Response().ok(ret).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_embedding_dim(self):
"""获取嵌入模型的维度"""
post_data = await request.json
provider_config = post_data.get("provider_config", None)
if not provider_config:
return Response().error("缺少参数 provider_config").__dict__
try:
# 动态导入 EmbeddingProvider
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.register import provider_cls_map
# 获取 provider 类型
provider_type = provider_config.get("type", None)
if not provider_type:
return Response().error("provider_config 缺少 type 字段").__dict__
# 获取对应的 provider 类
if provider_type not in provider_cls_map:
return (
Response()
.error(f"未找到适用于 {provider_type} 的提供商适配器")
.__dict__
)
provider_metadata = provider_cls_map[provider_type]
cls_type = provider_metadata.cls_type
if not cls_type:
return Response().error(f"无法找到 {provider_type} 的类").__dict__
# 实例化 provider
inst = cls_type(provider_config, {})
# 检查是否是 EmbeddingProvider
if not isinstance(inst, EmbeddingProvider):
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
init_fn = getattr(inst, "initialize", None)
if inspect.iscoroutinefunction(init_fn):
await init_fn()
# 获取嵌入向量维度
vec = await inst.get_embedding("echo")
dim = len(vec)
logger.info(
f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}",
)
return Response().ok({"embedding_dimensions": dim}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取嵌入维度失败: {e!s}").__dict__
async def get_platform_list(self):
"""获取所有平台的列表"""
platform_list = []
for platform in self.config["platform"]:
platform_list.append(platform)
return Response().ok({"platforms": platform_list}).__dict__
async def post_astrbot_configs(self):
data = await request.json
config = data.get("config", None)
conf_id = data.get("conf_id", None)
try:
await self._save_astrbot_configs(config, conf_id)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
return Response().ok(None, "保存成功~").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def post_plugin_configs(self):
post_configs = await request.json
plugin_name = request.args.get("plugin_name", "unknown")
try:
await self._save_plugin_configs(post_configs, plugin_name)
await self.core_lifecycle.plugin_manager.reload(plugin_name)
return (
Response()
.ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在热重载插件。")
.__dict__
)
except Exception as e:
return Response().error(str(e)).__dict__
async def post_new_platform(self):
new_platform_config = await request.json
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,自动生成 webhook_uuid
platform_type = new_platform_config.get("type", "")
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
if new_platform_config.get("unified_webhook_mode", False):
# 如果没有 webhook_uuid 或为空,自动生成
if not new_platform_config.get("webhook_uuid"):
new_platform_config["webhook_uuid"] = uuid.uuid4().hex[:16]
self.config["platform"].append(new_platform_config)
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.load_platform(
new_platform_config,
)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "新增平台配置成功~").__dict__
async def post_new_provider(self):
new_provider_config = await request.json
self.config["provider"].append(new_provider_config)
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.load_provider(
new_provider_config,
)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "新增服务提供商配置成功~").__dict__
async def post_update_platform(self):
update_platform_config = await request.json
platform_id = update_platform_config.get("id", None)
new_config = update_platform_config.get("config", None)
if not platform_id or not new_config:
return Response().error("参数错误").__dict__
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid
platform_type = new_config.get("type", "")
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
if new_config.get("unified_webhook_mode", False):
# 如果没有 webhook_uuid 或为空,自动生成
if not new_config.get("webhook_uuid"):
new_config["webhook_uuid"] = uuid.uuid4().hex
for i, platform in enumerate(self.config["platform"]):
if platform["id"] == platform_id:
self.config["platform"][i] = new_config
break
else:
return Response().error("未找到对应平台").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.reload(new_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新平台配置成功~").__dict__
async def post_update_provider(self):
update_provider_config = await request.json
provider_id = update_provider_config.get("id", None)
new_config = update_provider_config.get("config", None)
if not provider_id or not new_config:
return Response().error("参数错误").__dict__
for i, provider in enumerate(self.config["provider"]):
if provider["id"] == provider_id:
self.config["provider"][i] = new_config
break
else:
return Response().error("未找到对应服务提供商").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.reload(new_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新成功,已经实时生效~").__dict__
async def post_delete_platform(self):
platform_id = await request.json
platform_id = platform_id.get("id")
for i, platform in enumerate(self.config["platform"]):
if platform["id"] == platform_id:
del self.config["platform"][i]
break
else:
return Response().error("未找到对应平台").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.terminate_platform(platform_id)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除平台配置成功~").__dict__
async def post_delete_provider(self):
provider_id = await request.json
provider_id = provider_id.get("id")
for i, provider in enumerate(self.config["provider"]):
if provider["id"] == provider_id:
del self.config["provider"][i]
break
else:
return Response().error("未找到对应服务提供商").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.terminate_provider(provider_id)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def get_llm_tools(self):
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = tool_mgr.get_func_desc_openai_style()
return Response().ok(tools).__dict__
async def _register_platform_logo(self, platform, platform_default_tmpl):
"""注册平台logo文件并生成访问令牌"""
if not platform.logo_path:
return
try:
# 检查缓存
cache_key = f"{platform.name}:{platform.logo_path}"
if cache_key in self._logo_token_cache:
cached_token = self._logo_token_cache[cache_key]
# 确保platform_default_tmpl[platform.name]存在且为字典
if platform.name not in platform_default_tmpl or not isinstance(
platform_default_tmpl[platform.name], dict
):
platform_default_tmpl[platform.name] = {}
platform_default_tmpl[platform.name]["logo_token"] = cached_token
logger.debug(f"Using cached logo token for platform {platform.name}")
return
# 获取平台适配器类
platform_cls = platform_cls_map.get(platform.name)
if not platform_cls:
logger.warning(f"Platform class not found for {platform.name}")
return
# 获取插件目录路径
module_file = inspect.getfile(platform_cls)
plugin_dir = os.path.dirname(module_file)
# 解析logo文件路径
logo_file_path = os.path.join(plugin_dir, platform.logo_path)
# 检查文件是否存在并注册令牌
if os.path.exists(logo_file_path):
logo_token = await file_token_service.register_file(
logo_file_path,
timeout=3600,
)
# 确保platform_default_tmpl[platform.name]存在且为字典
if platform.name not in platform_default_tmpl or not isinstance(
platform_default_tmpl[platform.name], dict
):
platform_default_tmpl[platform.name] = {}
platform_default_tmpl[platform.name]["logo_token"] = logo_token
# 缓存token
self._logo_token_cache[cache_key] = logo_token
logger.debug(f"Logo token registered for platform {platform.name}")
else:
logger.warning(
f"Platform {platform.name} logo file not found: {logo_file_path}",
)
except (ImportError, AttributeError) as e:
logger.warning(
f"Failed to import required modules for platform {platform.name}: {e}",
)
except OSError as e:
logger.warning(f"File system error for platform {platform.name} logo: {e}")
except Exception as e:
logger.warning(
f"Unexpected error registering logo for platform {platform.name}: {e}",
)
async def _get_astrbot_config(self):
config = self.config
# 平台适配器的默认配置模板注入
platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][
"platform"
]["config_template"]
# 收集需要注册logo的平台
logo_registration_tasks = []
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
# 收集logo注册任务
if platform.logo_path:
logo_registration_tasks.append(
self._register_platform_logo(platform, platform_default_tmpl),
)
# 并行执行logo注册
if logo_registration_tasks:
await asyncio.gather(*logo_registration_tasks, return_exceptions=True)
# 服务提供商的默认配置模板注入
provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][
"provider"
]["config_template"]
for provider in provider_registry:
if provider.default_config_tmpl:
provider_default_tmpl[provider.type] = provider.default_config_tmpl
return {"metadata": CONFIG_METADATA_2, "config": config}
async def _get_plugin_config(self, plugin_name: str):
ret: dict = {"metadata": None, "config": None}
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
if not plugin_md.config:
break
ret["config"] = (
plugin_md.config
) # 这是自定义的 Dict 类(AstrBotConfig
ret["metadata"] = {
plugin_name: {
"description": f"{plugin_name} 配置",
"type": "object",
"items": plugin_md.config.schema, # 初始化时通过 __setattr__ 存入了 schema
},
}
break
return ret
async def _save_astrbot_configs(
self, post_configs: dict, conf_id: str | None = None
):
try:
if conf_id not in self.acm.confs:
raise ValueError(f"配置文件 {conf_id} 不存在")
astrbot_config = self.acm.confs[conf_id]
# 保留服务端的 t2i_active_template 值
if "t2i_active_template" in astrbot_config:
post_configs["t2i_active_template"] = astrbot_config[
"t2i_active_template"
]
save_config(post_configs, astrbot_config, is_core=True)
except Exception as e:
raise e
async def _save_plugin_configs(self, post_configs: dict, plugin_name: str):
md = None
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
md = plugin_md
if not md:
raise ValueError(f"插件 {plugin_name} 不存在")
if not md.config:
raise ValueError(f"插件 {plugin_name} 没有注册配置")
try:
save_config(post_configs, md.config)
except Exception as e:
raise e