stage
This commit is contained in:
@@ -113,7 +113,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(
|
||||
self.astrbot_config, self.db, self.provider_manager
|
||||
self.astrbot_config, self.provider_manager
|
||||
)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
|
||||
@@ -16,14 +16,23 @@ class BaseVecDB:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
@@ -44,3 +53,6 @@ class BaseVecDB:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self): ...
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
- 会话配置也存储在此数据库中,会话ID来源于主数据库
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, select
|
||||
@@ -17,7 +16,6 @@ from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
@@ -183,165 +181,3 @@ class KBDatabase:
|
||||
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
# ===== 会话配置查询 =====
|
||||
|
||||
async def get_session_kb_ids(self, session_id: str) -> list[str]:
|
||||
"""获取会话关联的知识库 ID 列表
|
||||
|
||||
查找顺序:
|
||||
1. 会话级别配置 (优先)
|
||||
2. 平台级别配置
|
||||
3. 返回空列表
|
||||
|
||||
Args:
|
||||
session_id: 会话ID(来自主数据库)
|
||||
|
||||
Returns:
|
||||
知识库ID列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 1. 查找会话级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "session",
|
||||
KBSessionConfig.scope_id == session_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 2. 提取平台 ID (格式: platform:xxx:session_id)
|
||||
parts = session_id.split(":")
|
||||
if len(parts) >= 2:
|
||||
platform_id = parts[0]
|
||||
|
||||
# 查找平台级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "platform",
|
||||
KBSessionConfig.scope_id == platform_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 3. 无配置
|
||||
return []
|
||||
|
||||
async def set_session_kb_ids(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
kb_ids: list[str],
|
||||
top_k: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KBSessionConfig:
|
||||
"""设置会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量 (可选)
|
||||
enable_rerank: 是否启用 Rerank (可选)
|
||||
|
||||
Returns:
|
||||
配置对象
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 查找现有配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.kb_ids = json.dumps(kb_ids)
|
||||
if top_k is not None:
|
||||
config.top_k = top_k
|
||||
if enable_rerank is not None:
|
||||
config.enable_rerank = enable_rerank
|
||||
else:
|
||||
# 创建新配置
|
||||
config = KBSessionConfig(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
kb_ids=json.dumps(kb_ids),
|
||||
top_k=top_k,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
|
||||
"""删除会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
|
||||
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID(来自主数据库)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
return await self.delete_session_kb_config("session", session_id)
|
||||
|
||||
async def list_all_session_configs(
|
||||
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
|
||||
) -> list[KBSessionConfig]:
|
||||
"""列出所有会话配置
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
scope: 可选的范围过滤 (session/platform)
|
||||
|
||||
Returns:
|
||||
会话配置列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig)
|
||||
|
||||
if scope:
|
||||
stmt = stmt.where(KBSessionConfig.scope == scope)
|
||||
|
||||
stmt = (
|
||||
stmt.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBSessionConfig.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@@ -10,6 +10,7 @@ from astrbot.core.knowledge_base.retrieval.manager import (
|
||||
RetrievalManager,
|
||||
RetrievalResult,
|
||||
)
|
||||
from .vec_db_factory import VecDBFactory
|
||||
|
||||
|
||||
class KnowledgeBaseInjector:
|
||||
@@ -24,6 +25,7 @@ class KnowledgeBaseInjector:
|
||||
def __init__(
|
||||
self,
|
||||
kb_db: KBDatabase,
|
||||
vec_db_factory: VecDBFactory,
|
||||
retrieval_manager: RetrievalManager,
|
||||
):
|
||||
"""初始化知识库上下文注入器
|
||||
@@ -33,18 +35,18 @@ class KnowledgeBaseInjector:
|
||||
retrieval_manager: 检索管理器实例
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self.vec_db_factory = vec_db_factory
|
||||
self.retrieval_manager = retrieval_manager
|
||||
|
||||
async def retrieve_and_inject(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
kb_ids: list[str],
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> Optional[dict]:
|
||||
"""检索并注入知识库上下文
|
||||
|
||||
Args:
|
||||
unified_msg_origin: 统一消息来源 ID (会话 ID)
|
||||
query: 用户查询
|
||||
top_k: 返回结果数量
|
||||
|
||||
@@ -55,14 +57,9 @@ class KnowledgeBaseInjector:
|
||||
"results": List[dict], # 原始检索结果列表
|
||||
}
|
||||
"""
|
||||
# 1. 获取会话关联的知识库
|
||||
kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin)
|
||||
|
||||
if not kb_ids:
|
||||
return None
|
||||
|
||||
# 2. 检索知识
|
||||
results = await self.retrieval_manager.retrieve(
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_m_final=top_k,
|
||||
|
||||
@@ -9,8 +9,18 @@
|
||||
|
||||
from pathlib import Path
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from .injector import KnowledgeBaseInjector
|
||||
from .retrieval.manager import RetrievalManager
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .kb_sqlite import KBSQLiteDatabase
|
||||
from .database import KBDatabase
|
||||
from .vec_db_factory import VecDBFactory
|
||||
from .manager import KBManager
|
||||
from .parsers.text_parser import TextParser
|
||||
from .parsers.pdf_parser import PDFParser
|
||||
from .chunking.fixed_size import FixedSizeChunker
|
||||
|
||||
|
||||
class KnowledgeBaseManager:
|
||||
@@ -28,32 +38,26 @@ class KnowledgeBaseManager:
|
||||
- 通过回调机制实现与主数据库的生命周期同步
|
||||
"""
|
||||
|
||||
kb_db: KBSQLiteDatabase
|
||||
vec_db_factory: VecDBFactory
|
||||
kb_database: KBDatabase
|
||||
kb_manager: KBManager
|
||||
retrieval_manager: RetrievalManager
|
||||
kb_injector: KnowledgeBaseInjector
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
main_db: BaseDatabase,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
"""初始化知识库管理器
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
main_db: 主数据库实例 (不直接使用,仅用于类型引用)
|
||||
provider_manager: Provider 管理器
|
||||
"""
|
||||
self.config = config.get("knowledge_base", {})
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
# 知识库独立数据库
|
||||
self.kb_db = None
|
||||
|
||||
# 组件实例
|
||||
self.kb_database = None
|
||||
self.kb_manager = None
|
||||
self.kb_vec_db = None
|
||||
self.retrieval_manager = None
|
||||
self.kb_injector = None
|
||||
|
||||
self._initialized = False
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
@@ -66,31 +70,54 @@ class KnowledgeBaseManager:
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
|
||||
# 1. 检查并选择 Embedding Provider
|
||||
embedding_provider = self._select_embedding_provider()
|
||||
if not embedding_provider:
|
||||
logger.warning("未配置 Embedding Provider,知识库功能无法使用")
|
||||
return
|
||||
|
||||
# 2. 初始化数据库
|
||||
# 初始化数据库
|
||||
await self._init_kb_database()
|
||||
await self._init_database()
|
||||
|
||||
# 3. 初始化向量数据库
|
||||
await self._init_vector_db(embedding_provider)
|
||||
# 初始化向量数据库工厂
|
||||
await self._init_vector_db_factory()
|
||||
|
||||
# 4. 初始化解析器和分块器
|
||||
parsers = self._init_parsers()
|
||||
chunker = self._init_chunker()
|
||||
# 初始化解析器和分块器
|
||||
parsers = {
|
||||
"txt": TextParser(),
|
||||
"md": TextParser(),
|
||||
"markdown": TextParser(),
|
||||
"pdf": PDFParser(),
|
||||
}
|
||||
chunking_config = self.config.get("chunking", {})
|
||||
chunker = FixedSizeChunker(
|
||||
chunk_size=chunking_config.get("chunk_size", 512),
|
||||
chunk_overlap=chunking_config.get("chunk_overlap", 50),
|
||||
)
|
||||
|
||||
# 5. 初始化知识库管理器
|
||||
await self._init_kb_manager(parsers, chunker)
|
||||
# 初始化知识库管理器
|
||||
files_path = self.config.get("storage", {}).get(
|
||||
"files_path", "data/knowledge_base"
|
||||
)
|
||||
self.kb_manager = KBManager(
|
||||
db=self.kb_db,
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
storage_path=files_path,
|
||||
parsers=parsers,
|
||||
chunker=chunker,
|
||||
provider_manager=self.provider_manager,
|
||||
)
|
||||
|
||||
# 6. 初始化检索管理器
|
||||
await self._init_retrieval_manager()
|
||||
# 初始化检索管理器
|
||||
sparse_retriever = SparseRetriever(self.kb_database)
|
||||
rank_fusion = RankFusion(self.kb_database)
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_database,
|
||||
)
|
||||
|
||||
# 7. 初始化上下文注入器
|
||||
await self._init_injector()
|
||||
# 初始化上下文注入器
|
||||
self.kb_injector = KnowledgeBaseInjector(
|
||||
kb_db=self.kb_database,
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
retrieval_manager=self.retrieval_manager,
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("知识库模块初始化完成")
|
||||
@@ -106,8 +133,6 @@ class KnowledgeBaseManager:
|
||||
|
||||
async def _init_kb_database(self):
|
||||
"""初始化知识库独立数据库"""
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
|
||||
db_path = self.config.get("storage", {}).get(
|
||||
"kb_db_path", "data/knowledge_base/kb.db"
|
||||
)
|
||||
@@ -116,168 +141,16 @@ class KnowledgeBaseManager:
|
||||
self.kb_db = KBSQLiteDatabase(db_path)
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
|
||||
logger.info(f"知识库独立数据库已初始化: {db_path}")
|
||||
|
||||
async def _init_database(self):
|
||||
"""初始化知识库数据库操作类"""
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
|
||||
self.kb_database = KBDatabase(self.kb_db)
|
||||
logger.info(f"KnowledgeBase database initialized: {db_path}")
|
||||
|
||||
async def _init_vector_db(self, embedding_provider):
|
||||
"""初始化向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
async def _init_vector_db_factory(self):
|
||||
"""初始化向量数据库工厂"""
|
||||
storage_path = self.config.get("storage", {}).get(
|
||||
"vector_db_path", "data/knowledge_base/vectors"
|
||||
)
|
||||
Path(storage_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.kb_vec_db = FaissVecDB(
|
||||
doc_store_path=f"{storage_path}/documents.db",
|
||||
index_store_path=f"{storage_path}/index.faiss",
|
||||
embedding_provider=embedding_provider,
|
||||
)
|
||||
await self.kb_vec_db.initialize()
|
||||
|
||||
def _init_parsers(self) -> dict:
|
||||
"""初始化文档解析器"""
|
||||
from astrbot.core.knowledge_base.parsers.text_parser import TextParser
|
||||
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
|
||||
|
||||
return {
|
||||
"txt": TextParser(),
|
||||
"md": TextParser(),
|
||||
"markdown": TextParser(),
|
||||
"pdf": PDFParser(),
|
||||
}
|
||||
|
||||
def _init_chunker(self):
|
||||
"""初始化分块器"""
|
||||
from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker
|
||||
|
||||
chunking_config = self.config.get("chunking", {})
|
||||
return FixedSizeChunker(
|
||||
chunk_size=chunking_config.get("chunk_size", 512),
|
||||
chunk_overlap=chunking_config.get("chunk_overlap", 50),
|
||||
)
|
||||
|
||||
async def _init_kb_manager(self, parsers: dict, chunker):
|
||||
"""初始化知识库管理器"""
|
||||
from astrbot.core.knowledge_base.manager import KBManager
|
||||
|
||||
files_path = self.config.get("storage", {}).get(
|
||||
"files_path", "data/knowledge_base"
|
||||
)
|
||||
|
||||
self.kb_manager = KBManager(
|
||||
db=self.kb_db, # 使用独立的知识库数据库
|
||||
vec_db=self.kb_vec_db,
|
||||
storage_path=files_path,
|
||||
parsers=parsers,
|
||||
chunker=chunker,
|
||||
provider_manager=self.provider_manager,
|
||||
)
|
||||
|
||||
async def _init_retrieval_manager(self):
|
||||
"""初始化检索管理器"""
|
||||
from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import (
|
||||
SparseRetriever,
|
||||
)
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
|
||||
sparse_retriever = SparseRetriever(self.kb_database)
|
||||
rank_fusion = RankFusion(self.kb_database)
|
||||
|
||||
# 选择 Rerank Provider (可选)
|
||||
rerank_provider = self._select_rerank_provider()
|
||||
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
vec_db=self.kb_vec_db,
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_database,
|
||||
rerank_provider=rerank_provider,
|
||||
)
|
||||
|
||||
async def _init_injector(self):
|
||||
"""初始化上下文注入器"""
|
||||
from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector
|
||||
|
||||
self.kb_injector = KnowledgeBaseInjector(
|
||||
kb_db=self.kb_database,
|
||||
retrieval_manager=self.retrieval_manager,
|
||||
)
|
||||
|
||||
def _select_embedding_provider(self):
|
||||
"""选择 Embedding Provider
|
||||
|
||||
逻辑:
|
||||
- 如果配置了 embedding_provider_id,则使用指定的 provider
|
||||
- 如果没有配置,但有 embedding provider,则使用第一个
|
||||
- 如果有多个 embedding provider 但没有指定,则警告并使用第一个
|
||||
"""
|
||||
embedding_providers = self.provider_manager.embedding_provider_insts
|
||||
|
||||
if not embedding_providers:
|
||||
return None
|
||||
|
||||
configured_provider_id = self.config.get("embedding_provider_id")
|
||||
|
||||
if configured_provider_id:
|
||||
# 按 ID 查找
|
||||
for provider in embedding_providers:
|
||||
provider_id = provider.meta().id
|
||||
if provider_id == configured_provider_id:
|
||||
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
|
||||
return provider
|
||||
logger.warning(
|
||||
f"未找到配置的 Embedding Provider ID: {configured_provider_id},"
|
||||
f"将使用第一个可用的"
|
||||
)
|
||||
|
||||
if len(embedding_providers) > 1 and not configured_provider_id:
|
||||
provider = embedding_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(
|
||||
f"检测到 {len(embedding_providers)} 个 Embedding Provider,"
|
||||
f"未在配置文件中指定 embedding_provider_id,将使用第一个: {provider_id}"
|
||||
)
|
||||
return provider
|
||||
|
||||
provider = embedding_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
|
||||
return provider
|
||||
|
||||
def _select_rerank_provider(self):
|
||||
"""选择 Rerank Provider (可选)"""
|
||||
if not self.config.get("retrieval", {}).get("enable_rerank", True):
|
||||
return None
|
||||
|
||||
rerank_providers = self.provider_manager.rerank_provider_insts
|
||||
if not rerank_providers:
|
||||
return None
|
||||
|
||||
configured_provider_id = self.config.get("rerank_provider_id")
|
||||
|
||||
if configured_provider_id:
|
||||
for provider in rerank_providers:
|
||||
provider_id = provider.meta().id
|
||||
if provider_id == configured_provider_id:
|
||||
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
|
||||
return provider
|
||||
logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}")
|
||||
|
||||
if len(rerank_providers) > 0:
|
||||
provider = rerank_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
|
||||
return provider
|
||||
|
||||
return None
|
||||
self.vec_db_factory = VecDBFactory(storage_base_path=storage_path)
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -292,31 +165,6 @@ class KnowledgeBaseManager:
|
||||
"""获取知识库上下文注入器"""
|
||||
return self.kb_injector if self._initialized else None
|
||||
|
||||
def register_session_lifecycle_hooks(self, conversation_manager):
|
||||
"""注册会话生命周期钩子
|
||||
|
||||
在会话删除时自动清理知识库配置,实现零侵入的级联清理。
|
||||
|
||||
Args:
|
||||
conversation_manager: 会话管理器实例
|
||||
"""
|
||||
if self._session_deleted_callback_registered or not self._initialized:
|
||||
return
|
||||
|
||||
async def on_session_deleted(session_id: str):
|
||||
"""会话删除回调:清理知识库配置"""
|
||||
try:
|
||||
await self.kb_database.delete_session_kb_config_by_session_id(
|
||||
session_id
|
||||
)
|
||||
logger.info(f"已清理会话知识库配置: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
|
||||
|
||||
conversation_manager.register_on_session_deleted(on_session_deleted)
|
||||
self._session_deleted_callback_registered = True
|
||||
logger.info("已注册知识库会话删除回调")
|
||||
|
||||
async def reinitialize(self):
|
||||
"""重新初始化知识库模块
|
||||
|
||||
@@ -336,13 +184,13 @@ class KnowledgeBaseManager:
|
||||
|
||||
logger.info("正在终止知识库模块...")
|
||||
|
||||
# 关闭向量数据库连接
|
||||
if self.kb_vec_db:
|
||||
# 关闭向量数据库工厂(关闭所有向量数据库实例)
|
||||
if self.vec_db_factory:
|
||||
try:
|
||||
await self.kb_vec_db.close()
|
||||
logger.debug("向量数据库已关闭")
|
||||
await self.vec_db_factory.close_all()
|
||||
logger.debug("向量数据库工厂已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库时出错: {e}")
|
||||
logger.warning(f"关闭向量数据库工厂时出错: {e}")
|
||||
|
||||
# 关闭知识库独立数据库连接
|
||||
if self.kb_db:
|
||||
@@ -352,13 +200,6 @@ class KnowledgeBaseManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭知识库数据库时出错: {e}")
|
||||
|
||||
# 清理资源
|
||||
self._initialized = False
|
||||
self.kb_db = None
|
||||
self.kb_database = None
|
||||
self.kb_manager = None
|
||||
self.kb_vec_db = None
|
||||
self.retrieval_manager = None
|
||||
self.kb_injector = None
|
||||
|
||||
logger.info("知识库模块已终止")
|
||||
|
||||
@@ -10,12 +10,11 @@ from typing import Optional
|
||||
import aiofiles
|
||||
from sqlalchemy import func, select, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from .kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.chunking.base import BaseChunker
|
||||
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
|
||||
from astrbot.core.knowledge_base.parsers.base import BaseParser
|
||||
|
||||
from .vec_db_factory import VecDBFactory
|
||||
|
||||
class KBManager:
|
||||
"""知识库管理器
|
||||
@@ -29,15 +28,15 @@ class KBManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseDatabase,
|
||||
vec_db: BaseVecDB,
|
||||
db: KBSQLiteDatabase,
|
||||
vec_db_factory: VecDBFactory,
|
||||
storage_path: str,
|
||||
parsers: dict[str, BaseParser],
|
||||
chunker: BaseChunker,
|
||||
provider_manager=None,
|
||||
):
|
||||
self.db = db
|
||||
self.vec_db = vec_db
|
||||
self.vec_db_factory = vec_db_factory
|
||||
self.storage_path = Path(storage_path)
|
||||
self.media_path = self.storage_path / "media"
|
||||
self.files_path = self.storage_path / "files"
|
||||
@@ -49,6 +48,48 @@ class KBManager:
|
||||
self.media_path.mkdir(parents=True, exist_ok=True)
|
||||
self.files_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _get_embedding_provider_for_kb(self, kb_id: str):
|
||||
"""根据知识库配置获取 Embedding Provider
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider: Embedding Provider 实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果找不到合适的 embedding provider
|
||||
"""
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
|
||||
# 获取知识库配置
|
||||
kb_database = KBDatabase(self.db)
|
||||
kb = await kb_database.get_kb_by_id(kb_id)
|
||||
if not kb:
|
||||
raise ValueError(f"知识库不存在: {kb_id}")
|
||||
|
||||
embedding_provider_id = kb.embedding_provider_id
|
||||
|
||||
# 如果没有 provider_manager,使用默认的第一个
|
||||
if not self.provider_manager:
|
||||
raise ValueError("Provider Manager 未初始化")
|
||||
|
||||
embedding_providers = self.provider_manager.embedding_provider_insts
|
||||
if not embedding_providers:
|
||||
raise ValueError("系统中没有可用的 Embedding Provider")
|
||||
|
||||
# 如果指定了 provider ID,则查找该 provider
|
||||
if embedding_provider_id:
|
||||
for provider in embedding_providers:
|
||||
if provider.meta().id == embedding_provider_id:
|
||||
return provider
|
||||
raise ValueError(
|
||||
f"未找到配置的 Embedding Provider: {embedding_provider_id}"
|
||||
)
|
||||
|
||||
# 使用第一个可用的 provider
|
||||
return embedding_providers[0]
|
||||
|
||||
# ===== 知识库操作 =====
|
||||
|
||||
async def create_kb(
|
||||
@@ -77,7 +118,7 @@ class KBManager:
|
||||
# 检查是否有可用的 rerank provider
|
||||
has_rerank_provider = (
|
||||
self.provider_manager
|
||||
and hasattr(self.provider_manager, 'rerank_provider_insts')
|
||||
and hasattr(self.provider_manager, "rerank_provider_insts")
|
||||
and len(self.provider_manager.rerank_provider_insts) > 0
|
||||
)
|
||||
enable_rerank = has_rerank_provider
|
||||
@@ -182,7 +223,10 @@ class KBManager:
|
||||
for doc in docs:
|
||||
await ops.delete_document(doc.doc_id)
|
||||
|
||||
# 3. 删除知识库记录
|
||||
# 3. 删除向量数据库
|
||||
await self.vec_db_factory.delete_vec_db(kb_id)
|
||||
|
||||
# 4. 删除知识库记录
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
@@ -257,11 +301,15 @@ class KBManager:
|
||||
# 4. 文档分块
|
||||
chunks_text = await self.chunker.chunk(text_content)
|
||||
|
||||
# 5. 生成向量并存储
|
||||
# 5. 获取 Embedding Provider 和向量数据库
|
||||
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
|
||||
|
||||
# 6. 生成向量并存储
|
||||
saved_chunks = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
# 存储到向量数据库
|
||||
vec_doc_id = await self.vec_db.insert(
|
||||
vec_doc_id = await vec_db.insert(
|
||||
content=chunk_text,
|
||||
metadata={
|
||||
"kb_id": kb_id,
|
||||
@@ -282,7 +330,7 @@ class KBManager:
|
||||
)
|
||||
saved_chunks.append(chunk)
|
||||
|
||||
# 6. 保存文档元数据(事务)
|
||||
# 7. 保存文档元数据(事务)
|
||||
doc = KBDocument(
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
@@ -305,7 +353,7 @@ class KBManager:
|
||||
|
||||
await session.refresh(doc)
|
||||
|
||||
# 7. 更新知识库统计
|
||||
# 8. 更新知识库统计
|
||||
await self._update_kb_stats(kb_id)
|
||||
|
||||
return doc
|
||||
@@ -316,12 +364,19 @@ class KBManager:
|
||||
|
||||
logger.error(f"文档上传失败,开始清理资源: {e}")
|
||||
|
||||
# 清理向量数据库
|
||||
for vec_id in vec_doc_ids:
|
||||
try:
|
||||
await self.vec_db.delete(vec_id)
|
||||
except Exception as ve:
|
||||
logger.warning(f"清理向量失败 {vec_id}: {ve}")
|
||||
# 获取知识库的向量数据库
|
||||
try:
|
||||
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
|
||||
|
||||
# 清理向量数据库
|
||||
for vec_id in vec_doc_ids:
|
||||
try:
|
||||
await vec_db.delete(vec_id)
|
||||
except Exception as ve:
|
||||
logger.warning(f"清理向量失败 {vec_id}: {ve}")
|
||||
except Exception as vfe:
|
||||
logger.error(f"获取向量数据库失败: {vfe}")
|
||||
|
||||
# 清理多媒体文件
|
||||
for media_path in media_paths:
|
||||
|
||||
@@ -28,7 +28,7 @@ class KBManagerOps:
|
||||
def __init__(self, manager: "KBManager"):
|
||||
self.manager = manager
|
||||
self.db = manager.db
|
||||
self.vec_db = manager.vec_db
|
||||
self.vec_db_factory = manager.vec_db_factory
|
||||
self.media_path = manager.media_path
|
||||
self.files_path = manager.files_path
|
||||
|
||||
@@ -75,6 +75,12 @@ class KBManagerOps:
|
||||
chunks = await self.list_chunks(doc_id)
|
||||
media_list = await self.list_media(doc_id)
|
||||
|
||||
# 获取知识库的向量数据库
|
||||
embedding_provider = await self.manager._get_embedding_provider_for_kb(
|
||||
doc.kb_id
|
||||
)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(doc.kb_id, embedding_provider)
|
||||
|
||||
# ===== 第一阶段: 删除向量(可重试) =====
|
||||
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
|
||||
deleted_vec_ids = []
|
||||
@@ -82,7 +88,7 @@ class KBManagerOps:
|
||||
|
||||
for vec_id in vec_ids_to_delete:
|
||||
try:
|
||||
await self.vec_db.delete(vec_id)
|
||||
await vec_db.delete(vec_id)
|
||||
deleted_vec_ids.append(vec_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_id}, {e}")
|
||||
@@ -173,11 +179,16 @@ class KBManagerOps:
|
||||
return False
|
||||
|
||||
doc_id = chunk.doc_id
|
||||
kb_id = chunk.kb_id
|
||||
vec_doc_id = chunk.vec_doc_id
|
||||
|
||||
# 2. 删除向量
|
||||
# 2. 获取知识库的向量数据库并删除向量
|
||||
try:
|
||||
await self.vec_db.delete(vec_doc_id)
|
||||
embedding_provider = await self.manager._get_embedding_provider_for_kb(
|
||||
kb_id
|
||||
)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
|
||||
await vec_db.delete(vec_doc_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
|
||||
return False
|
||||
|
||||
@@ -16,7 +16,7 @@ import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
|
||||
from sqlmodel import Field, SQLModel, Text
|
||||
|
||||
|
||||
class KnowledgeBase(SQLModel, table=True):
|
||||
@@ -25,7 +25,7 @@ class KnowledgeBase(SQLModel, table=True):
|
||||
存储知识库的基本信息和统计数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_bases"
|
||||
__tablename__ = "knowledge_bases" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
@@ -65,7 +65,7 @@ class KBDocument(SQLModel, table=True):
|
||||
存储上传到知识库的文档元数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_documents"
|
||||
__tablename__ = "kb_documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
@@ -97,7 +97,7 @@ class KBChunk(SQLModel, table=True):
|
||||
存储文档分块后的文本内容和向量索引关联信息。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_chunks"
|
||||
__tablename__ = "kb_chunks" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
@@ -124,7 +124,7 @@ class KBMedia(SQLModel, table=True):
|
||||
存储从文档中提取的图片、视频等多媒体资源。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_media"
|
||||
__tablename__ = "kb_media" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
@@ -144,39 +144,3 @@ class KBMedia(SQLModel, table=True):
|
||||
file_size: int = Field(nullable=False)
|
||||
mime_type: str = Field(max_length=100, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class KBSessionConfig(SQLModel, table=True):
|
||||
"""会话知识库配置表
|
||||
|
||||
存储会话或平台级别的知识库关联配置。
|
||||
该表存储在知识库独立数据库中,保持完全解耦。
|
||||
|
||||
支持两种配置范围:
|
||||
- platform: 平台级别配置 (如 'qq', 'telegram')
|
||||
- session: 会话级别配置 (如 'qq:group:12345')
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_session_config"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
config_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
scope: str = Field(max_length=20, nullable=False)
|
||||
scope_id: str = Field(max_length=255, nullable=False, index=True)
|
||||
kb_ids: str = Field(sa_type=Text, nullable=False)
|
||||
top_k: Optional[int] = Field(default=None, nullable=True)
|
||||
enable_rerank: Optional[bool] = Field(default=None, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),)
|
||||
|
||||
@@ -3,16 +3,15 @@
|
||||
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB, Result
|
||||
from ..vec_db_factory import VecDBFactory
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
@@ -38,36 +37,34 @@ class RetrievalManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vec_db: BaseVecDB,
|
||||
vec_db_factory, # VecDBFactory
|
||||
sparse_retriever: SparseRetriever,
|
||||
rank_fusion: RankFusion,
|
||||
kb_db: KBDatabase,
|
||||
rerank_provider: Optional[RerankProvider] = None,
|
||||
):
|
||||
"""初始化检索管理器
|
||||
|
||||
Args:
|
||||
vec_db: 向量数据库实例
|
||||
vec_db_factory: 向量数据库工厂
|
||||
sparse_retriever: 稀疏检索器
|
||||
rank_fusion: 结果融合器
|
||||
kb_db: 知识库数据库实例
|
||||
rerank_provider: Rerank 提供商 (可选)
|
||||
"""
|
||||
self.vec_db = vec_db
|
||||
self.vec_db_factory = vec_db_factory
|
||||
self.sparse_retriever = sparse_retriever
|
||||
self.rank_fusion = rank_fusion
|
||||
self.kb_db = kb_db
|
||||
self.rerank_provider = rerank_provider
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
vec_db_factory: VecDBFactory,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k_dense: int = 50,
|
||||
top_k_sparse: int = 50,
|
||||
top_n_fusion: int = 20,
|
||||
top_m_final: int = 5,
|
||||
enable_rerank: bool = True,
|
||||
rerank_provider: RerankProvider | None = None,
|
||||
) -> List[RetrievalResult]:
|
||||
"""混合检索
|
||||
|
||||
@@ -94,6 +91,7 @@ class RetrievalManager:
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_k=top_k_dense,
|
||||
vec_db=vec_db,
|
||||
)
|
||||
|
||||
# 2. 稀疏检索
|
||||
@@ -131,13 +129,13 @@ class RetrievalManager:
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Rerank (可选)
|
||||
if enable_rerank and self.rerank_provider and retrieval_results:
|
||||
# 5. Rerank
|
||||
if rerank_provider and retrieval_results:
|
||||
retrieval_results = await self._rerank(
|
||||
query=query,
|
||||
results=retrieval_results,
|
||||
top_k=top_m_final,
|
||||
rerank_provider=self.rerank_provider,
|
||||
rerank_provider=rerank_provider,
|
||||
)
|
||||
else:
|
||||
retrieval_results = retrieval_results[:top_m_final]
|
||||
@@ -149,9 +147,12 @@ class RetrievalManager:
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k: int,
|
||||
vec_db: BaseVecDB,
|
||||
):
|
||||
"""稠密检索 (向量相似度)
|
||||
|
||||
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
@@ -160,28 +161,27 @@ class RetrievalManager:
|
||||
Returns:
|
||||
List[Result]: 检索结果列表
|
||||
"""
|
||||
# 直接调用向量数据库检索
|
||||
vec_results = await self.vec_db.retrieve(
|
||||
query=query,
|
||||
top_k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
|
||||
)
|
||||
all_results: list[Result] = []
|
||||
|
||||
# 过滤:只保留指定知识库的结果
|
||||
filtered_results = []
|
||||
for result in vec_results:
|
||||
metadata_str = result.data.get("metadata", "{}")
|
||||
for kb_id in kb_ids:
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
metadata = {}
|
||||
vec_results = await vec_db.retrieve(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
fetch_k=top_k * 2,
|
||||
metadata_filters={"kb_id": kb_id},
|
||||
)
|
||||
|
||||
if metadata.get("kb_id") in kb_ids:
|
||||
filtered_results.append(result)
|
||||
all_results.extend(vec_results)
|
||||
except Exception as e:
|
||||
from astrbot.core import logger
|
||||
|
||||
if len(filtered_results) >= top_k:
|
||||
break
|
||||
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
|
||||
continue
|
||||
|
||||
return filtered_results[:top_k]
|
||||
# 按相似度排序并返回 top_k
|
||||
all_results.sort(key=lambda x: x.similarity, reverse=True)
|
||||
return all_results[:top_k]
|
||||
|
||||
async def _rerank(
|
||||
self,
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""会话知识库配置数据库操作
|
||||
|
||||
该模块封装会话知识库配置的数据库查询操作。
|
||||
|
||||
注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) 中,
|
||||
而不是主数据库 (astrbot.db) 中,以实现完全解耦。
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.models import KBSessionConfig
|
||||
|
||||
|
||||
class SessionConfigDB:
|
||||
"""会话知识库配置数据库操作类
|
||||
|
||||
职责:
|
||||
- 提供会话知识库配置管理
|
||||
- 统一异常处理
|
||||
|
||||
注意: 该类操作知识库独立数据库,实现完全解耦
|
||||
"""
|
||||
|
||||
def __init__(self, db: KBSQLiteDatabase):
|
||||
"""初始化会话配置数据库操作类
|
||||
|
||||
Args:
|
||||
db: 知识库独立数据库实例 (kb.db),不是主数据库
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
async def get_session_kb_ids(self, session_id: str) -> list[str]:
|
||||
"""获取会话关联的知识库 ID 列表
|
||||
|
||||
查找顺序:
|
||||
1. 会话级别配置 (优先)
|
||||
2. 平台级别配置
|
||||
3. 返回空列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 1. 查找会话级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "session",
|
||||
KBSessionConfig.scope_id == session_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 2. 提取平台 ID (格式: platform:xxx:session_id)
|
||||
parts = session_id.split(":")
|
||||
if len(parts) >= 2:
|
||||
platform_id = parts[0]
|
||||
|
||||
# 查找平台级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "platform",
|
||||
KBSessionConfig.scope_id == platform_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 3. 无配置
|
||||
return []
|
||||
|
||||
async def set_session_kb_ids(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
kb_ids: list[str],
|
||||
top_k: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KBSessionConfig:
|
||||
"""设置会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID)
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量 (可选)
|
||||
enable_rerank: 是否启用 Rerank (可选)
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 查找现有配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.kb_ids = json.dumps(kb_ids)
|
||||
if top_k is not None:
|
||||
config.top_k = top_k
|
||||
if enable_rerank is not None:
|
||||
config.enable_rerank = enable_rerank
|
||||
else:
|
||||
# 创建新配置
|
||||
config = KBSessionConfig(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
kb_ids=json.dumps(kb_ids),
|
||||
top_k=top_k,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
|
||||
"""删除会话知识库配置"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_all_session_configs(
|
||||
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
|
||||
) -> list[KBSessionConfig]:
|
||||
"""列出所有会话配置"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig)
|
||||
|
||||
if scope:
|
||||
stmt = stmt.where(KBSessionConfig.scope == scope)
|
||||
|
||||
stmt = (
|
||||
stmt.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBSessionConfig.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
@@ -0,0 +1,161 @@
|
||||
"""向量数据库工厂
|
||||
|
||||
负责为每个知识库创建和管理独立的向量数据库实例。
|
||||
|
||||
架构说明:
|
||||
- 每个知识库拥有独立的向量数据库实例
|
||||
- 向量数据库文件以 kb_id 命名
|
||||
- 工厂类负责实例的创建、缓存和生命周期管理
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
|
||||
|
||||
class VecDBFactory:
|
||||
"""向量数据库工厂
|
||||
|
||||
职责:
|
||||
- 为每个知识库创建独立的向量数据库实例
|
||||
- 缓存已创建的实例以提高性能
|
||||
- 管理向量数据库的生命周期
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_base_path: str,
|
||||
):
|
||||
"""初始化向量数据库工厂
|
||||
|
||||
Args:
|
||||
storage_base_path: 向量数据库存储基础路径
|
||||
"""
|
||||
self.storage_base_path = Path(storage_base_path)
|
||||
self._instances: Dict[str, BaseVecDB] = {}
|
||||
|
||||
# 确保基础路径存在
|
||||
self.storage_base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def get_vec_db(
|
||||
self, kb_id: str, embedding_provider: EmbeddingProvider
|
||||
) -> BaseVecDB:
|
||||
"""获取或创建指定知识库的向量数据库实例
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
embedding_provider: Embedding Provider 实例
|
||||
|
||||
Returns:
|
||||
BaseVecDB: 向量数据库实例
|
||||
"""
|
||||
# 如果已经创建过,直接返回缓存的实例
|
||||
if kb_id in self._instances:
|
||||
return self._instances[kb_id]
|
||||
|
||||
# 创建新实例
|
||||
vec_db = await self._create_vec_db(kb_id, embedding_provider)
|
||||
self._instances[kb_id] = vec_db
|
||||
|
||||
logger.debug(f"创建知识库 {kb_id} 的向量数据库实例")
|
||||
|
||||
return vec_db
|
||||
|
||||
async def _create_vec_db(
|
||||
self, kb_id: str, embedding_provider: EmbeddingProvider
|
||||
) -> BaseVecDB:
|
||||
"""创建向量数据库实例
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
embedding_provider: Embedding Provider 实例
|
||||
|
||||
Returns:
|
||||
BaseVecDB: 向量数据库实例
|
||||
"""
|
||||
# 为每个知识库创建独立的存储路径
|
||||
kb_storage_path = self.storage_base_path / kb_id
|
||||
kb_storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
doc_store_path = str(kb_storage_path / "documents.db")
|
||||
index_store_path = str(kb_storage_path / "index.faiss")
|
||||
|
||||
vec_db = FaissVecDB(
|
||||
doc_store_path=doc_store_path,
|
||||
index_store_path=index_store_path,
|
||||
embedding_provider=embedding_provider,
|
||||
)
|
||||
|
||||
await vec_db.initialize()
|
||||
|
||||
return vec_db
|
||||
|
||||
async def delete_vec_db(self, kb_id: str) -> bool:
|
||||
"""删除指定知识库的向量数据库
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
# 关闭并移除缓存的实例
|
||||
if kb_id in self._instances:
|
||||
try:
|
||||
await self._instances[kb_id].close()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
|
||||
|
||||
del self._instances[kb_id]
|
||||
|
||||
# 删除文件系统中的向量数据库文件
|
||||
kb_storage_path = self.storage_base_path / kb_id
|
||||
if kb_storage_path.exists():
|
||||
try:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(kb_storage_path)
|
||||
logger.info(f"已删除知识库 {kb_id} 的向量数据库文件")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量数据库文件失败 ({kb_id}): {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def close_all(self):
|
||||
"""关闭所有向量数据库实例"""
|
||||
for kb_id, vec_db in list(self._instances.items()):
|
||||
try:
|
||||
await vec_db.close()
|
||||
logger.debug(f"已关闭知识库 {kb_id} 的向量数据库")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
|
||||
|
||||
self._instances.clear()
|
||||
|
||||
def has_instance(self, kb_id: str) -> bool:
|
||||
"""检查是否已创建指定知识库的向量数据库实例
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
bool: 是否已创建实例
|
||||
"""
|
||||
return kb_id in self._instances
|
||||
|
||||
def get_cached_instance(self, kb_id: str) -> Optional[BaseVecDB]:
|
||||
"""获取已缓存的向量数据库实例(不创建新实例)
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
Optional[BaseVecDB]: 向量数据库实例,如果不存在则返回 None
|
||||
"""
|
||||
return self._instances.get(kb_id)
|
||||
@@ -1280,7 +1280,13 @@ class KnowledgeBaseRoute(Route):
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"sessions": session_list, "total": len(session_list), "kb_id": kb_id})
|
||||
.ok(
|
||||
{
|
||||
"sessions": session_list,
|
||||
"total": len(session_list),
|
||||
"kb_id": kb_id,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user