diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7710ebb40..ad268c12a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1262,6 +1262,18 @@ CONFIG_METADATA_2 = { "rerank_model": "BAAI/bge-reranker-base", "timeout": 20, }, + "Xinference Rerank": { + "id": "xinference_rerank", + "type": "xinference_rerank", + "provider": "xinference", + "provider_type": "rerank", + "enable": True, + "rerank_api_key": "", + "rerank_api_base": "http://127.0.0.1:9997", + "rerank_model": "BAAI/bge-reranker-base", + "timeout": 20, + "launch_model_if_not_running": False, + }, }, "items": { "rerank_api_base": { @@ -1278,6 +1290,11 @@ CONFIG_METADATA_2 = { "description": "重排序模型名称", "type": "string", }, + "launch_model_if_not_running": { + "description": "模型未运行时自动启动", + "type": "bool", + "hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。", + }, "modalities": { "description": "模型能力", "type": "list", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index edb9b767a..ef86ed602 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -311,6 +311,10 @@ class ProviderManager: from .sources.vllm_rerank_source import ( VLLMRerankProvider as VLLMRerankProvider, ) + case "xinference_rerank": + from .sources.xinference_rerank_source import ( + XinferenceRerankProvider as XinferenceRerankProvider, + ) except (ImportError, ModuleNotFoundError) as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。" diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py new file mode 100644 index 000000000..3c27d7c3a --- /dev/null +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -0,0 +1,108 @@ +from xinference_client.client.restful.async_restful_client import ( + AsyncClient as Client, +) +from astrbot import logger +from ..provider import RerankProvider +from ..register import register_provider_adapter +from ..entities import ProviderType, RerankResult + + +@register_provider_adapter( + "xinference_rerank", + "Xinference Rerank 适配器", + provider_type=ProviderType.RERANK, +) +class XinferenceRerankProvider(RerankProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000") + self.base_url = self.base_url.rstrip("/") + self.timeout = provider_config.get("timeout", 20) + self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base") + self.api_key = provider_config.get("rerank_api_key") + self.launch_model_if_not_running = provider_config.get( + "launch_model_if_not_running", False + ) + self.client = None + self.model = None + self.model_uid = None + + async def initialize(self): + if self.api_key: + logger.info("Xinference Rerank: Using API key for authentication.") + self.client = Client(self.base_url, api_key=self.api_key) + else: + logger.info("Xinference Rerank: No API key provided.") + self.client = Client(self.base_url) + + try: + running_models = await self.client.list_models() + for uid, model_spec in running_models.items(): + if model_spec.get("model_name") == self.model_name: + logger.info( + f"Model '{self.model_name}' is already running with UID: {uid}" + ) + self.model_uid = uid + break + + if self.model_uid is None: + if self.launch_model_if_not_running: + logger.info(f"Launching {self.model_name} model...") + self.model_uid = await self.client.launch_model( + model_name=self.model_name, model_type="rerank" + ) + logger.info("Model launched.") + else: + logger.warning( + f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available." + ) + return + + if self.model_uid: + self.model = await self.client.get_model(self.model_uid) + + except Exception as e: + logger.error(f"Failed to initialize Xinference model: {e}") + logger.debug( + f"Xinference initialization failed with exception: {e}", exc_info=True + ) + self.model = None + + async def rerank( + self, query: str, documents: list[str], top_n: int | None = None + ) -> list[RerankResult]: + if not self.model: + logger.error("Xinference rerank model is not initialized.") + return [] + try: + response = await self.model.rerank(documents, query, top_n) + results = response.get("results", []) + logger.debug(f"Rerank API response: {response}") + + if not results: + logger.warning( + f"Rerank API returned an empty list. Original response: {response}" + ) + + return [ + RerankResult( + index=result["index"], + relevance_score=result["relevance_score"], + ) + for result in results + ] + except Exception as e: + logger.error(f"Xinference rerank failed: {e}") + logger.debug(f"Xinference rerank failed with exception: {e}", exc_info=True) + return [] + + async def terminate(self) -> None: + """关闭客户端会话""" + if self.client: + logger.info("Closing Xinference rerank client...") + try: + await self.client.close() + except Exception as e: + logger.error(f"Failed to close Xinference client: {e}", exc_info=True) diff --git a/pyproject.toml b/pyproject.toml index d8780b623..59bd15672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "rank-bm25>=0.2.2", "jieba>=0.42.1", "markitdown-no-magika[docx,xls,xlsx]>=0.1.2", + "xinference-client", ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index 714676e4e..308c27e51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,4 +48,5 @@ pypdf aiofiles rank-bm25 jieba -markitdown-no-magika[docx,xls,xlsx] \ No newline at end of file +markitdown-no-magika[docx,xls,xlsx] +xinference-client \ No newline at end of file