diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 528611d2d..da0fefbf2 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1308,6 +1308,20 @@ CONFIG_METADATA_2 = { "timeout": 20, "launch_model_if_not_running": False, }, + "阿里云百炼重排序": { + "id": "bailian_rerank", + "type": "bailian_rerank", + "provider": "bailian", + "provider_type": "rerank", + "enable": True, + "rerank_api_key": "", + "rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", + "rerank_model": "qwen3-rerank", + "timeout": 30, + "top_n": 3, + "return_documents": False, + "instruct": "", + }, "Xinference STT": { "id": "xinference_stt", "type": "xinference_stt", @@ -1342,6 +1356,21 @@ CONFIG_METADATA_2 = { "description": "重排序模型名称", "type": "string", }, + "top_n": { + "description": "返回排序后的top_n个文档", + "type": "int", + "hint": "默认返回全部文档。如果指定的值大于文档总数,将返回全部文档。", + }, + "return_documents": { + "description": "是否在排序结果中返回文档原文", + "type": "bool", + "hint": "默认值false,以减少网络传输开销。", + }, + "instruct": { + "description": "自定义排序任务类型说明", + "type": "string", + "hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。", + }, "launch_model_if_not_running": { "description": "模型未运行时自动启动", "type": "bool", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 320c98d4e..ec2550415 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -331,6 +331,10 @@ class ProviderManager: from .sources.xinference_rerank_source import ( XinferenceRerankProvider as XinferenceRerankProvider, ) + case "bailian_rerank": + from .sources.bailian_rerank_source import ( + BailianRerankProvider as BailianRerankProvider, + ) except (ImportError, ModuleNotFoundError) as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。", diff --git a/astrbot/core/provider/sources/bailian_rerank_source.py b/astrbot/core/provider/sources/bailian_rerank_source.py new file mode 100644 index 000000000..e32b4ea0d --- /dev/null +++ b/astrbot/core/provider/sources/bailian_rerank_source.py @@ -0,0 +1,171 @@ +import os + +import aiohttp + +from astrbot import logger + +from ..entities import ProviderType, RerankResult +from ..provider import RerankProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK +) +class BailianRerankProvider(RerankProvider): + """阿里云百炼文本重排序适配器""" + + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + + # API配置 + self.api_key = provider_config.get("rerank_api_key", "") + if not self.api_key: + self.api_key = os.getenv("DASHSCOPE_API_KEY", "") + + if not self.api_key: + raise ValueError( + "阿里云百炼 API Key 不能为空,请在配置中设置 rerank_api_key 或设置环境变量 DASHSCOPE_API_KEY" + ) + + self.model = provider_config.get("rerank_model", "qwen3-rerank") + self.timeout = provider_config.get("timeout", 30) + self.default_top_n = provider_config.get("top_n", 5) + self.return_documents = provider_config.get("return_documents", False) + self.instruct = provider_config.get("instruct", "") + + self.base_url = provider_config.get( + "rerank_api_base", + "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", + ) + + # 设置HTTP客户端 + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + self.client = aiohttp.ClientSession( + headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + + # 设置模型名称 + self.set_model(self.model) + + logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}") + + async def rerank( + self, + query: str, + documents: list[str], + top_n: int | None = None, + ) -> list[RerankResult]: + """ + 对文档进行重排序 + + Args: + query: 查询文本 + documents: 待排序的文档列表 + top_n: 返回前N个结果,如果为None则使用配置中的默认值 + + Returns: + 重排序结果列表 + """ + if not documents: + logger.warning("文档列表为空,返回空结果") + return [] + + if not query.strip(): + logger.warning("查询文本为空,返回空结果") + return [] + + # 检查限制 + if len(documents) > 500: + logger.warning( + f"文档数量({len(documents)})超过限制(500),将截断前500个文档" + ) + documents = documents[:500] + + # 使用传入的top_n或默认配置 + final_top_n = top_n if top_n is not None else self.default_top_n + + try: + # 构建请求载荷 + payload = { + "model": self.model, + "input": {"query": query, "documents": documents}, + } + + # 添加可选参数 + parameters = {} + if final_top_n is not None: + parameters["top_n"] = final_top_n + if self.return_documents: + parameters["return_documents"] = True + if self.instruct and self.model == "qwen3-rerank": + parameters["instruct"] = self.instruct + + if parameters: + payload["parameters"] = parameters + + logger.debug( + f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}" + ) + + # 发送请求 + async with self.client.post(self.base_url, json=payload) as response: + response.raise_for_status() + response_data = await response.json() + + # 检查响应状态 + if "code" in response_data and response_data["code"] != "200": + error_msg = response_data.get("message", "未知错误") + raise Exception( + f"百炼 API 返回错误: {response_data['code']} - {error_msg}" + ) + + # 解析结果 + output = response_data.get("output", {}) + results = output.get("results", []) + + if not results: + logger.warning(f"百炼 Rerank 返回空结果: {response_data}") + return [] + + # 转换为RerankResult对象 + rerank_results = [] + for result in results: + rerank_result = RerankResult( + index=result["index"], relevance_score=result["relevance_score"] + ) + rerank_results.append(rerank_result) + + logger.debug(f"百炼 Rerank 成功返回 {len(rerank_results)} 个结果") + + # 记录使用量信息 + usage = response_data.get("usage", {}) + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + logger.debug(f"百炼 Rerank 消耗 Token 数量: {total_tokens}") + + return rerank_results + + except aiohttp.ClientError as e: + logger.error(f"百炼 Rerank 网络请求失败: {e}") + raise Exception(f"网络请求失败: {e}") + except Exception as e: + logger.error(f"百炼 Rerank 处理失败: {e}") + raise Exception(f"重排序失败: {e}") + + async def terminate(self) -> None: + """关闭HTTP客户端会话""" + if self.client: + logger.info("关闭 百炼 Rerank 客户端会话") + try: + await self.client.close() + except Exception as e: + logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}") + finally: + self.client = None