feat: 实现知识库核心后端模块
- 实现完整的知识库数据模型(知识库、文档、文档块、会话配置) - 实现基于 SQLite 的向量数据库存储和检索 - 实现文档解析器(PDF、TXT)和固定大小分块器 - 实现混合检索系统(密集向量检索 + BM25 稀疏检索 + RRF 融合) - 实现知识库生命周期管理和消息注入器 - 支持会话级别的知识库配置和关联
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
知识库管理模块
|
||||
|
||||
提供文档上传、解析、分块、向量化、检索等功能
|
||||
"""
|
||||
|
||||
from astrbot.core.db.po import KBSessionConfig
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
# 注意: 以下导入在对应模块实现后取消注释
|
||||
from .database import KBDatabase
|
||||
from .manager import KBManager
|
||||
from .manager_ops import KBManagerOps
|
||||
from .session_config_db import SessionConfigDB
|
||||
|
||||
# from .injector import KnowledgeBaseInjector
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBase",
|
||||
"KBDocument",
|
||||
"KBChunk",
|
||||
"KBMedia",
|
||||
"KBSessionConfig",
|
||||
"KBDatabase",
|
||||
"SessionConfigDB",
|
||||
"KBManager",
|
||||
"KBManagerOps",
|
||||
# "KnowledgeBaseInjector",
|
||||
]
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
文档分块模块
|
||||
"""
|
||||
|
||||
from .base import BaseChunker
|
||||
from .fixed_size import FixedSizeChunker
|
||||
|
||||
__all__ = [
|
||||
"BaseChunker",
|
||||
"FixedSizeChunker",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""文档分块器基类
|
||||
|
||||
定义了文档分块处理的抽象接口。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseChunker(ABC):
|
||||
"""分块器基类
|
||||
|
||||
所有分块器都应该继承此类并实现 chunk 方法。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def chunk(self, text: str) -> list[str]:
|
||||
"""将文本分块
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
list[str]: 分块后的文本列表
|
||||
"""
|
||||
@@ -0,0 +1,52 @@
|
||||
"""固定大小分块器
|
||||
|
||||
按照固定的字符数将文本分块,支持重叠区域。
|
||||
"""
|
||||
|
||||
from astrbot.core.knowledge_base.chunking.base import BaseChunker
|
||||
|
||||
|
||||
class FixedSizeChunker(BaseChunker):
|
||||
"""固定大小分块器
|
||||
|
||||
按照固定的字符数分块,并支持块之间的重叠。
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
|
||||
"""初始化分块器
|
||||
|
||||
Args:
|
||||
chunk_size: 块的大小(字符数)
|
||||
chunk_overlap: 块之间的重叠字符数
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
async def chunk(self, text: str) -> list[str]:
|
||||
"""固定大小分块
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
list[str]: 分块后的文本列表
|
||||
"""
|
||||
chunks = []
|
||||
start = 0
|
||||
text_len = len(text)
|
||||
|
||||
while start < text_len:
|
||||
end = start + self.chunk_size
|
||||
chunk = text[start:end]
|
||||
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# 移动窗口,保留重叠部分
|
||||
start = end - self.chunk_overlap
|
||||
|
||||
# 防止无限循环: 如果重叠过大,直接移到end
|
||||
if start >= end or self.chunk_overlap >= self.chunk_size:
|
||||
start = end
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,347 @@
|
||||
"""知识库数据库操作类
|
||||
|
||||
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
|
||||
|
||||
注意:
|
||||
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
|
||||
- 会话配置也存储在此数据库中,会话ID来源于主数据库
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
|
||||
class KBDatabase:
|
||||
"""知识库数据库操作类
|
||||
|
||||
职责:
|
||||
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
|
||||
- 统一异常处理
|
||||
|
||||
注意:
|
||||
- 该类操作独立的知识库数据库 (kb.db)
|
||||
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||
"""初始化知识库数据库操作类
|
||||
|
||||
Args:
|
||||
kb_db: 知识库独立数据库实例,而非主数据库
|
||||
"""
|
||||
self.db = kb_db
|
||||
|
||||
# ===== 知识库查询 =====
|
||||
|
||||
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""根据 ID 获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
|
||||
"""根据名称获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KnowledgeBase.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_kbs(self) -> int:
|
||||
"""统计知识库数量"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(func.count(KnowledgeBase.id))
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 文档查询 =====
|
||||
|
||||
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
|
||||
"""根据 ID 获取文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_documents_by_kb(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(KBDocument.kb_id == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBDocument.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_documents_by_kb(self, kb_id: str) -> int:
|
||||
"""统计知识库的文档数量"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 块查询 =====
|
||||
|
||||
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
|
||||
"""根据 ID 获取块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
|
||||
"""根据知识库 ID 列表获取所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
|
||||
"""根据向量文档 ID 获取块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
|
||||
"""获取块及其关联的文档和知识库元数据"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk, KBDocument, KnowledgeBase)
|
||||
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
|
||||
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
|
||||
.where(KBChunk.chunk_id == chunk_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
chunk, doc, kb = row
|
||||
return {
|
||||
"chunk": chunk,
|
||||
"document": doc,
|
||||
"knowledge_base": kb,
|
||||
}
|
||||
|
||||
async def list_chunks_by_doc(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBChunk]:
|
||||
"""列出文档的所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(KBChunk.doc_id == doc_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBChunk.chunk_index)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ===== 多媒体查询 =====
|
||||
|
||||
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
|
||||
"""根据 ID 获取多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
# ===== 会话配置查询 =====
|
||||
|
||||
async def get_session_kb_ids(self, session_id: str) -> list[str]:
|
||||
"""获取会话关联的知识库 ID 列表
|
||||
|
||||
查找顺序:
|
||||
1. 会话级别配置 (优先)
|
||||
2. 平台级别配置
|
||||
3. 返回空列表
|
||||
|
||||
Args:
|
||||
session_id: 会话ID(来自主数据库)
|
||||
|
||||
Returns:
|
||||
知识库ID列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 1. 查找会话级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "session",
|
||||
KBSessionConfig.scope_id == session_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 2. 提取平台 ID (格式: platform:xxx:session_id)
|
||||
parts = session_id.split(":")
|
||||
if len(parts) >= 2:
|
||||
platform_id = parts[0]
|
||||
|
||||
# 查找平台级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "platform",
|
||||
KBSessionConfig.scope_id == platform_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 3. 无配置
|
||||
return []
|
||||
|
||||
async def set_session_kb_ids(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
kb_ids: list[str],
|
||||
top_k: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KBSessionConfig:
|
||||
"""设置会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量 (可选)
|
||||
enable_rerank: 是否启用 Rerank (可选)
|
||||
|
||||
Returns:
|
||||
配置对象
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 查找现有配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.kb_ids = json.dumps(kb_ids)
|
||||
if top_k is not None:
|
||||
config.top_k = top_k
|
||||
if enable_rerank is not None:
|
||||
config.enable_rerank = enable_rerank
|
||||
else:
|
||||
# 创建新配置
|
||||
config = KBSessionConfig(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
kb_ids=json.dumps(kb_ids),
|
||||
top_k=top_k,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
|
||||
"""删除会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
|
||||
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID(来自主数据库)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
return await self.delete_session_kb_config("session", session_id)
|
||||
|
||||
async def list_all_session_configs(
|
||||
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
|
||||
) -> list[KBSessionConfig]:
|
||||
"""列出所有会话配置
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
scope: 可选的范围过滤 (session/platform)
|
||||
|
||||
Returns:
|
||||
会话配置列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig)
|
||||
|
||||
if scope:
|
||||
stmt = stmt.where(KBSessionConfig.scope == scope)
|
||||
|
||||
stmt = (
|
||||
stmt.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBSessionConfig.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
@@ -0,0 +1,139 @@
|
||||
"""知识库上下文注入器
|
||||
|
||||
负责检索相关知识并格式化为 LLM 可用的上下文文本
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.manager import (
|
||||
RetrievalManager,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseInjector:
|
||||
"""知识库上下文注入器
|
||||
|
||||
职责:
|
||||
- 检索相关知识
|
||||
- 格式化为上下文文本
|
||||
- 注入到 LLM Prompt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_db: KBDatabase,
|
||||
retrieval_manager: RetrievalManager,
|
||||
):
|
||||
"""初始化知识库上下文注入器
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
retrieval_manager: 检索管理器实例
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self.retrieval_manager = retrieval_manager
|
||||
|
||||
async def retrieve_and_inject(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> Optional[dict]:
|
||||
"""检索并注入知识库上下文
|
||||
|
||||
Args:
|
||||
unified_msg_origin: 统一消息来源 ID (会话 ID)
|
||||
query: 用户查询
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None
|
||||
{
|
||||
"context_text": str, # 格式化的上下文文本
|
||||
"results": List[dict], # 原始检索结果列表
|
||||
}
|
||||
"""
|
||||
# 1. 获取会话关联的知识库
|
||||
kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin)
|
||||
|
||||
if not kb_ids:
|
||||
return None
|
||||
|
||||
# 2. 检索知识
|
||||
results = await self.retrieval_manager.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_m_final=top_k,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# 3. 格式化上下文
|
||||
context_text = self._format_context(results)
|
||||
|
||||
# 4. 转换结果为字典格式
|
||||
results_dict = [
|
||||
{
|
||||
"chunk_id": r.chunk_id,
|
||||
"doc_id": r.doc_id,
|
||||
"kb_id": r.kb_id,
|
||||
"kb_name": r.kb_name,
|
||||
"doc_name": r.doc_name,
|
||||
"chunk_index": r.metadata.get("chunk_index", 0),
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return {
|
||||
"context_text": context_text,
|
||||
"results": results_dict,
|
||||
}
|
||||
|
||||
async def inject(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> Optional[str]:
|
||||
"""注入知识库上下文 (简化版本,仅返回文本)
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID (来自主数据库)
|
||||
query: 用户查询
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
Optional[str]: 格式化的知识上下文,如果无结果则返回 None
|
||||
"""
|
||||
result = await self.retrieve_and_inject(
|
||||
unified_msg_origin=session_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return result["context_text"] if result else None
|
||||
|
||||
def _format_context(self, results: List[RetrievalResult]) -> str:
|
||||
"""格式化知识上下文
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
str: 格式化的上下文文本
|
||||
"""
|
||||
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
lines.append(f"【知识 {i}】")
|
||||
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
||||
lines.append(f"内容: {result.content}")
|
||||
lines.append(f"相关度: {result.score:.2f}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
知识库管理器
|
||||
负责知识库模块的初始化、配置和资源管理
|
||||
|
||||
架构说明:
|
||||
- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db)
|
||||
- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
|
||||
|
||||
class KnowledgeBaseManager:
|
||||
"""知识库管理器
|
||||
|
||||
职责:
|
||||
- 知识库模块的初始化
|
||||
- Embedding Provider 和 Rerank Provider 的选择
|
||||
- 各个子组件的协调管理
|
||||
- 注册会话删除回调,实现级联清理
|
||||
|
||||
架构说明:
|
||||
- 知识库数据存储在独立数据库 (kb.db)
|
||||
- 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库
|
||||
- 通过回调机制实现与主数据库的生命周期同步
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
main_db: BaseDatabase,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
"""初始化知识库管理器
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
main_db: 主数据库实例 (不直接使用,仅用于类型引用)
|
||||
provider_manager: Provider 管理器
|
||||
"""
|
||||
self.config = config.get("knowledge_base", {})
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
# 知识库独立数据库
|
||||
self.kb_db = None
|
||||
|
||||
# 组件实例
|
||||
self.kb_database = None
|
||||
self.kb_manager = None
|
||||
self.kb_vec_db = None
|
||||
self.retrieval_manager = None
|
||||
self.kb_injector = None
|
||||
|
||||
self._initialized = False
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化知识库模块"""
|
||||
if not self.config.get("enabled", False):
|
||||
logger.info("知识库功能未启用")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
|
||||
# 1. 检查并选择 Embedding Provider
|
||||
embedding_provider = self._select_embedding_provider()
|
||||
if not embedding_provider:
|
||||
logger.warning("未配置 Embedding Provider,知识库功能无法使用")
|
||||
return
|
||||
|
||||
# 2. 初始化数据库
|
||||
await self._init_kb_database()
|
||||
await self._init_database()
|
||||
|
||||
# 3. 初始化向量数据库
|
||||
await self._init_vector_db(embedding_provider)
|
||||
|
||||
# 4. 初始化解析器和分块器
|
||||
parsers = self._init_parsers()
|
||||
chunker = self._init_chunker()
|
||||
|
||||
# 5. 初始化知识库管理器
|
||||
await self._init_kb_manager(parsers, chunker)
|
||||
|
||||
# 6. 初始化检索管理器
|
||||
await self._init_retrieval_manager()
|
||||
|
||||
# 7. 初始化上下文注入器
|
||||
await self._init_injector()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("知识库模块初始化完成")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"知识库模块导入失败: {e}")
|
||||
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
||||
except Exception as e:
|
||||
logger.error(f"知识库模块初始化失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _init_kb_database(self):
|
||||
"""初始化知识库独立数据库"""
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
|
||||
db_path = self.config.get("storage", {}).get(
|
||||
"kb_db_path", "data/knowledge_base/kb.db"
|
||||
)
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.kb_db = KBSQLiteDatabase(db_path)
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
|
||||
logger.info(f"知识库独立数据库已初始化: {db_path}")
|
||||
|
||||
async def _init_database(self):
|
||||
"""初始化知识库数据库操作类"""
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
|
||||
self.kb_database = KBDatabase(self.kb_db)
|
||||
|
||||
async def _init_vector_db(self, embedding_provider):
|
||||
"""初始化向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
storage_path = self.config.get("storage", {}).get(
|
||||
"vector_db_path", "data/knowledge_base/vectors"
|
||||
)
|
||||
Path(storage_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.kb_vec_db = FaissVecDB(
|
||||
doc_store_path=f"{storage_path}/documents.db",
|
||||
index_store_path=f"{storage_path}/index.faiss",
|
||||
embedding_provider=embedding_provider,
|
||||
)
|
||||
await self.kb_vec_db.initialize()
|
||||
|
||||
def _init_parsers(self) -> dict:
|
||||
"""初始化文档解析器"""
|
||||
from astrbot.core.knowledge_base.parsers.text_parser import TextParser
|
||||
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
|
||||
|
||||
return {
|
||||
"txt": TextParser(),
|
||||
"md": TextParser(),
|
||||
"markdown": TextParser(),
|
||||
"pdf": PDFParser(),
|
||||
}
|
||||
|
||||
def _init_chunker(self):
|
||||
"""初始化分块器"""
|
||||
from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker
|
||||
|
||||
chunking_config = self.config.get("chunking", {})
|
||||
return FixedSizeChunker(
|
||||
chunk_size=chunking_config.get("chunk_size", 512),
|
||||
chunk_overlap=chunking_config.get("chunk_overlap", 50),
|
||||
)
|
||||
|
||||
async def _init_kb_manager(self, parsers: dict, chunker):
|
||||
"""初始化知识库管理器"""
|
||||
from astrbot.core.knowledge_base.manager import KBManager
|
||||
|
||||
files_path = self.config.get("storage", {}).get(
|
||||
"files_path", "data/knowledge_base"
|
||||
)
|
||||
|
||||
self.kb_manager = KBManager(
|
||||
db=self.kb_db, # 使用独立的知识库数据库
|
||||
vec_db=self.kb_vec_db,
|
||||
storage_path=files_path,
|
||||
parsers=parsers,
|
||||
chunker=chunker,
|
||||
)
|
||||
|
||||
async def _init_retrieval_manager(self):
|
||||
"""初始化检索管理器"""
|
||||
from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import (
|
||||
SparseRetriever,
|
||||
)
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
|
||||
sparse_retriever = SparseRetriever(self.kb_database)
|
||||
rank_fusion = RankFusion(self.kb_database)
|
||||
|
||||
# 选择 Rerank Provider (可选)
|
||||
rerank_provider = self._select_rerank_provider()
|
||||
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
vec_db=self.kb_vec_db,
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_database,
|
||||
rerank_provider=rerank_provider,
|
||||
)
|
||||
|
||||
async def _init_injector(self):
|
||||
"""初始化上下文注入器"""
|
||||
from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector
|
||||
|
||||
self.kb_injector = KnowledgeBaseInjector(
|
||||
kb_db=self.kb_database,
|
||||
retrieval_manager=self.retrieval_manager,
|
||||
)
|
||||
|
||||
def _select_embedding_provider(self):
|
||||
"""选择 Embedding Provider
|
||||
|
||||
逻辑:
|
||||
- 如果配置了 embedding_provider_id,则使用指定的 provider
|
||||
- 如果没有配置,但有 embedding provider,则使用第一个
|
||||
- 如果有多个 embedding provider 但没有指定,则警告并使用第一个
|
||||
"""
|
||||
embedding_providers = self.provider_manager.embedding_provider_insts
|
||||
|
||||
if not embedding_providers:
|
||||
return None
|
||||
|
||||
configured_provider_id = self.config.get("embedding_provider_id")
|
||||
|
||||
if configured_provider_id:
|
||||
# 按 ID 查找
|
||||
for provider in embedding_providers:
|
||||
provider_id = provider.meta().id
|
||||
if provider_id == configured_provider_id:
|
||||
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
|
||||
return provider
|
||||
logger.warning(
|
||||
f"未找到配置的 Embedding Provider ID: {configured_provider_id},"
|
||||
f"将使用第一个可用的"
|
||||
)
|
||||
|
||||
if len(embedding_providers) > 1 and not configured_provider_id:
|
||||
logger.warning(
|
||||
f"检测到 {len(embedding_providers)} 个 Embedding Provider,"
|
||||
f"但未指定使用哪个,将默认使用第一个"
|
||||
)
|
||||
|
||||
provider = embedding_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
|
||||
return provider
|
||||
|
||||
def _select_rerank_provider(self):
|
||||
"""选择 Rerank Provider (可选)"""
|
||||
if not self.config.get("retrieval", {}).get("enable_rerank", True):
|
||||
return None
|
||||
|
||||
rerank_providers = self.provider_manager.rerank_provider_insts
|
||||
if not rerank_providers:
|
||||
return None
|
||||
|
||||
configured_provider_id = self.config.get("rerank_provider_id")
|
||||
|
||||
if configured_provider_id:
|
||||
for provider in rerank_providers:
|
||||
provider_id = provider.meta().id
|
||||
if provider_id == configured_provider_id:
|
||||
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
|
||||
return provider
|
||||
logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}")
|
||||
|
||||
if len(rerank_providers) > 0:
|
||||
provider = rerank_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""检查是否已初始化"""
|
||||
return self._initialized
|
||||
|
||||
def get_kb_manager(self):
|
||||
"""获取知识库管理器"""
|
||||
return self.kb_manager if self._initialized else None
|
||||
|
||||
def get_kb_injector(self):
|
||||
"""获取知识库上下文注入器"""
|
||||
return self.kb_injector if self._initialized else None
|
||||
|
||||
def register_session_lifecycle_hooks(self, conversation_manager):
|
||||
"""注册会话生命周期钩子
|
||||
|
||||
在会话删除时自动清理知识库配置,实现零侵入的级联清理。
|
||||
|
||||
Args:
|
||||
conversation_manager: 会话管理器实例
|
||||
"""
|
||||
if self._session_deleted_callback_registered or not self._initialized:
|
||||
return
|
||||
|
||||
async def on_session_deleted(session_id: str):
|
||||
"""会话删除回调:清理知识库配置"""
|
||||
try:
|
||||
await self.kb_database.delete_session_kb_config_by_session_id(session_id)
|
||||
logger.info(f"已清理会话知识库配置: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
|
||||
|
||||
conversation_manager.register_on_session_deleted(on_session_deleted)
|
||||
self._session_deleted_callback_registered = True
|
||||
logger.info("已注册知识库会话删除回调")
|
||||
|
||||
async def reinitialize(self):
|
||||
"""重新初始化知识库模块
|
||||
|
||||
用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后)
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.info("知识库模块已初始化,将重新初始化")
|
||||
await self.terminate()
|
||||
|
||||
await self.initialize()
|
||||
return self._initialized
|
||||
|
||||
async def terminate(self):
|
||||
"""终止知识库模块,清理资源"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("正在终止知识库模块...")
|
||||
|
||||
# 关闭向量数据库连接
|
||||
if self.kb_vec_db:
|
||||
try:
|
||||
await self.kb_vec_db.close()
|
||||
logger.debug("向量数据库已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库时出错: {e}")
|
||||
|
||||
# 关闭知识库独立数据库连接
|
||||
if self.kb_db:
|
||||
try:
|
||||
await self.kb_db.close()
|
||||
logger.debug("知识库数据库已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭知识库数据库时出错: {e}")
|
||||
|
||||
# 清理资源
|
||||
self._initialized = False
|
||||
self.kb_db = None
|
||||
self.kb_database = None
|
||||
self.kb_manager = None
|
||||
self.kb_vec_db = None
|
||||
self.retrieval_manager = None
|
||||
self.kb_injector = None
|
||||
|
||||
logger.info("知识库模块已终止")
|
||||
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
知识库独立 SQLite 数据库
|
||||
|
||||
该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。
|
||||
职责:
|
||||
- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media)
|
||||
- 提供数据库连接和会话管理
|
||||
- 执行数据库迁移和初始化
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
"""知识库独立 SQLite 数据库
|
||||
|
||||
与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。
|
||||
|
||||
特点:
|
||||
- 数据隔离: 知识库数据不会影响主数据库格式
|
||||
- 独立备份: 可以单独备份和恢复知识库数据
|
||||
- 性能隔离: 大量知识库查询不会影响主业务性能
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||
"""初始化知识库数据库
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# 确保目录存在
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建异步引擎
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""获取数据库会话
|
||||
|
||||
用法:
|
||||
async with kb_db.get_db() as session:
|
||||
# 执行数据库操作
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
# 创建所有知识库相关表
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
# 配置 SQLite 性能优化参数
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
self.inited = True
|
||||
logger.info(f"知识库数据库已初始化: {self.db_path}")
|
||||
|
||||
async def migrate_to_v1(self) -> None:
|
||||
"""执行知识库数据库 v1 迁移
|
||||
|
||||
创建所有必要的索引以优化查询性能
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
# 创建知识库表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||
"ON knowledge_bases(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||
"ON knowledge_bases(kb_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||
"ON knowledge_bases(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建文档表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||
"ON kb_documents(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||
"ON kb_documents(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||
"ON kb_documents(doc_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||
"ON kb_documents(file_type)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||
"ON kb_documents(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建块表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
|
||||
"ON kb_chunks(chunk_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
|
||||
"ON kb_chunks(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id "
|
||||
"ON kb_chunks(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
|
||||
"ON kb_chunks(vec_doc_id)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建多媒体表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||
"ON kb_media(media_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||
"ON kb_media(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id "
|
||||
"ON kb_media(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||
"ON kb_media(media_type)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建会话配置表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
|
||||
"ON kb_session_config(scope_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
|
||||
"ON kb_session_config(scope)"
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("知识库数据库迁移 v1 完成")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||
@@ -0,0 +1,349 @@
|
||||
"""知识库管理器
|
||||
|
||||
该模块提供知识库的CRUD操作和文档上传处理流程。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from sqlalchemy import func, select, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.knowledge_base.chunking.base import BaseChunker
|
||||
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
|
||||
from astrbot.core.knowledge_base.parsers.base import BaseParser
|
||||
|
||||
|
||||
class KBManager:
|
||||
"""知识库管理器
|
||||
|
||||
职责:
|
||||
- 知识库的 CRUD 操作
|
||||
- 文档上传与解析
|
||||
- 文档块生成与存储
|
||||
- 多媒体资源管理
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseDatabase,
|
||||
vec_db: BaseVecDB,
|
||||
storage_path: str,
|
||||
parsers: dict[str, BaseParser],
|
||||
chunker: BaseChunker,
|
||||
):
|
||||
self.db = db
|
||||
self.vec_db = vec_db
|
||||
self.storage_path = Path(storage_path)
|
||||
self.media_path = self.storage_path / "media"
|
||||
self.files_path = self.storage_path / "files"
|
||||
self.parsers = parsers
|
||||
self.chunker = chunker
|
||||
|
||||
# 确保目录存在
|
||||
self.media_path.mkdir(parents=True, exist_ok=True)
|
||||
self.files_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ===== 知识库操作 =====
|
||||
|
||||
async def create_kb(
|
||||
self,
|
||||
kb_name: str,
|
||||
description: Optional[str] = None,
|
||||
emoji: Optional[str] = None,
|
||||
embedding_provider_id: Optional[str] = None,
|
||||
rerank_provider_id: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
top_k_dense: Optional[int] = None,
|
||||
top_k_sparse: Optional[int] = None,
|
||||
top_m_final: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KnowledgeBase:
|
||||
"""创建知识库"""
|
||||
kb = KnowledgeBase(
|
||||
kb_name=kb_name,
|
||||
description=description,
|
||||
emoji=emoji or "📚",
|
||||
embedding_provider_id=embedding_provider_id,
|
||||
rerank_provider_id=rerank_provider_id,
|
||||
chunk_size=chunk_size if chunk_size is not None else 512,
|
||||
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
||||
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
||||
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
||||
top_m_final=top_m_final if top_m_final is not None else 5,
|
||||
enable_rerank=enable_rerank if enable_rerank is not None else True,
|
||||
)
|
||||
async with self.db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return kb
|
||||
|
||||
async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KnowledgeBase.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_kb(
|
||||
self,
|
||||
kb_id: str,
|
||||
kb_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
emoji: Optional[str] = None,
|
||||
embedding_provider_id: Optional[str] = None,
|
||||
rerank_provider_id: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
top_k_dense: Optional[int] = None,
|
||||
top_k_sparse: Optional[int] = None,
|
||||
top_m_final: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> Optional[KnowledgeBase]:
|
||||
"""更新知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
if not kb:
|
||||
return None
|
||||
|
||||
if kb_name is not None:
|
||||
kb.kb_name = kb_name
|
||||
if description is not None:
|
||||
kb.description = description
|
||||
if emoji is not None:
|
||||
kb.emoji = emoji
|
||||
if embedding_provider_id is not None:
|
||||
kb.embedding_provider_id = embedding_provider_id
|
||||
if rerank_provider_id is not None:
|
||||
kb.rerank_provider_id = rerank_provider_id
|
||||
if chunk_size is not None:
|
||||
kb.chunk_size = chunk_size
|
||||
if chunk_overlap is not None:
|
||||
kb.chunk_overlap = chunk_overlap
|
||||
if top_k_dense is not None:
|
||||
kb.top_k_dense = top_k_dense
|
||||
if top_k_sparse is not None:
|
||||
kb.top_k_sparse = top_k_sparse
|
||||
if top_m_final is not None:
|
||||
kb.top_m_final = top_m_final
|
||||
if enable_rerank is not None:
|
||||
kb.enable_rerank = enable_rerank
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return kb
|
||||
|
||||
async def delete_kb(self, kb_id: str) -> bool:
|
||||
"""删除知识库(级联删除所有文档和资源)"""
|
||||
# 1. 获取所有文档
|
||||
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
||||
|
||||
ops = KBManagerOps(self)
|
||||
docs = await ops.list_documents(kb_id)
|
||||
|
||||
# 2. 删除所有文档(包括文件和向量)
|
||||
for doc in docs:
|
||||
await ops.delete_document(doc.doc_id)
|
||||
|
||||
# 3. 删除知识库记录
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
if not kb:
|
||||
return False
|
||||
|
||||
await session.delete(kb)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
# ===== 文档上传 =====
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
kb_id: str,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
流程:
|
||||
1. 保存原始文件
|
||||
2. 解析文档内容
|
||||
3. 提取多媒体资源
|
||||
4. 分块处理
|
||||
5. 生成向量并存储
|
||||
6. 保存元数据(事务)
|
||||
7. 更新统计
|
||||
"""
|
||||
doc_id = str(uuid.uuid4())
|
||||
file_path = None
|
||||
media_paths = []
|
||||
vec_doc_ids = []
|
||||
|
||||
try:
|
||||
# 1. 保存原始文件
|
||||
file_path = self.files_path / kb_id / f"{doc_id}.{file_type}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(file_content)
|
||||
|
||||
# 2. 解析文档
|
||||
parser = self.parsers.get(file_type)
|
||||
if not parser:
|
||||
raise ValueError(f"不支持的文件类型: {file_type}")
|
||||
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
# 3. 保存多媒体资源
|
||||
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
||||
|
||||
ops = KBManagerOps(self)
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await ops._save_media(
|
||||
kb_id=kb_id,
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 4. 文档分块
|
||||
chunks_text = await self.chunker.chunk(text_content)
|
||||
|
||||
# 5. 生成向量并存储
|
||||
saved_chunks = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
# 存储到向量数据库
|
||||
vec_doc_id = await self.vec_db.insert(
|
||||
content=chunk_text,
|
||||
metadata={
|
||||
"kb_id": kb_id,
|
||||
"doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
},
|
||||
)
|
||||
vec_doc_ids.append(str(vec_doc_id))
|
||||
|
||||
# 保存块元数据
|
||||
chunk = KBChunk(
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
chunk_index=idx,
|
||||
content=chunk_text,
|
||||
char_count=len(chunk_text),
|
||||
vec_doc_id=str(vec_doc_id),
|
||||
)
|
||||
saved_chunks.append(chunk)
|
||||
|
||||
# 6. 保存文档元数据(事务)
|
||||
doc = KBDocument(
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
file_path=str(file_path),
|
||||
chunk_count=len(saved_chunks),
|
||||
media_count=len(saved_media),
|
||||
)
|
||||
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
for chunk in saved_chunks:
|
||||
session.add(chunk)
|
||||
for media in saved_media:
|
||||
session.add(media)
|
||||
await session.commit()
|
||||
|
||||
await session.refresh(doc)
|
||||
|
||||
# 7. 更新知识库统计
|
||||
await self._update_kb_stats(kb_id)
|
||||
|
||||
return doc
|
||||
|
||||
except Exception as e:
|
||||
# 失败清理:删除已创建的资源
|
||||
from astrbot.core import logger
|
||||
|
||||
logger.error(f"文档上传失败,开始清理资源: {e}")
|
||||
|
||||
# 清理向量数据库
|
||||
for vec_id in vec_doc_ids:
|
||||
try:
|
||||
await self.vec_db.delete(vec_id)
|
||||
except Exception as ve:
|
||||
logger.warning(f"清理向量失败 {vec_id}: {ve}")
|
||||
|
||||
# 清理多媒体文件
|
||||
for media_path in media_paths:
|
||||
try:
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as me:
|
||||
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
||||
|
||||
# 清理文档文件
|
||||
if file_path and file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
except Exception as fe:
|
||||
logger.warning(f"清理文档文件失败 {file_path}: {fe}")
|
||||
|
||||
# 重新抛出原始异常
|
||||
raise
|
||||
|
||||
# ===== 统计更新 =====
|
||||
|
||||
async def _update_kb_stats(self, kb_id: str):
|
||||
"""更新知识库统计信息(事务中执行)"""
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 统计文档数(在事务中查询)
|
||||
doc_count = await session.scalar(
|
||||
select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
|
||||
) or 0
|
||||
|
||||
# 统计块数(在事务中查询)
|
||||
chunk_count = await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
|
||||
) or 0
|
||||
|
||||
# 更新知识库(在同一事务中)
|
||||
await session.execute(
|
||||
update(KnowledgeBase)
|
||||
.where(KnowledgeBase.kb_id == kb_id)
|
||||
.values(doc_count=doc_count, chunk_count=chunk_count)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -0,0 +1,306 @@
|
||||
"""知识库管理器辅助操作
|
||||
|
||||
该模块提供文档、块和多媒体的管理操作。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiofiles
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.manager import KBManager
|
||||
|
||||
|
||||
class KBManagerOps:
|
||||
"""知识库管理器辅助操作类
|
||||
|
||||
职责:
|
||||
- 文档管理操作
|
||||
- 块管理操作
|
||||
- 多媒体管理操作
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "KBManager"):
|
||||
self.manager = manager
|
||||
self.db = manager.db
|
||||
self.vec_db = manager.vec_db
|
||||
self.media_path = manager.media_path
|
||||
self.files_path = manager.files_path
|
||||
|
||||
# ===== 文档操作 =====
|
||||
|
||||
async def list_documents(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(KBDocument.kb_id == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBDocument.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_document(self, doc_id: str) -> KBDocument | None:
|
||||
"""获取文档详情"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档(级联删除块、多媒体、向量)
|
||||
|
||||
采用三阶段删除策略:
|
||||
1. 删除向量数据库中的向量(允许部分失败)
|
||||
2. 删除SQL数据库中的记录(事务保证原子性)
|
||||
3. 删除文件系统中的文件(失败不影响数据一致性)
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 0. 获取文档信息
|
||||
doc = await self.get_document(doc_id)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# 收集所有需要删除的资源
|
||||
chunks = await self.list_chunks(doc_id)
|
||||
media_list = await self.list_media(doc_id)
|
||||
|
||||
# ===== 第一阶段: 删除向量(可重试) =====
|
||||
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
|
||||
deleted_vec_ids = []
|
||||
failed_vec_ids = []
|
||||
|
||||
for vec_id in vec_ids_to_delete:
|
||||
try:
|
||||
await self.vec_db.delete(vec_id)
|
||||
deleted_vec_ids.append(vec_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_id}, {e}")
|
||||
failed_vec_ids.append(vec_id)
|
||||
|
||||
# 如果向量删除失败过多(超过50%),中止操作
|
||||
if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5:
|
||||
logger.error(
|
||||
f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除"
|
||||
)
|
||||
return False
|
||||
|
||||
# 记录部分失败但继续执行
|
||||
if failed_vec_ids:
|
||||
logger.warning(
|
||||
f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作"
|
||||
)
|
||||
|
||||
# ===== 第二阶段: 删除数据库记录(事务) =====
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 删除块记录
|
||||
await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id))
|
||||
|
||||
# 删除多媒体记录
|
||||
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
|
||||
|
||||
# 删除文档记录
|
||||
await session.execute(delete(KBDocument).where(KBDocument.doc_id == doc_id))
|
||||
|
||||
await session.commit()
|
||||
|
||||
# ===== 第三阶段: 删除文件(失败不影响) =====
|
||||
# 删除多媒体文件
|
||||
for media in media_list:
|
||||
try:
|
||||
media_path = Path(media.file_path)
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}")
|
||||
|
||||
# 删除文档文件
|
||||
try:
|
||||
file_path = Path(doc.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除文档文件失败: {doc.file_path}, {e}")
|
||||
|
||||
# ===== 更新统计 =====
|
||||
await self.manager._update_kb_stats(doc.kb_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 块操作 =====
|
||||
|
||||
async def list_chunks(self, doc_id: str) -> list[KBChunk]:
|
||||
"""列出文档的所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(KBChunk.doc_id == doc_id)
|
||||
.order_by(KBChunk.chunk_index)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> bool:
|
||||
"""删除单个块
|
||||
|
||||
流程:
|
||||
1. 查询块信息
|
||||
2. 删除向量
|
||||
3. 删除数据库记录
|
||||
4. 更新文档统计
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 1. 查询块信息
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
result = await session.execute(stmt)
|
||||
chunk = result.scalar_one_or_none()
|
||||
if not chunk:
|
||||
return False
|
||||
|
||||
doc_id = chunk.doc_id
|
||||
vec_doc_id = chunk.vec_doc_id
|
||||
|
||||
# 2. 删除向量
|
||||
try:
|
||||
await self.vec_db.delete(vec_doc_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
|
||||
return False
|
||||
|
||||
# 3. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(delete(KBChunk).where(KBChunk.chunk_id == chunk_id))
|
||||
await session.commit()
|
||||
|
||||
# 4. 更新文档统计
|
||||
await self._update_doc_stats(doc_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 多媒体操作 =====
|
||||
|
||||
async def list_media(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_media(self, media_id: str) -> bool:
|
||||
"""删除多媒体资源
|
||||
|
||||
流程:
|
||||
1. 查询媒体信息
|
||||
2. 删除数据库记录
|
||||
3. 删除文件(失败不影响)
|
||||
4. 更新文档统计
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 1. 查询媒体信息
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
||||
result = await session.execute(stmt)
|
||||
media = result.scalar_one_or_none()
|
||||
if not media:
|
||||
return False
|
||||
|
||||
doc_id = media.doc_id
|
||||
file_path_str = media.file_path
|
||||
|
||||
# 2. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(delete(KBMedia).where(KBMedia.media_id == media_id))
|
||||
await session.commit()
|
||||
|
||||
# 3. 删除文件(失败不影响)
|
||||
try:
|
||||
media_path = Path(file_path_str)
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}")
|
||||
|
||||
# 4. 更新文档统计
|
||||
await self._update_doc_stats(doc_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 内部辅助方法 =====
|
||||
|
||||
async def _save_media(
|
||||
self,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
media_type: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
mime_type: str,
|
||||
) -> KBMedia:
|
||||
"""保存多媒体资源"""
|
||||
media_id = str(uuid.uuid4())
|
||||
ext = Path(file_name).suffix
|
||||
|
||||
# 保存文件
|
||||
file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
# 创建记录
|
||||
media = KBMedia(
|
||||
media_id=media_id,
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
media_type=media_type,
|
||||
file_name=file_name,
|
||||
file_path=str(file_path),
|
||||
file_size=len(content),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
return media
|
||||
|
||||
async def _update_doc_stats(self, doc_id: str):
|
||||
"""更新文档统计信息(事务中执行)"""
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 统计块数
|
||||
chunk_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id)
|
||||
)
|
||||
) or 0
|
||||
|
||||
# 统计多媒体数
|
||||
media_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id)
|
||||
)
|
||||
) or 0
|
||||
|
||||
# 更新文档
|
||||
doc = await session.scalar(
|
||||
select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
)
|
||||
if doc:
|
||||
doc.chunk_count = chunk_count
|
||||
doc.media_count = media_count
|
||||
|
||||
await session.commit()
|
||||
@@ -0,0 +1,184 @@
|
||||
"""知识库管理功能的数据模型定义
|
||||
|
||||
该模块定义了知识库系统所需的数据模型,包括:
|
||||
- KnowledgeBase: 知识库表 (存储在独立的 kb.db)
|
||||
- KBDocument: 文档表 (存储在独立的 kb.db)
|
||||
- KBChunk: 文档块表 (存储在独立的 kb.db)
|
||||
- KBMedia: 多媒体资源表 (存储在独立的 kb.db)
|
||||
- KBSessionConfig: 会话配置表 (存储在独立的 kb.db)
|
||||
|
||||
注意:
|
||||
- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db)
|
||||
- 与主数据库 (astrbot.db) 完全解耦
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
|
||||
|
||||
|
||||
class KnowledgeBase(SQLModel, table=True):
|
||||
"""知识库表
|
||||
|
||||
存储知识库的基本信息和统计数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_bases"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
kb_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
kb_name: str = Field(max_length=100, nullable=False)
|
||||
description: Optional[str] = Field(default=None, sa_type=Text)
|
||||
emoji: Optional[str] = Field(default="📚", max_length=10)
|
||||
embedding_provider_id: Optional[str] = Field(default=None, max_length=100)
|
||||
rerank_provider_id: Optional[str] = Field(default=None, max_length=100)
|
||||
# 分块配置参数
|
||||
chunk_size: Optional[int] = Field(default=512, nullable=True)
|
||||
chunk_overlap: Optional[int] = Field(default=50, nullable=True)
|
||||
# 检索配置参数
|
||||
top_k_dense: Optional[int] = Field(default=50, nullable=True)
|
||||
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
|
||||
top_m_final: Optional[int] = Field(default=5, nullable=True)
|
||||
enable_rerank: Optional[bool] = Field(default=True, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
doc_count: int = Field(default=0, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
|
||||
|
||||
class KBDocument(SQLModel, table=True):
|
||||
"""文档表
|
||||
|
||||
存储上传到知识库的文档元数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_documents"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
doc_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
doc_name: str = Field(max_length=255, nullable=False)
|
||||
file_type: str = Field(max_length=20, nullable=False)
|
||||
file_size: int = Field(nullable=False)
|
||||
file_path: str = Field(max_length=512, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
media_count: int = Field(default=0, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class KBChunk(SQLModel, table=True):
|
||||
"""文档块表
|
||||
|
||||
存储文档分块后的文本内容和向量索引关联信息。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_chunks"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
chunk_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
doc_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
chunk_index: int = Field(nullable=False)
|
||||
content: str = Field(sa_type=Text, nullable=False)
|
||||
char_count: int = Field(nullable=False)
|
||||
vec_doc_id: str = Field(max_length=100, nullable=False, index=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class KBMedia(SQLModel, table=True):
|
||||
"""多媒体资源表
|
||||
|
||||
存储从文档中提取的图片、视频等多媒体资源。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_media"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
media_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
doc_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
media_type: str = Field(max_length=20, nullable=False)
|
||||
file_name: str = Field(max_length=255, nullable=False)
|
||||
file_path: str = Field(max_length=512, nullable=False)
|
||||
file_size: int = Field(nullable=False)
|
||||
mime_type: str = Field(max_length=100, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class KBSessionConfig(SQLModel, table=True):
|
||||
"""会话知识库配置表
|
||||
|
||||
存储会话或平台级别的知识库关联配置。
|
||||
该表存储在知识库独立数据库中,保持完全解耦。
|
||||
|
||||
支持两种配置范围:
|
||||
- platform: 平台级别配置 (如 'qq', 'telegram')
|
||||
- session: 会话级别配置 (如 'qq:group:12345')
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_session_config"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
config_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
scope: str = Field(max_length=20, nullable=False)
|
||||
scope_id: str = Field(max_length=255, nullable=False, index=True)
|
||||
kb_ids: str = Field(sa_type=Text, nullable=False)
|
||||
top_k: Optional[int] = Field(default=None, nullable=True)
|
||||
enable_rerank: Optional[bool] = Field(default=None, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
文档解析器模块
|
||||
"""
|
||||
|
||||
from .base import BaseParser, MediaItem, ParseResult
|
||||
from .text_parser import TextParser
|
||||
from .pdf_parser import PDFParser
|
||||
|
||||
__all__ = [
|
||||
"BaseParser",
|
||||
"MediaItem",
|
||||
"ParseResult",
|
||||
"TextParser",
|
||||
"PDFParser",
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""文档解析器基类和数据结构
|
||||
|
||||
定义了文档解析器的抽象接口和相关数据类。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaItem:
|
||||
"""多媒体项
|
||||
|
||||
表示从文档中提取的多媒体资源。
|
||||
"""
|
||||
|
||||
media_type: str # image, video
|
||||
file_name: str
|
||||
content: bytes
|
||||
mime_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseResult:
|
||||
"""解析结果
|
||||
|
||||
包含解析后的文本内容和提取的多媒体资源。
|
||||
"""
|
||||
|
||||
text: str
|
||||
media: list[MediaItem]
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""文档解析器基类
|
||||
|
||||
所有文档解析器都应该继承此类并实现 parse 方法。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||
"""解析文档
|
||||
|
||||
Args:
|
||||
file_content: 文件内容
|
||||
file_name: 文件名
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
@@ -0,0 +1,100 @@
|
||||
"""PDF 文件解析器
|
||||
|
||||
支持解析 PDF 文件中的文本和图片资源。
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
from pypdf import PdfReader
|
||||
|
||||
from astrbot.core.knowledge_base.parsers.base import (
|
||||
BaseParser,
|
||||
MediaItem,
|
||||
ParseResult,
|
||||
)
|
||||
|
||||
|
||||
class PDFParser(BaseParser):
|
||||
"""PDF 文档解析器
|
||||
|
||||
提取 PDF 中的文本内容和嵌入的图片资源。
|
||||
"""
|
||||
|
||||
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||
"""解析 PDF 文件
|
||||
|
||||
Args:
|
||||
file_content: 文件内容
|
||||
file_name: 文件名
|
||||
|
||||
Returns:
|
||||
ParseResult: 包含文本和图片的解析结果
|
||||
"""
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
reader = PdfReader(pdf_file)
|
||||
|
||||
text_parts = []
|
||||
media_items = []
|
||||
|
||||
# 提取文本
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
# 提取图片
|
||||
image_counter = 0
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
try:
|
||||
# 安全检查 Resources
|
||||
if "/Resources" not in page:
|
||||
continue
|
||||
|
||||
resources = page["/Resources"]
|
||||
if not resources or "/XObject" not in resources:
|
||||
continue
|
||||
|
||||
xobjects = resources["/XObject"].get_object()
|
||||
if not xobjects:
|
||||
continue
|
||||
|
||||
for obj_name in xobjects:
|
||||
try:
|
||||
obj = xobjects[obj_name]
|
||||
|
||||
if obj.get("/Subtype") != "/Image":
|
||||
continue
|
||||
|
||||
# 提取图片数据
|
||||
image_data = obj.get_data()
|
||||
|
||||
# 确定格式
|
||||
filter_type = obj.get("/Filter", "")
|
||||
if filter_type == "/DCTDecode":
|
||||
ext = "jpg"
|
||||
mime_type = "image/jpeg"
|
||||
elif filter_type == "/FlateDecode":
|
||||
ext = "png"
|
||||
mime_type = "image/png"
|
||||
else:
|
||||
ext = "png"
|
||||
mime_type = "image/png"
|
||||
|
||||
image_counter += 1
|
||||
media_items.append(
|
||||
MediaItem(
|
||||
media_type="image",
|
||||
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
|
||||
content=image_data,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 单个图片提取失败不影响整体
|
||||
continue
|
||||
except Exception:
|
||||
# 页面处理失败不影响其他页面
|
||||
continue
|
||||
|
||||
full_text = "\n\n".join(text_parts)
|
||||
return ParseResult(text=full_text, media=media_items)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""文本文件解析器
|
||||
|
||||
支持解析 TXT 和 Markdown 文件。
|
||||
"""
|
||||
|
||||
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
|
||||
|
||||
|
||||
class TextParser(BaseParser):
|
||||
"""TXT/MD 文本解析器
|
||||
|
||||
支持多种字符编码的自动检测。
|
||||
"""
|
||||
|
||||
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||
"""解析文本文件
|
||||
|
||||
尝试使用多种编码解析文件内容。
|
||||
|
||||
Args:
|
||||
file_content: 文件内容
|
||||
file_name: 文件名
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果,不包含多媒体资源
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法解码文件
|
||||
"""
|
||||
# 尝试多种编码
|
||||
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
|
||||
try:
|
||||
text = file_content.decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"无法解码文件: {file_name}")
|
||||
|
||||
# 文本文件无多媒体资源
|
||||
return ParseResult(text=text, media=[])
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
检索模块
|
||||
"""
|
||||
|
||||
from .manager import RetrievalManager, RetrievalResult
|
||||
from .sparse_retriever import SparseRetriever, SparseResult
|
||||
from .rank_fusion import RankFusion, FusedResult
|
||||
|
||||
__all__ = [
|
||||
"RetrievalManager",
|
||||
"RetrievalResult",
|
||||
"SparseRetriever",
|
||||
"SparseResult",
|
||||
"RankFusion",
|
||||
"FusedResult",
|
||||
]
|
||||
@@ -0,0 +1,224 @@
|
||||
"""检索管理器
|
||||
|
||||
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""检索结果"""
|
||||
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
kb_id: str
|
||||
kb_name: str
|
||||
content: str
|
||||
score: float
|
||||
metadata: dict
|
||||
|
||||
|
||||
class RetrievalManager:
|
||||
"""检索管理器
|
||||
|
||||
职责:
|
||||
- 协调稠密检索、稀疏检索和 Rerank
|
||||
- 结果融合和排序
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vec_db: BaseVecDB,
|
||||
sparse_retriever: SparseRetriever,
|
||||
rank_fusion: RankFusion,
|
||||
kb_db: KBDatabase,
|
||||
rerank_provider: Optional[RerankProvider] = None,
|
||||
):
|
||||
"""初始化检索管理器
|
||||
|
||||
Args:
|
||||
vec_db: 向量数据库实例
|
||||
sparse_retriever: 稀疏检索器
|
||||
rank_fusion: 结果融合器
|
||||
kb_db: 知识库数据库实例
|
||||
rerank_provider: Rerank 提供商 (可选)
|
||||
"""
|
||||
self.vec_db = vec_db
|
||||
self.sparse_retriever = sparse_retriever
|
||||
self.rank_fusion = rank_fusion
|
||||
self.kb_db = kb_db
|
||||
self.rerank_provider = rerank_provider
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k_dense: int = 50,
|
||||
top_k_sparse: int = 50,
|
||||
top_n_fusion: int = 20,
|
||||
top_m_final: int = 5,
|
||||
enable_rerank: bool = True,
|
||||
) -> List[RetrievalResult]:
|
||||
"""混合检索
|
||||
|
||||
流程:
|
||||
1. 稠密检索 (向量相似度)
|
||||
2. 稀疏检索 (BM25)
|
||||
3. 结果融合 (RRF)
|
||||
4. Rerank 重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k_dense: 稠密检索返回数量
|
||||
top_k_sparse: 稀疏检索返回数量
|
||||
top_n_fusion: 融合后返回数量
|
||||
top_m_final: 最终返回数量
|
||||
enable_rerank: 是否启用 Rerank
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 检索结果列表
|
||||
"""
|
||||
# 1. 稠密检索
|
||||
dense_results = await self._dense_retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_k=top_k_dense,
|
||||
)
|
||||
|
||||
# 2. 稀疏检索
|
||||
sparse_results = await self.sparse_retriever.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_k=top_k_sparse,
|
||||
)
|
||||
|
||||
# 3. 结果融合
|
||||
fused_results = await self.rank_fusion.fuse(
|
||||
dense_results=dense_results,
|
||||
sparse_results=sparse_results,
|
||||
top_k=top_n_fusion,
|
||||
)
|
||||
|
||||
# 4. 转换为 RetrievalResult (获取元数据)
|
||||
retrieval_results = []
|
||||
for fr in fused_results:
|
||||
metadata_dict = await self.kb_db.get_chunk_with_metadata(fr.chunk_id)
|
||||
if metadata_dict:
|
||||
retrieval_results.append(
|
||||
RetrievalResult(
|
||||
chunk_id=fr.chunk_id,
|
||||
doc_id=fr.doc_id,
|
||||
doc_name=metadata_dict["document"].doc_name,
|
||||
kb_id=fr.kb_id,
|
||||
kb_name=metadata_dict["knowledge_base"].kb_name,
|
||||
content=fr.content,
|
||||
score=fr.score,
|
||||
metadata={
|
||||
"chunk_index": metadata_dict["chunk"].chunk_index,
|
||||
"char_count": metadata_dict["chunk"].char_count,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Rerank (可选)
|
||||
if enable_rerank and self.rerank_provider and retrieval_results:
|
||||
retrieval_results = await self._rerank(
|
||||
query=query,
|
||||
results=retrieval_results,
|
||||
top_k=top_m_final,
|
||||
)
|
||||
else:
|
||||
retrieval_results = retrieval_results[:top_m_final]
|
||||
|
||||
return retrieval_results
|
||||
|
||||
async def _dense_retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k: int,
|
||||
):
|
||||
"""稠密检索 (向量相似度)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
List[Result]: 检索结果列表
|
||||
"""
|
||||
# 直接调用向量数据库检索
|
||||
vec_results = await self.vec_db.retrieve(
|
||||
query=query,
|
||||
k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
|
||||
)
|
||||
|
||||
# 过滤:只保留指定知识库的结果
|
||||
filtered_results = []
|
||||
for result in vec_results:
|
||||
metadata_str = result.data.get("metadata", "{}")
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
metadata = {}
|
||||
|
||||
if metadata.get("kb_id") in kb_ids:
|
||||
filtered_results.append(result)
|
||||
|
||||
if len(filtered_results) >= top_k:
|
||||
break
|
||||
|
||||
return filtered_results[:top_k]
|
||||
|
||||
async def _rerank(
|
||||
self,
|
||||
query: str,
|
||||
results: List[RetrievalResult],
|
||||
top_k: int,
|
||||
) -> List[RetrievalResult]:
|
||||
"""Rerank 重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
results: 检索结果列表
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 重排序后的结果列表
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# 准备文档列表
|
||||
docs = [r.content for r in results]
|
||||
|
||||
# 调用 Rerank Provider
|
||||
rerank_results = await self.rerank_provider.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
)
|
||||
|
||||
# 更新分数并重新排序
|
||||
reranked_list = []
|
||||
for rerank_result in rerank_results:
|
||||
idx = rerank_result.index
|
||||
if idx < len(results):
|
||||
result = results[idx]
|
||||
result.score = rerank_result.relevance_score
|
||||
reranked_list.append(result)
|
||||
|
||||
reranked_list.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return reranked_list[:top_k]
|
||||
@@ -0,0 +1,134 @@
|
||||
"""检索结果融合器
|
||||
|
||||
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedResult:
|
||||
"""融合后的检索结果"""
|
||||
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
kb_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
|
||||
class RankFusion:
|
||||
"""检索结果融合器
|
||||
|
||||
职责:
|
||||
- 融合稠密检索和稀疏检索的结果
|
||||
- 使用 Reciprocal Rank Fusion (RRF) 算法
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBDatabase, k: int = 60):
|
||||
"""初始化结果融合器
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
k: RRF 参数,用于平滑排名
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self.k = k
|
||||
|
||||
async def fuse(
|
||||
self,
|
||||
dense_results: List[Result],
|
||||
sparse_results: List[SparseResult],
|
||||
top_k: int = 20,
|
||||
) -> List[FusedResult]:
|
||||
"""融合稠密和稀疏检索结果
|
||||
|
||||
RRF 公式:
|
||||
score(doc) = sum(1 / (k + rank_i))
|
||||
|
||||
Args:
|
||||
dense_results: 稠密检索结果
|
||||
sparse_results: 稀疏检索结果
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
List[FusedResult]: 融合后的结果列表
|
||||
"""
|
||||
# 1. 构建排名映射
|
||||
dense_ranks = {r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)}
|
||||
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
|
||||
|
||||
# 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id)
|
||||
# 需要统一为 chunk_id
|
||||
all_chunk_ids = set()
|
||||
vec_doc_id_to_dense = {} # vec_doc_id -> Result
|
||||
chunk_id_to_sparse = {} # chunk_id -> SparseResult
|
||||
|
||||
# 处理稀疏检索结果
|
||||
for r in sparse_results:
|
||||
all_chunk_ids.add(r.chunk_id)
|
||||
chunk_id_to_sparse[r.chunk_id] = r
|
||||
|
||||
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
|
||||
for r in dense_results:
|
||||
vec_doc_id = r.data["doc_id"]
|
||||
all_chunk_ids.add(vec_doc_id)
|
||||
vec_doc_id_to_dense[vec_doc_id] = r
|
||||
|
||||
# 3. 计算 RRF 分数
|
||||
rrf_scores: Dict[str, float] = {}
|
||||
|
||||
for identifier in all_chunk_ids:
|
||||
score = 0.0
|
||||
|
||||
# 来自稠密检索的贡献
|
||||
if identifier in dense_ranks:
|
||||
score += 1.0 / (self.k + dense_ranks[identifier])
|
||||
|
||||
# 来自稀疏检索的贡献
|
||||
if identifier in sparse_ranks:
|
||||
score += 1.0 / (self.k + sparse_ranks[identifier])
|
||||
|
||||
rrf_scores[identifier] = score
|
||||
|
||||
# 4. 排序
|
||||
sorted_ids = sorted(
|
||||
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
|
||||
)[:top_k]
|
||||
|
||||
# 5. 构建融合结果
|
||||
fused_results = []
|
||||
for identifier in sorted_ids:
|
||||
# 优先从稀疏检索获取完整信息
|
||||
if identifier in chunk_id_to_sparse:
|
||||
sr = chunk_id_to_sparse[identifier]
|
||||
fused_results.append(
|
||||
FusedResult(
|
||||
chunk_id=sr.chunk_id,
|
||||
doc_id=sr.doc_id,
|
||||
kb_id=sr.kb_id,
|
||||
content=sr.content,
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
)
|
||||
elif identifier in vec_doc_id_to_dense:
|
||||
# 从向量检索获取信息,需要从数据库获取块的详细信息
|
||||
dr = vec_doc_id_to_dense[identifier]
|
||||
chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier)
|
||||
if chunk:
|
||||
fused_results.append(
|
||||
FusedResult(
|
||||
chunk_id=chunk.chunk_id,
|
||||
doc_id=chunk.doc_id,
|
||||
kb_id=chunk.kb_id,
|
||||
content=chunk.content,
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
)
|
||||
|
||||
return fused_results
|
||||
@@ -0,0 +1,90 @@
|
||||
"""稀疏检索器
|
||||
|
||||
使用 BM25 算法进行基于关键词的文档检索
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseResult:
|
||||
"""稀疏检索结果"""
|
||||
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
kb_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
|
||||
class SparseRetriever:
|
||||
"""BM25 稀疏检索器
|
||||
|
||||
职责:
|
||||
- 基于关键词的文档检索
|
||||
- 使用 BM25 算法计算相关度
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBDatabase):
|
||||
"""初始化稀疏检索器
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self._index_cache = {} # 缓存 BM25 索引
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k: int = 50,
|
||||
) -> List[SparseResult]:
|
||||
"""执行稀疏检索
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
List[SparseResult]: 检索结果列表
|
||||
"""
|
||||
# 1. 获取所有相关块
|
||||
chunks = await self.kb_db.get_chunks_by_kb_ids(kb_ids)
|
||||
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
# 2. 准备文档和索引
|
||||
corpus = [chunk.content for chunk in chunks]
|
||||
tokenized_corpus = [doc.split() for doc in corpus]
|
||||
|
||||
# 3. 构建 BM25 索引
|
||||
bm25 = BM25Okapi(tokenized_corpus)
|
||||
|
||||
# 4. 执行检索
|
||||
tokenized_query = query.split()
|
||||
scores = bm25.get_scores(tokenized_query)
|
||||
|
||||
# 5. 排序并返回 Top-K
|
||||
results = []
|
||||
for idx, score in enumerate(scores):
|
||||
chunk = chunks[idx]
|
||||
results.append(
|
||||
SparseResult(
|
||||
chunk_id=chunk.chunk_id,
|
||||
doc_id=chunk.doc_id,
|
||||
kb_id=chunk.kb_id,
|
||||
content=chunk.content,
|
||||
score=float(score),
|
||||
)
|
||||
)
|
||||
|
||||
results.sort(key=lambda x: x.score, reverse=True)
|
||||
return results[:top_k]
|
||||
@@ -0,0 +1,157 @@
|
||||
"""会话知识库配置数据库操作
|
||||
|
||||
该模块封装会话知识库配置的数据库查询操作。
|
||||
|
||||
注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) 中,
|
||||
而不是主数据库 (astrbot.db) 中,以实现完全解耦。
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.models import KBSessionConfig
|
||||
|
||||
|
||||
class SessionConfigDB:
|
||||
"""会话知识库配置数据库操作类
|
||||
|
||||
职责:
|
||||
- 提供会话知识库配置管理
|
||||
- 统一异常处理
|
||||
|
||||
注意: 该类操作知识库独立数据库,实现完全解耦
|
||||
"""
|
||||
|
||||
def __init__(self, db: KBSQLiteDatabase):
|
||||
"""初始化会话配置数据库操作类
|
||||
|
||||
Args:
|
||||
db: 知识库独立数据库实例 (kb.db),不是主数据库
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
async def get_session_kb_ids(self, session_id: str) -> list[str]:
|
||||
"""获取会话关联的知识库 ID 列表
|
||||
|
||||
查找顺序:
|
||||
1. 会话级别配置 (优先)
|
||||
2. 平台级别配置
|
||||
3. 返回空列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 1. 查找会话级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "session",
|
||||
KBSessionConfig.scope_id == session_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 2. 提取平台 ID (格式: platform:xxx:session_id)
|
||||
parts = session_id.split(":")
|
||||
if len(parts) >= 2:
|
||||
platform_id = parts[0]
|
||||
|
||||
# 查找平台级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "platform",
|
||||
KBSessionConfig.scope_id == platform_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 3. 无配置
|
||||
return []
|
||||
|
||||
async def set_session_kb_ids(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
kb_ids: list[str],
|
||||
top_k: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KBSessionConfig:
|
||||
"""设置会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID)
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量 (可选)
|
||||
enable_rerank: 是否启用 Rerank (可选)
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 查找现有配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.kb_ids = json.dumps(kb_ids)
|
||||
if top_k is not None:
|
||||
config.top_k = top_k
|
||||
if enable_rerank is not None:
|
||||
config.enable_rerank = enable_rerank
|
||||
else:
|
||||
# 创建新配置
|
||||
config = KBSessionConfig(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
kb_ids=json.dumps(kb_ids),
|
||||
top_k=top_k,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
|
||||
"""删除会话知识库配置"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_all_session_configs(
|
||||
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
|
||||
) -> list[KBSessionConfig]:
|
||||
"""列出所有会话配置"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig)
|
||||
|
||||
if scope:
|
||||
stmt = stmt.where(KBSessionConfig.scope == scope)
|
||||
|
||||
stmt = (
|
||||
stmt.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBSessionConfig.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
Reference in New Issue
Block a user