diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 3c805cf99..f8dbbf469 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1974,22 +1974,10 @@ CONFIG_METADATA_2 = { "description": "API Base URL", "type": "string", }, - "model_config": { - "description": "模型配置", - "type": "object", - "items": { - "model": { - "description": "模型名称", - "type": "string", - "hint": "模型名称,如 gpt-4o-mini, deepseek-chat。", - }, - "max_tokens": { - "description": "模型最大输出长度(tokens)", - "type": "int", - }, - "temperature": {"description": "温度", "type": "float"}, - "top_p": {"description": "Top P值", "type": "float"}, - }, + "model": { + "description": "模型 ID", + "type": "string", + "hint": "模型名称,如 gpt-4o-mini, deepseek-chat。", }, "dify_api_key": { "description": "API Key", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index e3e0e7595..e4f2a9e24 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -37,6 +37,8 @@ class ProviderManager: config = acm.confs["default"] self.providers_config: list = config["provider"] self.provider_sources_config: list = config.get("provider_sources", []) + self.merged_provider_config: dict = {} + """合并 provider 和 provider_sources 配置后的结果""" self.provider_settings: dict = config["provider_settings"] self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) @@ -270,6 +272,8 @@ class ProviderManager: merged_config["id"] = provider_config["id"] provider_config = merged_config + self.merged_provider_config[provider_config["id"]] = provider_config + if not provider_config["enable"]: logger.info(f"Provider {provider_config['id']} is disabled, skipping") return diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index f2b7fac6f..21acd87e8 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -45,7 +45,7 @@ class ProviderAnthropic(Provider): base_url=self.base_url, ) - self.set_model(provider_config["model_config"]["model"]) + self.set_model(provider_config.get("model", "unknown")) def _prepare_payload(self, messages: list[dict]): """准备 Anthropic API 的请求 payload @@ -285,10 +285,9 @@ class ProviderAnthropic(Provider): system_prompt, new_messages = self._prepare_payload(context_query) - model_config = self.provider_config.get("model_config", {}) - model_config["model"] = model or self.get_model() + model = model or self.get_model() - payloads = {"messages": new_messages, **model_config} + payloads = {"messages": new_messages, "model": model} # Anthropic has a different way of handling system prompts if system_prompt: @@ -298,7 +297,6 @@ class ProviderAnthropic(Provider): try: llm_response = await self._query(payloads, func_tool) except Exception as e: - # logger.error(f"发生了错误。Provider 配置如下: {model_config}") raise e return llm_response @@ -340,10 +338,9 @@ class ProviderAnthropic(Provider): system_prompt, new_messages = self._prepare_payload(context_query) - model_config = self.provider_config.get("model_config", {}) - model_config["model"] = model or self.get_model() + model = model or self.get_model() - payloads = {"messages": new_messages, **model_config} + payloads = {"messages": new_messages, "model": model} # Anthropic has a different way of handling system prompts if system_prompt: diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index e2efc6aab..ebb000c03 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -68,7 +68,7 @@ class ProviderGoogleGenAI(Provider): self.api_base = self.api_base[:-1] self._init_client() - self.set_model(provider_config["model_config"]["model"]) + self.set_model(provider_config.get("model", "unknown")) self._init_safety_settings() def _init_client(self) -> None: @@ -652,10 +652,9 @@ class ProviderGoogleGenAI(Provider): for tcr in tool_calls_result: context_query.extend(tcr.to_openai_messages()) - model_config = self.provider_config.get("model_config", {}) - model_config["model"] = model or self.get_model() + model = model or self.get_model() - payloads = {"messages": context_query, **model_config} + payloads = {"messages": context_query, "model": model} retry = 10 keys = self.api_keys.copy() @@ -705,10 +704,9 @@ class ProviderGoogleGenAI(Provider): for tcr in tool_calls_result: context_query.extend(tcr.to_openai_messages()) - model_config = self.provider_config.get("model_config", {}) - model_config["model"] = model or self.get_model() + model = model or self.get_model() - payloads = {"messages": context_query, **model_config} + payloads = {"messages": context_query, "model": model} retry = 10 keys = self.api_keys.copy() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 788b649a9..6e3bd0bfd 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -68,8 +68,7 @@ class ProviderOpenAIOfficial(Provider): self.client.chat.completions.create, ).parameters.keys() - model_config = provider_config.get("model_config", {}) - model = model_config.get("model", "unknown") + model = provider_config.get("model", "unknown") self.set_model(model) self.reasoning_key = "reasoning_content" @@ -358,10 +357,9 @@ class ProviderOpenAIOfficial(Provider): for tcr in tool_calls_result: context_query.extend(tcr.to_openai_messages()) - model_config = self.provider_config.get("model_config", {}) - model_config["model"] = model or self.get_model() + model = model or self.get_model() - payloads = {"messages": context_query, **model_config} + payloads = {"messages": context_query, "model": model} # xAI origin search tool inject self._maybe_inject_xai_search(payloads, **kwargs) diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 52715caf9..42046eab8 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -51,6 +51,7 @@ def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None: "model", "modalities", "custom_extra_body", + "enable", } # Fields that should not go to source diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 22a24892e..6e9942e5b 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -188,9 +188,80 @@ class ConfigRoute(Route): "GET", self.get_provider_source_models, ), + "/config/provider_sources//update": ( + "POST", + self.update_provider_source, + ), } self.register_routes() + async def update_provider_source(self, provider_source_id: str): + """更新或新增 provider_source,并重载关联的 providers""" + + post_data = await request.json + if not post_data: + return Response().error("缺少配置数据").__dict__ + + new_source_config = post_data.get("config") or post_data + original_id = post_data.get("original_id") or provider_source_id + + if not isinstance(new_source_config, dict): + return Response().error("缺少或错误的配置数据").__dict__ + + # 确保配置中有 id 字段 + if not new_source_config.get("id"): + new_source_config["id"] = original_id + + provider_sources = self.config.get("provider_sources", []) + + # 查找旧的 provider_source,若不存在则追加为新配置 + target_idx = next( + (i for i, ps in enumerate(provider_sources) if ps.get("id") == original_id), + -1, + ) + + old_id = original_id + if target_idx == -1: + provider_sources.append(new_source_config) + else: + old_id = provider_sources[target_idx].get("id") + provider_sources[target_idx] = new_source_config + + # 更新引用了该 provider_source 的 providers + affected_providers = [] + for provider in self.config.get("provider", []): + if provider.get("provider_source_id") == old_id: + provider["provider_source_id"] = new_source_config["id"] + affected_providers.append(provider) + + # 写回配置 + self.config["provider_sources"] = provider_sources + + try: + save_config(self.config, self.config, is_core=True) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(str(e)).__dict__ + + # 重载受影响的 providers,使新的 source 配置生效 + reload_errors = [] + prov_mgr = self.core_lifecycle.provider_manager + for provider in affected_providers: + try: + await prov_mgr.reload(provider) + except Exception as e: + logger.error(traceback.format_exc()) + reload_errors.append(f"{provider.get('id')}: {e}") + + if reload_errors: + return ( + Response() + .error("更新成功,但部分提供商重载失败: " + ", ".join(reload_errors)) + .__dict__ + ) + + return Response().ok(message="更新 provider source 成功").__dict__ + async def get_provider_template(self): provider_config = astrbot_config["provider"] config_schema = { @@ -449,8 +520,8 @@ class ConfigRoute(Route): return Response().error("缺少参数 provider_type").__dict__ provider_type_ls = provider_type.split(",") provider_list = [] - astrbot_config = self.core_lifecycle.astrbot_config - for provider in astrbot_config["provider"]: + pc = self.core_lifecycle.provider_manager.merged_provider_config + for provider in pc.values(): if provider.get("provider_type", None) in provider_type_ls: provider_list.append(provider) return Response().ok(provider_list).__dict__ diff --git a/dashboard/src/components/shared/ProviderSelector.vue b/dashboard/src/components/shared/ProviderSelector.vue index 050738a94..ffade98d1 100644 --- a/dashboard/src/components/shared/ProviderSelector.vue +++ b/dashboard/src/components/shared/ProviderSelector.vue @@ -51,7 +51,7 @@ {{ provider.id }} {{ provider.type || provider.provider_type || tm('providerSelector.unknownType') }} - - {{ provider.model_config.model }} + - {{ provider.model }}