添加阿里百炼重排序模型
This commit is contained in:
@@ -1308,6 +1308,20 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 20,
|
||||
"launch_model_if_not_running": False,
|
||||
},
|
||||
"阿里云百炼重排序": {
|
||||
"id": "bailian_rerank",
|
||||
"type": "bailian_rerank",
|
||||
"provider": "bailian",
|
||||
"provider_type": "rerank",
|
||||
"enable": True,
|
||||
"rerank_api_key": "",
|
||||
"rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
||||
"rerank_model": "qwen3-rerank",
|
||||
"timeout": 30,
|
||||
"top_n": 3,
|
||||
"return_documents": False,
|
||||
"instruct": "",
|
||||
},
|
||||
"Xinference STT": {
|
||||
"id": "xinference_stt",
|
||||
"type": "xinference_stt",
|
||||
@@ -1342,6 +1356,21 @@ CONFIG_METADATA_2 = {
|
||||
"description": "重排序模型名称",
|
||||
"type": "string",
|
||||
},
|
||||
"top_n": {
|
||||
"description": "返回排序后的top_n个文档",
|
||||
"type": "int",
|
||||
"hint": "默认返回全部文档。如果指定的值大于文档总数,将返回全部文档。",
|
||||
},
|
||||
"return_documents": {
|
||||
"description": "是否在排序结果中返回文档原文",
|
||||
"type": "bool",
|
||||
"hint": "默认值false,以减少网络传输开销。",
|
||||
},
|
||||
"instruct": {
|
||||
"description": "自定义排序任务类型说明",
|
||||
"type": "string",
|
||||
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
|
||||
},
|
||||
"launch_model_if_not_running": {
|
||||
"description": "模型未运行时自动启动",
|
||||
"type": "bool",
|
||||
|
||||
@@ -331,6 +331,10 @@ class ProviderManager:
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..entities import ProviderType, RerankResult
|
||||
from ..provider import RerankProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
|
||||
)
|
||||
class BailianRerankProvider(RerankProvider):
|
||||
"""阿里云百炼文本重排序适配器"""
|
||||
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
# API配置
|
||||
self.api_key = provider_config.get("rerank_api_key", "")
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"阿里云百炼 API Key 不能为空,请在配置中设置 rerank_api_key 或设置环境变量 DASHSCOPE_API_KEY"
|
||||
)
|
||||
|
||||
self.model = provider_config.get("rerank_model", "qwen3-rerank")
|
||||
self.timeout = provider_config.get("timeout", 30)
|
||||
self.default_top_n = provider_config.get("top_n", 5)
|
||||
self.return_documents = provider_config.get("return_documents", False)
|
||||
self.instruct = provider_config.get("instruct", "")
|
||||
|
||||
self.base_url = provider_config.get(
|
||||
"rerank_api_base",
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
||||
)
|
||||
|
||||
# 设置HTTP客户端
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
self.client = aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
)
|
||||
|
||||
# 设置模型名称
|
||||
self.set_model(self.model)
|
||||
|
||||
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int | None = None,
|
||||
) -> list[RerankResult]:
|
||||
"""
|
||||
对文档进行重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 待排序的文档列表
|
||||
top_n: 返回前N个结果,如果为None则使用配置中的默认值
|
||||
|
||||
Returns:
|
||||
重排序结果列表
|
||||
"""
|
||||
if not documents:
|
||||
logger.warning("文档列表为空,返回空结果")
|
||||
return []
|
||||
|
||||
if not query.strip():
|
||||
logger.warning("查询文本为空,返回空结果")
|
||||
return []
|
||||
|
||||
# 检查限制
|
||||
if len(documents) > 500:
|
||||
logger.warning(
|
||||
f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
|
||||
)
|
||||
documents = documents[:500]
|
||||
|
||||
# 使用传入的top_n或默认配置
|
||||
final_top_n = top_n if top_n is not None else self.default_top_n
|
||||
|
||||
try:
|
||||
# 构建请求载荷
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": {"query": query, "documents": documents},
|
||||
}
|
||||
|
||||
# 添加可选参数
|
||||
parameters = {}
|
||||
if final_top_n is not None:
|
||||
parameters["top_n"] = final_top_n
|
||||
if self.return_documents:
|
||||
parameters["return_documents"] = True
|
||||
if self.instruct and self.model == "qwen3-rerank":
|
||||
parameters["instruct"] = self.instruct
|
||||
|
||||
if parameters:
|
||||
payload["parameters"] = parameters
|
||||
|
||||
logger.debug(
|
||||
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
|
||||
)
|
||||
|
||||
# 发送请求
|
||||
async with self.client.post(self.base_url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
response_data = await response.json()
|
||||
|
||||
# 检查响应状态
|
||||
if "code" in response_data and response_data["code"] != "200":
|
||||
error_msg = response_data.get("message", "未知错误")
|
||||
raise Exception(
|
||||
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
output = response_data.get("output", {})
|
||||
results = output.get("results", [])
|
||||
|
||||
if not results:
|
||||
logger.warning(f"百炼 Rerank 返回空结果: {response_data}")
|
||||
return []
|
||||
|
||||
# 转换为RerankResult对象
|
||||
rerank_results = []
|
||||
for result in results:
|
||||
rerank_result = RerankResult(
|
||||
index=result["index"], relevance_score=result["relevance_score"]
|
||||
)
|
||||
rerank_results.append(rerank_result)
|
||||
|
||||
logger.debug(f"百炼 Rerank 成功返回 {len(rerank_results)} 个结果")
|
||||
|
||||
# 记录使用量信息
|
||||
usage = response_data.get("usage", {})
|
||||
total_tokens = usage.get("total_tokens", 0)
|
||||
if total_tokens > 0:
|
||||
logger.debug(f"百炼 Rerank 消耗 Token 数量: {total_tokens}")
|
||||
|
||||
return rerank_results
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"百炼 Rerank 网络请求失败: {e}")
|
||||
raise Exception(f"网络请求失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"百炼 Rerank 处理失败: {e}")
|
||||
raise Exception(f"重排序失败: {e}")
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""关闭HTTP客户端会话"""
|
||||
if self.client:
|
||||
logger.info("关闭 百炼 Rerank 客户端会话")
|
||||
try:
|
||||
await self.client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
|
||||
finally:
|
||||
self.client = None
|
||||
Reference in New Issue
Block a user