添加阿里百炼重排序模型

This commit is contained in:
piexian
2025-11-20 08:05:42 +08:00
parent 8488c9aeab
commit 788ceb9721
3 changed files with 204 additions and 0 deletions
+29
View File
@@ -1308,6 +1308,20 @@ CONFIG_METADATA_2 = {
"timeout": 20,
"launch_model_if_not_running": False,
},
"阿里云百炼重排序": {
"id": "bailian_rerank",
"type": "bailian_rerank",
"provider": "bailian",
"provider_type": "rerank",
"enable": True,
"rerank_api_key": "",
"rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
"rerank_model": "qwen3-rerank",
"timeout": 30,
"top_n": 3,
"return_documents": False,
"instruct": "",
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
@@ -1342,6 +1356,21 @@ CONFIG_METADATA_2 = {
"description": "重排序模型名称",
"type": "string",
},
"top_n": {
"description": "返回排序后的top_n个文档",
"type": "int",
"hint": "默认返回全部文档。如果指定的值大于文档总数,将返回全部文档。",
},
"return_documents": {
"description": "是否在排序结果中返回文档原文",
"type": "bool",
"hint": "默认值false,以减少网络传输开销。",
},
"instruct": {
"description": "自定义排序任务类型说明",
"type": "string",
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
},
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"type": "bool",
+4
View File
@@ -331,6 +331,10 @@ class ProviderManager:
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
@@ -0,0 +1,171 @@
import os
import aiohttp
from astrbot import logger
from ..entities import ProviderType, RerankResult
from ..provider import RerankProvider
from ..register import register_provider_adapter
@register_provider_adapter(
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
)
class BailianRerankProvider(RerankProvider):
"""阿里云百炼文本重排序适配器"""
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
# API配置
self.api_key = provider_config.get("rerank_api_key", "")
if not self.api_key:
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")
if not self.api_key:
raise ValueError(
"阿里云百炼 API Key 不能为空,请在配置中设置 rerank_api_key 或设置环境变量 DASHSCOPE_API_KEY"
)
self.model = provider_config.get("rerank_model", "qwen3-rerank")
self.timeout = provider_config.get("timeout", 30)
self.default_top_n = provider_config.get("top_n", 5)
self.return_documents = provider_config.get("return_documents", False)
self.instruct = provider_config.get("instruct", "")
self.base_url = provider_config.get(
"rerank_api_base",
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
)
# 设置HTTP客户端
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
self.client = aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
)
# 设置模型名称
self.set_model(self.model)
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
) -> list[RerankResult]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回前N个结果如果为None则使用配置中的默认值
Returns:
重排序结果列表
"""
if not documents:
logger.warning("文档列表为空,返回空结果")
return []
if not query.strip():
logger.warning("查询文本为空,返回空结果")
return []
# 检查限制
if len(documents) > 500:
logger.warning(
f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
)
documents = documents[:500]
# 使用传入的top_n或默认配置
final_top_n = top_n if top_n is not None else self.default_top_n
try:
# 构建请求载荷
payload = {
"model": self.model,
"input": {"query": query, "documents": documents},
}
# 添加可选参数
parameters = {}
if final_top_n is not None:
parameters["top_n"] = final_top_n
if self.return_documents:
parameters["return_documents"] = True
if self.instruct and self.model == "qwen3-rerank":
parameters["instruct"] = self.instruct
if parameters:
payload["parameters"] = parameters
logger.debug(
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
)
# 发送请求
async with self.client.post(self.base_url, json=payload) as response:
response.raise_for_status()
response_data = await response.json()
# 检查响应状态
if "code" in response_data and response_data["code"] != "200":
error_msg = response_data.get("message", "未知错误")
raise Exception(
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
)
# 解析结果
output = response_data.get("output", {})
results = output.get("results", [])
if not results:
logger.warning(f"百炼 Rerank 返回空结果: {response_data}")
return []
# 转换为RerankResult对象
rerank_results = []
for result in results:
rerank_result = RerankResult(
index=result["index"], relevance_score=result["relevance_score"]
)
rerank_results.append(rerank_result)
logger.debug(f"百炼 Rerank 成功返回 {len(rerank_results)} 个结果")
# 记录使用量信息
usage = response_data.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
if total_tokens > 0:
logger.debug(f"百炼 Rerank 消耗 Token 数量: {total_tokens}")
return rerank_results
except aiohttp.ClientError as e:
logger.error(f"百炼 Rerank 网络请求失败: {e}")
raise Exception(f"网络请求失败: {e}")
except Exception as e:
logger.error(f"百炼 Rerank 处理失败: {e}")
raise Exception(f"重排序失败: {e}")
async def terminate(self) -> None:
"""关闭HTTP客户端会话"""
if self.client:
logger.info("关闭 百炼 Rerank 客户端会话")
try:
await self.client.close()
except Exception as e:
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
finally:
self.client = None