diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 9dc03d4af..783b1fa98 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -1,6 +1,6 @@ """ Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。 -该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。 +该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。 该类还负责加载和执行插件, 以及处理事件总线的分发。 工作流程: @@ -28,7 +28,6 @@ from astrbot.core.db import BaseDatabase from astrbot.core.updator import AstrBotUpdator from astrbot.core import logger from astrbot.core.config.default import VERSION -from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_map @@ -37,7 +36,7 @@ from astrbot.core.star.star_handler import star_map class AstrBotCoreLifecycle: """ AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。 - 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、 + 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、 EventBus 等。 该类还负责加载和执行插件, 以及处理事件总线的分发。 """ @@ -54,7 +53,7 @@ class AstrBotCoreLifecycle: async def initialize(self): """ - 初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 + 初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 """ # 初始化日志代理 @@ -73,9 +72,6 @@ class AstrBotCoreLifecycle: # 初始化平台管理器 self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) - # 初始化知识库管理器 - self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config) - # 初始化对话管理器 self.conversation_manager = ConversationManager(self.db) @@ -87,7 +83,6 @@ class AstrBotCoreLifecycle: self.provider_manager, self.platform_manager, self.conversation_manager, - self.knowledge_db_manager, ) # 初始化插件管理器 diff --git a/astrbot/core/db/plugin/sqlite_impl.py b/astrbot/core/db/plugin/sqlite_impl.py deleted file mode 100644 index 53cfb8284..000000000 --- a/astrbot/core/db/plugin/sqlite_impl.py +++ /dev/null @@ -1,113 +0,0 @@ -import json -import aiosqlite -import os -from typing import Any -from .plugin_storage import PluginStorage -from astrbot.core.utils.astrbot_path import get_astrbot_data_path - -DBPATH = os.path.join(get_astrbot_data_path(), "plugin_data", "sqlite", "plugin_data.db") - - -class SQLitePluginStorage(PluginStorage): - """插件数据的 SQLite 存储实现类。 - - 该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。 - 所有数据以 (plugin, key) 作为复合主键进行索引。 - """ - - _instance = None # Standalone instance of the class - _db_conn = None - db_path = None - - def __new__(cls): - """ - 创建或获取 SQLitePluginStorage 的单例实例。 - 如果实例已存在,则返回现有实例;否则创建一个新实例。 - 数据在 `data/plugin_data/sqlite/plugin_data.db` 下。 - """ - os.makedirs(os.path.dirname(DBPATH), exist_ok=True) - if cls._instance is None: - cls._instance = super(SQLitePluginStorage, cls).__new__(cls) - cls._instance.db_path = DBPATH - return cls._instance - - async def _init_db(self): - """初始化数据库连接(只执行一次)""" - if SQLitePluginStorage._db_conn is None: - SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path) - await self._setup_db() - - async def _setup_db(self): - """ - 异步初始化数据库。 - - 创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段, - 其中 plugin 和 key 组合作为主键。 - """ - await self._db_conn.execute(""" - CREATE TABLE IF NOT EXISTS plugin_data ( - plugin TEXT, - key TEXT, - value TEXT, - PRIMARY KEY (plugin, key) - ) - """) - await self._db_conn.commit() - - async def set(self, plugin: str, key: str, value: Any): - """ - 异步存储数据。 - - 将指定插件的键值对存入数据库,如果键已存在则更新值。 - 值会被序列化为 JSON 字符串后存储。 - - Args: - plugin: 插件标识符 - key: 数据键名 - value: 要存储的数据值(任意类型,将被 JSON 序列化) - """ - await self._init_db() - await self._db_conn.execute( - "INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) " - "ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value", - (plugin, key, json.dumps(value)), - ) - await self._db_conn.commit() - - async def get(self, plugin: str, key: str) -> Any: - """ - 异步获取数据。 - - 从数据库中获取指定插件和键名对应的值, - 返回的值会从 JSON 字符串反序列化为原始数据类型。 - - Args: - plugin: 插件标识符 - key: 数据键名 - - Returns: - Any: 存储的数据值,如果未找到则返回 None - """ - await self._init_db() - async with self._db_conn.execute( - "SELECT value FROM plugin_data WHERE plugin = ? AND key = ?", - (plugin, key), - ) as cursor: - row = await cursor.fetchone() - return json.loads(row[0]) if row else None - - async def delete(self, plugin: str, key: str): - """ - 异步删除数据。 - - 从数据库中删除指定插件和键名对应的数据项。 - - Args: - plugin: 插件标识符 - key: 要删除的数据键名 - """ - await self._init_db() - await self._db_conn.execute( - "DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key) - ) - await self._db_conn.commit() diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py new file mode 100644 index 000000000..d6e5e56be --- /dev/null +++ b/astrbot/core/db/vec_db/base.py @@ -0,0 +1,46 @@ +import abc +from dataclasses import dataclass + + +@dataclass +class Result: + similarity: float + data: dict + + +class BaseVecDB: + async def initialize(self): + """ + 初始化向量数据库 + """ + pass + + @abc.abstractmethod + async def insert(self, content: str, metadata: dict = None, id: str = None) -> int: + """ + 插入一条文本和其对应向量,自动生成 ID 并保持一致性。 + """ + ... + + @abc.abstractmethod + async def retrieve(self, query: str, top_k: int = 5) -> list[Result]: + """ + 搜索最相似的文档。 + Args: + query (str): 查询文本 + top_k (int): 返回的最相似文档的数量 + Returns: + List[Result]: 查询结果 + """ + ... + + @abc.abstractmethod + async def delete(self, doc_id: str) -> bool: + """ + 删除指定文档。 + Args: + doc_id (str): 要删除的文档 ID + Returns: + bool: 删除是否成功 + """ + ... diff --git a/astrbot/core/db/vec_db/faiss_impl/__init__.py b/astrbot/core/db/vec_db/faiss_impl/__init__.py new file mode 100644 index 000000000..11fc79d60 --- /dev/null +++ b/astrbot/core/db/vec_db/faiss_impl/__init__.py @@ -0,0 +1,3 @@ +from .vec_db import FaissVecDB + +__all__ = ["FaissVecDB"] \ No newline at end of file diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py new file mode 100644 index 000000000..ee44da66c --- /dev/null +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -0,0 +1,121 @@ +import aiosqlite +import os + + +class DocumentStorage: + def __init__(self, db_path: str): + self.db_path = db_path + self.connection = None + self.sqlite_init_path = os.path.join( + os.path.dirname(__file__), "sqlite_init.sql" + ) + + async def initialize(self): + """Initialize the SQLite database and create the documents table if it doesn't exist.""" + if not os.path.exists(self.db_path): + await self.connect() + async with self.connection.cursor() as cursor: + with open(self.sqlite_init_path, "r", encoding="utf-8") as f: + sql_script = f.read() + await cursor.executescript(sql_script) + await self.connection.commit() + else: + await self.connect() + + async def connect(self): + """Connect to the SQLite database.""" + self.connection = await aiosqlite.connect(self.db_path) + + async def get_documents(self, metadata_filters: dict, ids: list = None): + """Retrieve documents by metadata filters and ids. + + Args: + metadata_filters (dict): The metadata filters to apply. + + Returns: + list: The list of document IDs(primary key, not doc_id) that match the filters. + """ + # metadata filter -> SQL WHERE clause + where_clauses = [] + values = [] + for key, val in metadata_filters.items(): + where_clauses.append(f"json_extract(metadata, '$.{key}') = ?") + values.append(val) + if ids is not None and len(ids) > 0: + ids = [str(i) for i in ids if i != -1] + where_clauses.append("id IN ({})".format(",".join("?" * len(ids)))) + values.extend(ids) + where_sql = " AND ".join(where_clauses) or "1=1" + + result = [] + async with self.connection.cursor() as cursor: + sql = "SELECT * FROM documents WHERE " + where_sql + await cursor.execute(sql, values) + for row in await cursor.fetchall(): + result.append(await self.tuple_to_dict(row)) + return result + + async def get_document_by_doc_id(self, doc_id: str): + """Retrieve a document by its doc_id. + + Args: + doc_id (str): The doc_id of the document to retrieve. + + Returns: + dict: The document data. + """ + async with self.connection.cursor() as cursor: + await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,)) + row = await cursor.fetchone() + if row: + return await self.tuple_to_dict(row) + else: + return None + + async def update_document_by_doc_id(self, doc_id: str, new_text: str): + """Retrieve a document by its doc_id. + + Args: + doc_id (str): The doc_id. + new_text (str): The new text to update the document with. + """ + async with self.connection.cursor() as cursor: + await cursor.execute( + "UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id) + ) + await self.connection.commit() + + async def get_user_ids(self) -> list[str]: + """Retrieve all user IDs from the documents table. + + Returns: + list: A list of user IDs. + """ + async with self.connection.cursor() as cursor: + await cursor.execute("SELECT DISTINCT user_id FROM documents") + rows = await cursor.fetchall() + return [row[0] for row in rows] + + async def tuple_to_dict(self, row): + """Convert a tuple to a dictionary. + + Args: + row (tuple): The row to convert. + + Returns: + dict: The converted dictionary. + """ + return { + "id": row[0], + "doc_id": row[1], + "text": row[2], + "metadata": row[3], + "created_at": row[4], + "updated_at": row[5], + } + + async def close(self): + """Close the connection to the SQLite database.""" + if self.connection: + await self.connection.close() + self.connection = None diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py new file mode 100644 index 000000000..262a459e3 --- /dev/null +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -0,0 +1,59 @@ +try: + import faiss +except ModuleNotFoundError: + raise ImportError( + "faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。" + ) +import os +import numpy as np + + +class EmbeddingStorage: + def __init__(self, dimension: int, path: str = None): + self.dimension = dimension + self.path = path + self.index = None + if path and os.path.exists(path): + self.index = faiss.read_index(path) + else: + base_index = faiss.IndexFlatL2(dimension) + self.index = faiss.IndexIDMap(base_index) + self.storage = {} + + async def insert(self, vector: np.ndarray, id: int): + """插入向量 + + Args: + vector (np.ndarray): 要插入的向量 + id (int): 向量的ID + Raises: + ValueError: 如果向量的维度与存储的维度不匹配 + """ + if vector.shape[0] != self.dimention: + raise ValueError( + f"向量维度不匹配, 期望: {self.dimention}, 实际: {vector.shape[0]}" + ) + self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) + self.storage[id] = vector + await self.save_index() + + async def search(self, vector: np.ndarray, k: int) -> tuple: + """搜索最相似的向量 + + Args: + vector (np.ndarray): 查询向量 + k (int): 返回的最相似向量的数量 + Returns: + tuple: (距离, 索引) + """ + faiss.normalize_L2(vector) + distances, indices = self.index.search(vector, k) + return distances, indices + + async def save_index(self): + """保存索引 + + Args: + path (str): 保存索引的路径 + """ + faiss.write_index(self.index, self.path) diff --git a/astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql b/astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql new file mode 100644 index 000000000..1e04d70e3 --- /dev/null +++ b/astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql @@ -0,0 +1,17 @@ +-- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at +CREATE TABLE documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id TEXT NOT NULL, + text TEXT NOT NULL, + metadata TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +ALTER TABLE documents +ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED; +ALTER TABLE documents +ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED; + +CREATE INDEX idx_documents_user_id ON documents(user_id); +CREATE INDEX idx_documents_group_id ON documents(group_id); \ No newline at end of file diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py new file mode 100644 index 000000000..e4122e547 --- /dev/null +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -0,0 +1,123 @@ +import uuid +import json +import numpy as np +from .document_storage import DocumentStorage +from .embedding_storage import EmbeddingStorage +from ..base import Result, BaseVecDB +from astrbot.core.provider.provider import EmbeddingProvider + + +class FaissVecDB(BaseVecDB): + """ + A class to represent a vector database. + """ + + def __init__( + self, + doc_store_path: str, + index_store_path: str, + embedding_provider: EmbeddingProvider, + ): + self.doc_store_path = doc_store_path + self.index_store_path = index_store_path + self.embedding_provider = embedding_provider + self.document_storage = DocumentStorage(doc_store_path) + self.embedding_storage = EmbeddingStorage( + embedding_provider.get_dim(), index_store_path + ) + self.embedding_provider = embedding_provider + + async def initialize(self): + await self.document_storage.initialize() + + async def insert( + self, + content: str, + metadata: dict = None, + id: str = None, + ) -> int: + """ + 插入一条文本和其对应向量,自动生成 ID 并保持一致性。 + """ + metadata = metadata or {} + str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID + + # 获取向量 + vector = await self.embedding_provider.get_embedding(content) + vector = np.array(vector, dtype=np.float32) + async with self.document_storage.connection.cursor() as cursor: + await cursor.execute( + "INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)", + (str_id, content, json.dumps(metadata)), + ) + await self.document_storage.connection.commit() + result = await self.document_storage.get_document_by_doc_id(str_id) + int_id = result["id"] + + # 插入向量到 FAISS + await self.embedding_storage.insert(vector, int_id) + return int_id + + async def retrieve( + self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None + ) -> list[Result]: + """ + 搜索最相似的文档。 + + Args: + query (str): 查询文本 + k (int): 返回的最相似文档的数量 + fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量 + metadata_filters (dict): 元数据过滤器 + + Returns: + List[Result]: 查询结果 + """ + embedding = await self.embedding_provider.get_embedding(query) + scores, indices = await self.embedding_storage.search( + vector=np.array([embedding]).astype("float32"), + k=fetch_k if metadata_filters else k, + ) + # TODO: rerank + if len(indices[0]) == 0 or indices[0][0] == -1: + return [] + # normalize scores + scores[0] = 1.0 - (scores[0] / 2.0) + # NOTE: maybe the size is less than k. + fetched_docs = await self.document_storage.get_documents( + metadata_filters=metadata_filters or {}, ids=indices[0] + ) + if not fetched_docs: + return [] + result_docs = [] + + idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)} + for i, indice_idx in enumerate(indices[0]): + pos = idx_pos.get(indice_idx) + if pos is None: + continue + fetch_doc = fetched_docs[pos] + score = scores[0][i] + result_docs.append(Result(similarity=float(score), data=fetch_doc)) + return result_docs[:k] + + async def delete(self, doc_id: int): + """ + 删除一条文档 + """ + await self.document_storage.connection.execute( + "DELETE FROM documents WHERE doc_id = ?", (doc_id,) + ) + await self.document_storage.connection.commit() + + async def close(self): + await self.document_storage.close() + + async def count_documents(self) -> int: + """ + 计算文档数量 + """ + async with self.document_storage.connection.cursor() as cursor: + await cursor.execute("SELECT COUNT(*) FROM documents") + count = await cursor.fetchone() + return count[0] if count else 0 diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 96547c5c2..7019113c7 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -179,3 +179,20 @@ class TTSProvider(AbstractProvider): async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" raise NotImplementedError() + + +class EmbeddingProvider(AbstractProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config) + self.provider_config = provider_config + self.provider_settings = provider_settings + + @abc.abstractmethod + async def get_embedding(self, text: str) -> list[float]: + """获取文本的向量""" + ... + + @abc.abstractmethod + def get_dim(self) -> int: + """获取向量的维度""" + ... diff --git a/astrbot/core/rag/embedding/openai_source.py b/astrbot/core/rag/embedding/openai_source.py deleted file mode 100644 index dc09d84dc..000000000 --- a/astrbot/core/rag/embedding/openai_source.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import List -from openai import AsyncOpenAI - - -class SimpleOpenAIEmbedding: - def __init__( - self, - model, - api_key, - api_base=None, - ) -> None: - self.client = AsyncOpenAI(api_key=api_key, base_url=api_base) - self.model = model - - async def get_embedding(self, text) -> List[float]: - """ - 获取文本的嵌入 - """ - embedding = await self.client.embeddings.create(input=text, model=self.model) - return embedding.data[0].embedding diff --git a/astrbot/core/rag/knowledge_db_mgr.py b/astrbot/core/rag/knowledge_db_mgr.py deleted file mode 100644 index f1c1f386c..000000000 --- a/astrbot/core/rag/knowledge_db_mgr.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -from typing import List, Dict -from astrbot.core import logger -from .store import Store -from astrbot.core.config import AstrBotConfig -from astrbot.core.utils.astrbot_path import get_astrbot_data_path - - -class KnowledgeDBManager: - def __init__(self, astrbot_config: AstrBotConfig) -> None: - self.db_path = os.path.join(get_astrbot_data_path(), "knowledge_db") - self.config = astrbot_config.get("knowledge_db", {}) - self.astrbot_config = astrbot_config - if not os.path.exists(self.db_path): - os.makedirs(self.db_path) - self.store_insts: Dict[str, Store] = {} - for name, cfg in self.config.items(): - if cfg["strategy"] == "embedding": - logger.info(f"加载 Chroma Vector Store:{name}") - try: - from .store.chroma_db import ChromaVectorStore - except ImportError as ie: - logger.error(f"{ie} 可能未安装 chromadb 库。") - continue - self.store_insts[name] = ChromaVectorStore( - name, cfg["embedding_config"] - ) - else: - logger.error(f"不支持的策略:{cfg['strategy']}") - - async def list_knowledge_db(self) -> List[str]: - return [ - f - for f in os.listdir(self.db_path) - if os.path.isfile(os.path.join(self.db_path, f)) - ] - - async def create_knowledge_db(self, name: str, config: Dict): - """ - config 格式: - ``` - { - "strategy": "embedding", # 目前只支持 embedding - "chunk_method": { - "strategy": "fixed", - "chunk_size": 100, - "overlap_size": 10 - }, - "embedding_config": { - "strategy": "openai", - "base_url": "", - "model": "", - "api_key": "" - } - } - ``` - """ - if name in self.config: - raise ValueError(f"知识库已存在:{name}") - - self.config[name] = config - self.astrbot_config["knowledge_db"] = self.config - self.astrbot_config.save_config() - - async def insert_record(self, name: str, text: str): - if name not in self.store_insts: - raise ValueError(f"未找到知识库:{name}") - - ret = [] - match self.config[name]["chunk_method"]["strategy"]: - case "fixed": - chunk_size = self.config[name]["chunk_method"]["chunk_size"] - chunk_overlap = self.config[name]["chunk_method"]["overlap_size"] - ret = self._fixed_chunk(text, chunk_size, chunk_overlap) - case _: - pass - - for chunk in ret: - await self.store_insts[name].save(chunk) - - async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]: - if name not in self.store_insts: - raise ValueError(f"未找到知识库:{name}") - - inst = self.store_insts[name] - return await inst.query(query, top_n) - - def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]: - chunks = [] - start = 0 - while start < len(text): - end = start + chunk_size - chunks.append(text[start:end]) - start += chunk_size - chunk_overlap - return chunks diff --git a/astrbot/core/rag/store/__init__.py b/astrbot/core/rag/store/__init__.py deleted file mode 100644 index 0e74c5a07..000000000 --- a/astrbot/core/rag/store/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import List - - -class Store: - async def save(self, text: str): - pass - - async def query(self, query: str, top_n: int = 3) -> List[str]: - pass diff --git a/astrbot/core/rag/store/chroma_db.py b/astrbot/core/rag/store/chroma_db.py deleted file mode 100644 index d4cfae946..000000000 --- a/astrbot/core/rag/store/chroma_db.py +++ /dev/null @@ -1,44 +0,0 @@ -import chromadb -import uuid -from typing import List, Dict -from astrbot.api import logger -from ..embedding.openai_source import SimpleOpenAIEmbedding -from . import Store -from astrbot.core.utils.astrbot_path import get_astrbot_data_path - - -class ChromaVectorStore(Store): - def __init__(self, name: str, embedding_cfg: Dict) -> None: - import os - self.chroma_client = chromadb.PersistentClient( - path=os.path.join(get_astrbot_data_path(), "long_term_memory_chroma.db") - ) - self.collection = self.chroma_client.get_or_create_collection(name=name) - self.embedding = None - if embedding_cfg["strategy"] == "openai": - self.embedding = SimpleOpenAIEmbedding( - model=embedding_cfg["model"], - api_key=embedding_cfg["api_key"], - api_base=embedding_cfg.get("base_url", None), - ) - - async def save(self, text: str, metadata: Dict = None): - logger.debug(f"Saving text: {text}") - embedding = await self.embedding.get_embedding(text) - - self.collection.upsert( - documents=text, - metadatas=metadata, - ids=str(uuid.uuid4()), - embeddings=embedding, - ) - - async def query( - self, query: str, top_n=3, metadata_filter: Dict = None - ) -> List[str]: - embedding = await self.embedding.get_embedding(query) - - results = self.collection.query( - query_embeddings=embedding, n_results=top_n, where=metadata_filter - ) - return results["documents"][0] diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index d2df31ec4..7cb3ffd1c 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -16,7 +16,6 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter from .filter.regex import RegexFilter from typing import Awaitable -from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, @@ -42,6 +41,8 @@ class Context: platform_manager: PlatformManager = None + registered_web_apis: list = [] + # back compatibility _register_tasks: List[Awaitable] = [] _star_manager = None @@ -54,14 +55,12 @@ class Context: provider_manager: ProviderManager = None, platform_manager: PlatformManager = None, conversation_manager: ConversationManager = None, - knowledge_db_manager: KnowledgeDBManager = None, ): self._event_queue = event_queue self._config = config self._db = db self.provider_manager = provider_manager self.platform_manager = platform_manager - self.knowledge_db_manager = knowledge_db_manager self.conversation_manager = conversation_manager def get_registered_star(self, star_name: str) -> StarMetadata: @@ -301,3 +300,6 @@ class Context: 注册一个异步任务。 """ self._register_tasks.append(task) + + def register_web_api(self, route: str, view_handler: Awaitable, methods: list, desc: str): + self.registered_web_apis.append((route, view_handler, methods, desc)) diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 9f76e2a71..af2d897c2 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -102,7 +102,10 @@ class PluginRoute(Route): async def get_plugins(self): _plugin_resp = [] + plugin_name = request.args.get("name") for plugin in self.plugin_manager.context.get_all_stars(): + if plugin_name and plugin.name != plugin_name: + continue _t = { "name": plugin.name, "repo": "" if plugin.repo is None else plugin.repo, diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 124291718..d8c1a1dd9 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -15,6 +15,8 @@ from astrbot.core.db import BaseDatabase from astrbot.core.utils.io import get_local_ip_addresses from astrbot.core.utils.astrbot_path import get_astrbot_data_path +APP: Quart = None + class AstrBotDashboard: def __init__( @@ -27,6 +29,7 @@ class AstrBotDashboard: self.config = core_lifecycle.astrbot_config self.data_path = os.path.abspath(os.path.join(get_astrbot_data_path(), "dist")) self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/") + APP = self.app # noqa self.app.config["MAX_CONTENT_LENGTH"] = ( 128 * 1024 * 1024 ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB @@ -51,8 +54,25 @@ class AstrBotDashboard: self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) self.file_route = FileRoute(self.context) + self.app.add_url_rule( + "/api/plug/", + view_func=self.srv_plug_route, + methods=["GET", "POST"], + ) + self.shutdown_event = shutdown_event + async def srv_plug_route(self, subpath, *args, **kwargs): + """ + 插件路由 + """ + registered_web_apis = self.core_lifecycle.star_context.registered_web_apis + for api in registered_web_apis: + route, view_handler, methods, _ = api + if route == f"/{subpath}" and request.method in methods: + return await view_handler(*args, **kwargs) + return jsonify(Response().error("未找到该路由").__dict__) + async def auth_middleware(self): if not request.path.startswith("/api"): return diff --git a/dashboard/package.json b/dashboard/package.json index a7edd4935..cd621b0b7 100644 --- a/dashboard/package.json +++ b/dashboard/package.json @@ -20,6 +20,7 @@ "axios": "^1.6.2", "axios-mock-adapter": "^1.22.0", "chance": "1.1.11", + "d3": "^7.9.0", "date-fns": "2.30.0", "highlight.js": "^11.11.1", "js-md5": "^0.8.3", diff --git a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue index d505e0422..82f160746 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue +++ b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue @@ -166,6 +166,10 @@ function endDrag() {
+ + 🔧 设置 + +
官方文档 diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index 803642e95..f541d1d91 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -65,11 +65,11 @@ const sidebarItem: menu[] = [ icon: 'mdi-console', to: '/console' }, - { - title: '设置', - icon: 'mdi-wrench', - to: '/settings' - }, + // { + // title: 'Alkaid', + // icon: 'mdi-test-tube', + // to: '/alkaid' + // }, { title: '关于', icon: 'mdi-information', diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index f1ac3002d..80910b8e3 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -57,9 +57,26 @@ const MainRoutes = { component: () => import('@/views/ConsolePage.vue') }, { - name: 'Project ATRI', - path: '/project-atri', - component: () => import('@/views/ATRIProject.vue') + name: 'Alkaid', + path: '/alkaid', + component: () => import('@/views/AlkaidPage.vue'), + children: [ + { + path: 'knowledge-base', + name: 'KnowledgeBase', + component: () => import('@/views/alkaid/KnowledgeBase.vue') + }, + { + path: 'long-term-memory', + name: 'LongTermMemory', + component: () => import('@/views/alkaid/LongTermMemory.vue') + }, + { + path: 'other', + name: 'OtherFeatures', + component: () => import('@/views/alkaid/Other.vue') + } + ] }, { name: 'Chat', diff --git a/dashboard/src/views/ATRIProject.vue b/dashboard/src/views/ATRIProject.vue deleted file mode 100644 index 4c9a771d4..000000000 --- a/dashboard/src/views/ATRIProject.vue +++ /dev/null @@ -1,87 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/dashboard/src/views/AlkaidPage.vue b/dashboard/src/views/AlkaidPage.vue new file mode 100644 index 000000000..326cddc28 --- /dev/null +++ b/dashboard/src/views/AlkaidPage.vue @@ -0,0 +1,80 @@ + + + + + \ No newline at end of file diff --git a/dashboard/src/views/AlkaidPage_sigma.vue b/dashboard/src/views/AlkaidPage_sigma.vue new file mode 100644 index 000000000..1b5180955 --- /dev/null +++ b/dashboard/src/views/AlkaidPage_sigma.vue @@ -0,0 +1,432 @@ + + + + + + + + \ No newline at end of file diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 07eb41c22..ad149bcd0 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -12,19 +12,22 @@ marked.setOptions({
-