f624971613
* chore(core.utils): 🚨 修正错误Lint
* chore(core.provider): 🚨 修复基类错误Lint
* chore(core.utils): 补全session_get()的重载
* chore(core.provider): 🚨 修正实现错误Lint
* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint
* chore(core.platform): 修正错误实现Lint
* fix(core.provider): 修复循环调用和错误assert
* chore(core.platform): 修复部分实现Lint
* chore(core.provider): 补充Dify.text_chat_stream的参数类型
* chore(core.pipeline): 🚨 修复错误Lint
* fix(core.slack): 补充遗漏导入
* chore(core.utils): 修复错误的session_get声明
* chore(core.platform): 移除Lark adapter import中的wildcard
* chore(core.db): 修复声明和部分逻辑
* chore(core.db): 添加typings,使faiss参数能被正确识别。
* chore(core): 修复声明
* chore(core): 修改声明
* chore: 补充faiss声明
* chore(dashboard): 修改实现,减少报错
* chore(package): 修改部分声明与实现,减少报错
* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查
* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查
* chore(core.config): 添加类型标注,通过类型检查
* chore(core.message): 为File._download_file添加检查,通过类型检查
* fix: 将断言改为条件判断以实现优雅关闭的容错性
* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 增强 Lark 相关空值/异常检查并完善日志输出
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将断言替换为条件检查并加入日志与错误处理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* chore: 移除LLM生成的无用注释
* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断
* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志
* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 去除cast,直接使用字段与字典访问,修正端口解析
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: group_name_display 时若 group 对象为空则记录错误并返回
* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置
* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_
* refactor: 移除 typing.cast,简化内容安全检查调用
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast
* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值
* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全
* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错
* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"
This reverts commit 1cbfdf9d1b.
* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接
* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型
* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转
* chore: ruff format
---------
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
281 lines
8.4 KiB
Python
281 lines
8.4 KiB
Python
"""检索管理器
|
|
|
|
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
|
|
"""
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
from astrbot import logger
|
|
from astrbot.core.db.vec_db.base import Result
|
|
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
|
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
|
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
|
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
|
from astrbot.core.provider.provider import RerankProvider
|
|
|
|
from ..kb_helper import KBHelper
|
|
|
|
|
|
@dataclass
|
|
class RetrievalResult:
|
|
"""检索结果"""
|
|
|
|
chunk_id: str
|
|
doc_id: str
|
|
doc_name: str
|
|
kb_id: str
|
|
kb_name: str
|
|
content: str
|
|
score: float
|
|
metadata: dict
|
|
|
|
|
|
class RetrievalManager:
|
|
"""检索管理器
|
|
|
|
职责:
|
|
- 协调稠密检索、稀疏检索和 Rerank
|
|
- 结果融合和排序
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sparse_retriever: SparseRetriever,
|
|
rank_fusion: RankFusion,
|
|
kb_db: KBSQLiteDatabase,
|
|
):
|
|
"""初始化检索管理器
|
|
|
|
Args:
|
|
vec_db_factory: 向量数据库工厂
|
|
sparse_retriever: 稀疏检索器
|
|
rank_fusion: 结果融合器
|
|
kb_db: 知识库数据库实例
|
|
|
|
"""
|
|
self.sparse_retriever = sparse_retriever
|
|
self.rank_fusion = rank_fusion
|
|
self.kb_db = kb_db
|
|
|
|
async def retrieve(
|
|
self,
|
|
query: str,
|
|
kb_ids: list[str],
|
|
kb_id_helper_map: dict[str, KBHelper],
|
|
top_k_fusion: int = 20,
|
|
top_m_final: int = 5,
|
|
) -> list[RetrievalResult]:
|
|
"""混合检索
|
|
|
|
流程:
|
|
1. 稠密检索 (向量相似度)
|
|
2. 稀疏检索 (BM25)
|
|
3. 结果融合 (RRF)
|
|
4. Rerank 重排序
|
|
|
|
Args:
|
|
query: 查询文本
|
|
kb_ids: 知识库 ID 列表
|
|
top_m_final: 最终返回数量
|
|
enable_rerank: 是否启用 Rerank
|
|
|
|
Returns:
|
|
List[RetrievalResult]: 检索结果列表
|
|
|
|
"""
|
|
if not kb_ids:
|
|
return []
|
|
|
|
kb_options: dict = {}
|
|
new_kb_ids = []
|
|
for kb_id in kb_ids:
|
|
kb_helper = kb_id_helper_map.get(kb_id)
|
|
if kb_helper:
|
|
kb = kb_helper.kb
|
|
kb_options[kb_id] = {
|
|
"top_k_dense": kb.top_k_dense or 50,
|
|
"top_k_sparse": kb.top_k_sparse or 50,
|
|
"top_m_final": kb.top_m_final or 5,
|
|
"vec_db": kb_helper.vec_db,
|
|
"rerank_provider_id": kb.rerank_provider_id,
|
|
}
|
|
new_kb_ids.append(kb_id)
|
|
else:
|
|
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
|
|
|
|
kb_ids = new_kb_ids
|
|
|
|
# 1. 稠密检索
|
|
time_start = time.time()
|
|
dense_results = await self._dense_retrieve(
|
|
query=query,
|
|
kb_ids=kb_ids,
|
|
kb_options=kb_options,
|
|
)
|
|
time_end = time.time()
|
|
logger.debug(
|
|
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.",
|
|
)
|
|
|
|
# 2. 稀疏检索
|
|
time_start = time.time()
|
|
sparse_results = await self.sparse_retriever.retrieve(
|
|
query=query,
|
|
kb_ids=kb_ids,
|
|
kb_options=kb_options,
|
|
)
|
|
time_end = time.time()
|
|
logger.debug(
|
|
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.",
|
|
)
|
|
|
|
# 3. 结果融合
|
|
time_start = time.time()
|
|
fused_results = await self.rank_fusion.fuse(
|
|
dense_results=dense_results,
|
|
sparse_results=sparse_results,
|
|
top_k=top_k_fusion,
|
|
)
|
|
time_end = time.time()
|
|
logger.debug(
|
|
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
|
|
)
|
|
|
|
# 4. 转换为 RetrievalResult (获取元数据)
|
|
retrieval_results = []
|
|
for fr in fused_results:
|
|
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
|
|
if metadata_dict:
|
|
retrieval_results.append(
|
|
RetrievalResult(
|
|
chunk_id=fr.chunk_id,
|
|
doc_id=fr.doc_id,
|
|
doc_name=metadata_dict["document"].doc_name,
|
|
kb_id=fr.kb_id,
|
|
kb_name=metadata_dict["knowledge_base"].kb_name,
|
|
content=fr.content,
|
|
score=fr.score,
|
|
metadata={
|
|
"chunk_index": fr.chunk_index,
|
|
"char_count": len(fr.content),
|
|
},
|
|
),
|
|
)
|
|
|
|
# 5. Rerank
|
|
first_rerank = None
|
|
for kb_id in kb_ids:
|
|
vec_db = kb_options[kb_id]["vec_db"]
|
|
if not isinstance(vec_db, FaissVecDB):
|
|
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
|
|
continue
|
|
|
|
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
|
if (
|
|
vec_db
|
|
and vec_db.rerank_provider
|
|
and rerank_pi
|
|
and rerank_pi == vec_db.rerank_provider.meta().id
|
|
):
|
|
first_rerank = vec_db.rerank_provider
|
|
break
|
|
if first_rerank and retrieval_results:
|
|
retrieval_results = await self._rerank(
|
|
query=query,
|
|
results=retrieval_results,
|
|
top_k=top_m_final,
|
|
rerank_provider=first_rerank,
|
|
)
|
|
|
|
return retrieval_results[:top_m_final]
|
|
|
|
async def _dense_retrieve(
|
|
self,
|
|
query: str,
|
|
kb_ids: list[str],
|
|
kb_options: dict,
|
|
):
|
|
"""稠密检索 (向量相似度)
|
|
|
|
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
|
|
|
|
Args:
|
|
query: 查询文本
|
|
kb_ids: 知识库 ID 列表
|
|
top_k: 返回结果数量
|
|
|
|
Returns:
|
|
List[Result]: 检索结果列表
|
|
|
|
"""
|
|
all_results: list[Result] = []
|
|
for kb_id in kb_ids:
|
|
if kb_id not in kb_options:
|
|
continue
|
|
try:
|
|
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
|
dense_k = int(kb_options[kb_id]["top_k_dense"])
|
|
vec_results = await vec_db.retrieve(
|
|
query=query,
|
|
k=dense_k,
|
|
fetch_k=dense_k * 2,
|
|
rerank=False, # 稠密检索阶段不进行 rerank
|
|
metadata_filters={"kb_id": kb_id},
|
|
)
|
|
|
|
all_results.extend(vec_results)
|
|
except Exception as e:
|
|
from astrbot.core import logger
|
|
|
|
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
|
|
continue
|
|
|
|
# 按相似度排序并返回 top_k
|
|
all_results.sort(key=lambda x: x.similarity, reverse=True)
|
|
# return all_results[: len(all_results) // len(kb_ids)]
|
|
return all_results
|
|
|
|
async def _rerank(
|
|
self,
|
|
query: str,
|
|
results: list[RetrievalResult],
|
|
top_k: int,
|
|
rerank_provider: RerankProvider,
|
|
) -> list[RetrievalResult]:
|
|
"""Rerank 重排序
|
|
|
|
Args:
|
|
query: 查询文本
|
|
results: 检索结果列表
|
|
top_k: 返回结果数量
|
|
|
|
Returns:
|
|
List[RetrievalResult]: 重排序后的结果列表
|
|
|
|
"""
|
|
if not results:
|
|
return []
|
|
|
|
# 准备文档列表
|
|
docs = [r.content for r in results]
|
|
|
|
# 调用 Rerank Provider
|
|
rerank_results = await rerank_provider.rerank(
|
|
query=query,
|
|
documents=docs,
|
|
)
|
|
|
|
# 更新分数并重新排序
|
|
reranked_list = []
|
|
for rerank_result in rerank_results:
|
|
idx = rerank_result.index
|
|
if idx < len(results):
|
|
result = results[idx]
|
|
result.score = rerank_result.relevance_score
|
|
reranked_list.append(result)
|
|
|
|
reranked_list.sort(key=lambda x: x.score, reverse=True)
|
|
|
|
return reranked_list[:top_k]
|