From 1542ea3e03bd30080907d02189e6949a9bb2594d Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 22 Sep 2025 17:22:50 +0800 Subject: [PATCH] fix: context.get_provider_by_id issue --- astrbot/core/provider/manager.py | 25 +++++++++++++++++++------ astrbot/core/star/context.py | 11 +++++++---- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 4b9204a68..a23788a76 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -7,7 +7,13 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase from .entities import ProviderType -from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider +from .provider import ( + Provider, + STTProvider, + TTSProvider, + EmbeddingProvider, + RerankProvider, +) from .register import llm_tools, provider_cls_map from ..persona_mgr import PersonaManager @@ -38,7 +44,12 @@ class ProviderManager: """加载的 Text To Speech Provider 的实例""" self.embedding_provider_insts: List[EmbeddingProvider] = [] """加载的 Embedding Provider 的实例""" - self.inst_map: dict[str, Provider | STTProvider | TTSProvider] = {} + self.rerank_provider_insts: List[RerankProvider] = [] + """加载的 Rerank Provider 的实例""" + self.inst_map: dict[ + str, + Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider, + ] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools @@ -378,14 +389,16 @@ class ProviderManager: if not self.curr_provider_inst: self.curr_provider_inst = inst - elif provider_metadata.provider_type in [ - ProviderType.EMBEDDING, - ProviderType.RERANK, - ]: + elif provider_metadata.provider_type == ProviderType.EMBEDDING: inst = cls_type(provider_config, self.provider_settings) if getattr(inst, "initialize", None): await inst.initialize() self.embedding_provider_insts.append(inst) + elif provider_metadata.provider_type == ProviderType.RERANK: + inst = cls_type(provider_config, self.provider_settings) + if getattr(inst, "initialize", None): + await inst.initialize() + self.rerank_provider_insts.append(inst) self.inst_map[provider_config["id"]] = inst except Exception as e: diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 005266a02..31616e7da 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -6,6 +6,7 @@ from astrbot.core.provider.provider import ( TTSProvider, STTProvider, EmbeddingProvider, + RerankProvider, ) from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase @@ -103,11 +104,13 @@ class Context: """ self.provider_manager.provider_insts.append(provider) - def get_provider_by_id(self, provider_id: str) -> Provider | None: - """通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。""" + def get_provider_by_id( + self, provider_id: str + ) -> ( + Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None + ): + """通过 ID 获取对应的 LLM Provider。""" prov = self.provider_manager.inst_map.get(provider_id) - if prov and not isinstance(prov, Provider): - raise ValueError("返回的 Provider 不是 Provider 类型") return prov def get_all_providers(self) -> List[Provider]: