diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 16f108ece..272a0f417 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,5 +1,4 @@ import os -import asyncio from .log import LogManager, LogBroker # noqa from astrbot.core.utils.t2i.renderer import HtmlRenderer from astrbot.core.utils.shared_preferences import SharedPreferences diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d24d701c6..40a602d3d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -6,8 +6,8 @@ import os from astrbot.core.utils.astrbot_path import get_astrbot_data_path -VERSION = "3.5.22" -DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db") +VERSION = "4.0.0" +DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") # 默认配置 DEFAULT_CONFIG = { diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index b665488e4..62079b344 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -5,13 +5,12 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 在一个会话中可以建立多个对话, 并且支持对话的切换和删除 """ -import uuid import json import asyncio from astrbot.core import sp from typing import Dict, List from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import Conversation +from astrbot.core.db.po import Conversation, ConversationV2 class ConversationManager: @@ -38,7 +37,29 @@ class ConversationManager: """保存会话对话映射关系到存储中""" sp.put("session_conversation", self.session_conversations) - async def new_conversation(self, unified_msg_origin: str) -> str: + def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: + """将 ConversationV2 对象转换为 Conversation 对象""" + created_at = int(conv_v2.created_at.timestamp()) + updated_at = int(conv_v2.updated_at.timestamp()) + return Conversation( + platform_id=conv_v2.platform_id, + user_id=conv_v2.user_id, + cid=conv_v2.conversation_id, + history=json.dumps(conv_v2.content or []), + title=conv_v2.title, + persona_id=conv_v2.persona_id, + created_at=created_at, + updated_at=updated_at, + ) + + async def new_conversation( + self, + unified_msg_origin: str, + platform_id: str = None, + content: list[dict] = None, + title: str = None, + persona_id: str = None, + ) -> str: """新建对话,并将当前会话的对话转移到新对话 Args: @@ -46,11 +67,23 @@ class ConversationManager: Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ - conversation_id = str(uuid.uuid4()) - self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id) - self.session_conversations[unified_msg_origin] = conversation_id + if not platform_id: + # 如果没有提供 platform_id,则从 unified_msg_origin 中解析 + parts = unified_msg_origin.split(":") + if len(parts) >= 3: + platform_id = parts[0] + if not platform_id: + platform_id = "unknown" + conv = await self.db.create_conversation( + user_id=unified_msg_origin, + platform_id=platform_id, + content=content, + title=title, + persona_id=persona_id, + ) + self.session_conversations[unified_msg_origin] = conv.conversation_id sp.put("session_conversation", self.session_conversations) - return conversation_id + return str(conv.conversation_id) async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): """切换会话的对话 @@ -71,11 +104,16 @@ class ConversationManager: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ - conversation_id = self.session_conversations.get(unified_msg_origin) + f = False + if not conversation_id: + conversation_id = self.session_conversations.get(unified_msg_origin) + if conversation_id: + f = True if conversation_id: - self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id) - del self.session_conversations[unified_msg_origin] - sp.put("session_conversation", self.session_conversations) + await self.db.delete_conversation(cid=conversation_id) + if f: + self.session_conversations.pop(unified_msg_origin, None) + sp.put("session_conversation", self.session_conversations) async def get_curr_conversation_id(self, unified_msg_origin: str) -> str: """获取会话当前的对话 ID @@ -92,7 +130,7 @@ class ConversationManager: unified_msg_origin: str, conversation_id: str, create_if_not_exists: bool = False, - ) -> Conversation: + ) -> Conversation | None: """获取会话的对话 Args: @@ -101,27 +139,74 @@ class ConversationManager: Returns: conversation (Conversation): 对话对象 """ - conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) + conv = await self.db.get_conversation_by_id(cid=conversation_id) if not conv and create_if_not_exists: # 如果对话不存在且需要创建,则新建一个对话 conversation_id = await self.new_conversation(unified_msg_origin) - return self.db.get_conversation_by_user_id( - unified_msg_origin, conversation_id - ) - return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) + conv = await self.db.get_conversation_by_id(cid=conversation_id) + conv_res = None + if conv: + conv_res = self._convert_conv_from_v2_to_v1(conv) + return conv_res - async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]: - """获取会话的所有对话 + async def get_conversations( + self, unified_msg_origin: str = None, platform_id: str = None + ) -> List[Conversation]: + """获取对话列表 Args: - unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选 + platform_id (str): 平台 ID, 可选参数, 用于过滤对话 Returns: conversations (List[Conversation]): 对话对象列表 """ - return self.db.get_conversations(unified_msg_origin) + convs = await self.db.get_conversations( + user_id=unified_msg_origin, platform_id=platform_id + ) + convs_res = [] + for conv in convs: + conv_res = self._convert_conv_from_v2_to_v1(conv) + convs_res.append(conv_res) + return convs_res + + async def get_filtered_conversations( + self, + page: int = 1, + page_size: int = 20, + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs, + ) -> tuple[list[Conversation], int]: + """获取过滤后的对话列表 + + Args: + page (int): 页码, 默认为 1 + page_size (int): 每页大小, 默认为 20 + platform_ids (list[str]): 平台 ID 列表, 可选 + search_query (str): 搜索查询字符串, 可选 + Returns: + conversations (list[Conversation]): 对话对象列表 + """ + convs, cnt = await self.db.get_filtered_conversations( + page=page, + page_size=page_size, + platform_ids=platform_ids, + search_query=search_query, + **kwargs, + ) + convs_res = [] + for conv in convs: + conv_res = self._convert_conv_from_v2_to_v1(conv) + convs_res.append(conv_res) + return convs_res, cnt async def update_conversation( - self, unified_msg_origin: str, conversation_id: str, history: List[Dict] + self, + unified_msg_origin: str, + conversation_id: str = None, + history: list[dict] = None, + title: str = None, + persona_id: str = None, ): """更新会话的对话 @@ -130,40 +215,52 @@ class ConversationManager: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 """ + if not conversation_id: + # 如果没有提供 conversation_id,则从 session_conversations 中获取当前的 + conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: - self.db.update_conversation( - user_id=unified_msg_origin, + await self.db.update_conversation( cid=conversation_id, - history=json.dumps(history), + title=title, + persona_id=persona_id, + content=history or [], ) - async def update_conversation_title(self, unified_msg_origin: str, title: str): + async def update_conversation_title( + self, unified_msg_origin: str, title: str, conversation_id: str = None + ): """更新会话的对话标题 Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id title (str): 对话标题 + + Deprecated: + Use `update_conversation` with `title` parameter instead. """ - conversation_id = self.session_conversations.get(unified_msg_origin) - if conversation_id: - self.db.update_conversation_title( - user_id=unified_msg_origin, cid=conversation_id, title=title - ) + await self.update_conversation( + unified_msg_origin=unified_msg_origin, + conversation_id=conversation_id, + title=title, + ) async def update_conversation_persona_id( - self, unified_msg_origin: str, persona_id: str + self, unified_msg_origin: str, persona_id: str, conversation_id: str = None ): """更新会话的对话 Persona ID Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id persona_id (str): 对话 Persona ID + + Deprecated: + Use `update_conversation` with `persona_id` parameter instead. """ - conversation_id = self.session_conversations.get(unified_msg_origin) - if conversation_id: - self.db.update_conversation_persona_id( - user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id - ) + await self.update_conversation( + unified_msg_origin=unified_msg_origin, + conversation_id=conversation_id, + persona_id=persona_id, + ) async def get_human_readable_context( self, unified_msg_origin, conversation_id, page=1, page_size=10 diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index eccffbd64..8412d5bea 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -29,8 +29,10 @@ from astrbot.core.updator import AstrBotUpdator from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_map +from astrbot.core.db.migration.helper import do_migration_v4 class AstrBotCoreLifecycle: @@ -66,6 +68,9 @@ class AstrBotCoreLifecycle: else: logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别 + await self.db.initialize() + await do_migration_v4(self.db, {}) + # 初始化事件队列 self.event_queue = Queue() @@ -78,6 +83,9 @@ class AstrBotCoreLifecycle: # 初始化对话管理器 self.conversation_manager = ConversationManager(self.db) + # 初始化平台消息历史管理器 + self.platform_message_history_manager = PlatformMessageHistoryManager(self.db) + # 初始化提供给插件的上下文 self.star_context = Context( self.event_queue, diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 6688dcced..53fccacfc 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -1,7 +1,20 @@ import abc +import datetime +import typing as T +from deprecated import deprecated from dataclasses import dataclass -from typing import List, Dict, Any, Tuple -from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation +from astrbot.core.db.po import ( + Stats, + PlatformStat, + ConversationV2, + PlatformMessageHistory, + Attachment, + Persona, + Preference, +) +from contextlib import asynccontextmanager +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker @dataclass @@ -10,152 +23,226 @@ class BaseDatabase(abc.ABC): 数据库基类 """ + DATABASE_URL = "" + def __init__(self) -> None: + self.engine = create_async_engine( + self.DATABASE_URL, + echo=False, + future=True, + ) + self.AsyncSessionLocal = sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) + + async def initialize(self): + """初始化数据库连接""" pass - def insert_base_metrics(self, metrics: dict): - """插入基础指标数据""" - self.insert_platform_metrics(metrics["platform_stats"]) - self.insert_plugin_metrics(metrics["plugin_stats"]) - self.insert_command_metrics(metrics["command_stats"]) - self.insert_llm_metrics(metrics["llm_stats"]) - - @abc.abstractmethod - def insert_platform_metrics(self, metrics: dict): - """插入平台指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def insert_plugin_metrics(self, metrics: dict): - """插入插件指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def insert_command_metrics(self, metrics: dict): - """插入指令指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def insert_llm_metrics(self, metrics: dict): - """插入 LLM 指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def update_llm_history(self, session_id: str, content: str, provider_type: str): - """更新 LLM 历史记录。当不存在 session_id 时插入""" - raise NotImplementedError - - @abc.abstractmethod - def get_llm_history( - self, session_id: str = None, provider_type: str = None - ) -> List[LLMHistory]: - """获取 LLM 历史记录, 如果 session_id 为 None, 返回所有""" - raise NotImplementedError + @asynccontextmanager + async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]: + """Get a database session.""" + if not self.inited: + await self.initialize() + self.inited = True + async with self.AsyncSessionLocal() as session: + yield session + @deprecated(version="4.0.0", reason="Use get_platform_stats instead") @abc.abstractmethod def get_base_stats(self, offset_sec: int = 86400) -> Stats: """获取基础统计数据""" raise NotImplementedError + @deprecated(version="4.0.0", reason="Use get_platform_stats instead") @abc.abstractmethod def get_total_message_count(self) -> int: """获取总消息数""" raise NotImplementedError + @deprecated(version="4.0.0", reason="Use get_platform_stats instead") @abc.abstractmethod def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: """获取基础统计数据(合并)""" raise NotImplementedError - @abc.abstractmethod - def insert_atri_vision_data(self, vision_data: ATRIVision): - """插入 ATRI 视觉数据""" - raise NotImplementedError + # New methods in v4.0.0 @abc.abstractmethod - def get_atri_vision_data(self) -> List[ATRIVision]: - """获取 ATRI 视觉数据""" - raise NotImplementedError + async def insert_platform_stats( + self, + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime.datetime = None, + ) -> None: + """Insert a new platform statistic record.""" + ... @abc.abstractmethod - def get_atri_vision_data_by_path_or_id( - self, url_or_path: str, id: str - ) -> ATRIVision: - """通过 url 或 path 获取 ATRI 视觉数据""" - raise NotImplementedError + async def count_platform_stats(self) -> int: + """Count the number of platform statistics records.""" + ... @abc.abstractmethod - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: - """通过 user_id 和 cid 获取 Conversation""" - raise NotImplementedError + async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: + """Get platform statistics within the specified offset in seconds and group by platform_id.""" + ... @abc.abstractmethod - def new_conversation(self, user_id: str, cid: str): - """新建 Conversation""" - raise NotImplementedError + async def get_conversations( + self, user_id: str = None, platform_id: str = None + ) -> list[ConversationV2]: + """Get all conversations for a specific user and platform_id(optional). - @abc.abstractmethod - def get_conversations(self, user_id: str) -> List[Conversation]: - raise NotImplementedError - - @abc.abstractmethod - def update_conversation(self, user_id: str, cid: str, history: str): - """更新 Conversation""" - raise NotImplementedError - - @abc.abstractmethod - def delete_conversation(self, user_id: str, cid: str): - """删除 Conversation""" - raise NotImplementedError - - @abc.abstractmethod - def update_conversation_title(self, user_id: str, cid: str, title: str): - """更新 Conversation 标题""" - raise NotImplementedError - - @abc.abstractmethod - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): - """更新 Conversation Persona ID""" - raise NotImplementedError - - @abc.abstractmethod - def get_all_conversations( - self, page: int = 1, page_size: int = 20 - ) -> Tuple[List[Dict[str, Any]], int]: - """获取所有对话,支持分页 - - Args: - page: 页码,从1开始 - page_size: 每页数量 - - Returns: - Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数 + content is not included in the result. """ - raise NotImplementedError + ... @abc.abstractmethod - def get_filtered_conversations( + async def get_conversation_by_id(self, cid: str) -> ConversationV2: + """Get a specific conversation by its ID.""" + ... + + @abc.abstractmethod + async def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> list[ConversationV2]: + """Get all conversations with pagination.""" + ... + + @abc.abstractmethod + async def get_filtered_conversations( self, page: int = 1, page_size: int = 20, - platforms: List[str] = None, - message_types: List[str] = None, - search_query: str = None, - exclude_ids: List[str] = None, - exclude_platforms: List[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: - """获取筛选后的对话列表 + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs, + ) -> tuple[list[ConversationV2], int]: + """Get conversations filtered by platform IDs and search query.""" + ... - Args: - page: 页码 - page_size: 每页数量 - platforms: 平台筛选列表 - message_types: 消息类型筛选列表 - search_query: 搜索关键词 - exclude_ids: 排除的用户ID列表 - exclude_platforms: 排除的平台列表 + @abc.abstractmethod + async def create_conversation( + self, + user_id: str, + platform_id: str, + content: list[dict] = None, + title: str = None, + persona_id: str = None, + cid: str = None, + created_at: datetime.datetime = None, + updated_at: datetime.datetime = None, + ) -> ConversationV2: + """Create a new conversation.""" + ... - Returns: - Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数 - """ - raise NotImplementedError + @abc.abstractmethod + async def update_conversation( + self, + cid: str, + title: str = None, + persona_id: str = None, + content: list[dict] = None, + ) -> None: + """Update a conversation's history.""" + ... + + @abc.abstractmethod + async def delete_conversation(self, cid: str) -> None: + """Delete a conversation by its ID.""" + ... + + @abc.abstractmethod + async def insert_platform_message_history( + self, + platform_id: str, + user_id: str, + content: list[dict], + sender_id: str = None, + sender_name: str = None, + ) -> None: + """Insert a new platform message history record.""" + ... + + @abc.abstractmethod + async def delete_platform_message_offset( + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: + """Delete platform message history records older than the specified offset.""" + ... + + @abc.abstractmethod + async def get_platform_message_history( + self, + platform_id: str, + user_id: str, + page: int = 1, + page_size: int = 20, + ) -> list[PlatformMessageHistory]: + """Get platform message history for a specific user.""" + ... + + @abc.abstractmethod + async def insert_attachment( + self, + path: str, + type: str, + mime_type: str, + ): + """Insert a new attachment record.""" + ... + + @abc.abstractmethod + async def get_attachment_by_id(self, attachment_id: str) -> Attachment: + """Get an attachment by its ID.""" + ... + + @abc.abstractmethod + async def insert_persona( + self, + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] = None, + ) -> Persona: + """Insert a new persona record.""" + ... + + @abc.abstractmethod + async def get_persona_by_id(self, persona_id: str) -> Persona: + """Get a persona by its ID.""" + ... + + @abc.abstractmethod + async def get_personas(self) -> list[Persona]: + """Get all personas for a specific bot.""" + ... + + @abc.abstractmethod + async def insert_preference_or_update(self, key: str, value: str) -> Preference: + """Insert a new preference record.""" + ... + + @abc.abstractmethod + async def get_preference(self, key: str) -> Preference: + """Get a preference by bot ID and key.""" + ... + + # @abc.abstractmethod + # async def insert_llm_message( + # self, + # cid: str, + # role: str, + # content: list, + # tool_calls: list = None, + # tool_call_id: str = None, + # parent_id: str = None, + # ) -> LLMMessage: + # """Insert a new LLM message into the conversation.""" + # ... + + # @abc.abstractmethod + # async def get_llm_messages(self, cid: str) -> list[LLMMessage]: + # """Get all LLM messages for a specific conversation.""" + # ... diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py new file mode 100644 index 000000000..d4b9c99f9 --- /dev/null +++ b/astrbot/core/db/migration/helper.py @@ -0,0 +1,51 @@ +import os +from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.db import BaseDatabase +from astrbot.api import logger +from .migra_3_to_4 import ( + migration_conversation_table, + migration_platform_table, + migration_webchat_data, +) + + +async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: + """ + 检查是否需要进行数据库迁移 + 如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。 + """ + data_v3_exists = os.path.exists(get_astrbot_data_path()) + if not data_v3_exists: + return False + migration_done = await db_helper.get_preference("migration_done_v4") + if migration_done: + return False + return True + + +async def do_migration_v4( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + """ + 执行数据库迁移 + 迁移旧的 webchat_conversation 表到新的 conversation 表。 + 迁移旧的 platform 到新的 platform_stats 表。 + """ + if not await check_migration_needed_v4(db_helper): + return + + logger.info("开始执行数据库迁移...") + + # 执行会话表迁移 + await migration_conversation_table(db_helper, platform_id_map) + + # 执行平台统计表迁移 + await migration_platform_table(db_helper, platform_id_map) + + # 执行 WebChat 数据迁移 + await migration_webchat_data(db_helper, platform_id_map) + + # 标记迁移完成 + await db_helper.insert_preference_or_update("migration_done_v4", "true") + + logger.info("数据库迁移完成。") diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py new file mode 100644 index 000000000..b6de3214e --- /dev/null +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -0,0 +1,214 @@ +import json +import datetime +from .. import BaseDatabase +from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3 +from astrbot.core.config.default import DB_PATH +from astrbot.api import logger +from astrbot.core.platform.astr_message_event import MessageSesion +from sqlalchemy.ext.asyncio import AsyncSession +from astrbot.core.db.po import ( + ConversationV2, + PlatformMessageHistory, +) +from sqlalchemy import text + +""" +1. 迁移旧的 webchat_conversation 表到新的 conversation 表。 +2. 迁移旧的 platform 到新的 platform_stats 表。 +""" + + +def get_platform_id( + platform_id_map: dict[str, dict[str, str]], old_platform_name: str +) -> str: + return platform_id_map.get( + old_platform_name, + {"platform_id": old_platform_name, "platform_type": old_platform_name}, + ).get("platform_id", old_platform_name) + + +def get_platform_type( + platform_id_map: dict[str, dict[str, str]], old_platform_name: str +) -> str: + return platform_id_map.get( + old_platform_name, + {"platform_id": old_platform_name, "platform_type": old_platform_name}, + ).get("platform_type", old_platform_name) + + +async def migration_conversation_table( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + db_helper_v3 = SQLiteV3DatabaseV3( + db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + ) + conversations, total_cnt = db_helper_v3.get_all_conversations( + page=1, page_size=10000000 + ) + logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...") + + async with db_helper.get_db() as dbsession: + dbsession: AsyncSession + async with dbsession.begin(): + for conversation in conversations: + try: + conv = db_helper_v3.get_conversation_by_user_id( + user_id=conversation.get("user_id", "unknown"), + cid=conversation.get("cid", "unknown"), + ) + if not conv: + logger.warning( + f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" + ) + if ":" not in conv.user_id: + logger.warning( + f"跳过 user_id 为 {conv.user_id} 的会话,它可能是 WebChat 的消息历史记录。" + ) + continue + session = MessageSesion.from_str(session_str=conv.user_id) + platform_id = get_platform_id( + platform_id_map, session.platform_name + ) + session.platform_name = platform_id # 更新平台名称为新的 ID + conv_v2 = ConversationV2( + user_id=str(session), + content=json.loads(conv.history) if conv.history else [], + platform_id=platform_id, + title=conv.title, + persona_id=conv.persona_id, + conversation_id=conv.cid, + created_at=datetime.datetime.fromtimestamp(conv.created_at), + updated_at=datetime.datetime.fromtimestamp(conv.updated_at), + ) + dbsession.add(conv_v2) + if conv_v2: + logger.info(f"迁移旧会话 {conv.cid} 到新表成功。") + except Exception as e: + logger.error( + f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}", + exc_info=True, + ) + + +async def migration_platform_table( + db_helper: BaseDatabase, platform_id_map: dict[str, str] +): + db_helper_v3 = SQLiteV3DatabaseV3( + db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + ) + secs_from_2023_4_10_to_now = ( + datetime.datetime.now(datetime.timezone.utc) + - datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc) + ).total_seconds() + offset_sec = int(secs_from_2023_4_10_to_now) + logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。") + stats = db_helper_v3.get_base_stats(offset_sec=offset_sec) + logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...") + platform_stats_v3 = stats.platform + + if not platform_stats_v3: + logger.warning("没有找到旧平台数据,跳过迁移。") + return + + first_time_stamp = platform_stats_v3[0].timestamp + end_time_stamp = platform_stats_v3[-1].timestamp + start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时 + end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时 + + idx = 0 + + async with db_helper.get_db() as dbsession: + dbsession: AsyncSession + async with dbsession.begin(): + for bucket_end in range(start_time, end_time, 3600): + cnt = 0 + while ( + idx < len(platform_stats_v3) + and platform_stats_v3[idx].timestamp < bucket_end + ): + cnt += platform_stats_v3[idx].count + idx += 1 + if cnt == 0: + continue + platform_id = get_platform_id( + platform_id_map, platform_stats_v3[idx].name + ) + platform_type = get_platform_type( + platform_id_map, platform_stats_v3[idx].name + ) + logger.info( + f"迁移平台统计数据: {platform_id}, {platform_type}, 时间戳: {bucket_end}, 计数: {cnt}" + ) + try: + await dbsession.execute( + text(""" + INSERT INTO platform_stats (timestamp, platform_id, platform_type, count) + VALUES (:timestamp, :platform_id, :platform_type, :count) + ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET + count = platform_stats.count + EXCLUDED.count + """), + { + "timestamp": datetime.datetime.fromtimestamp( + bucket_end, tz=datetime.timezone.utc + ), + "platform_id": platform_id, + "platform_type": platform_type, + "count": cnt, + }, + ) + except Exception: + logger.error( + f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}", + exc_info=True, + ) + + +async def migration_webchat_data( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" + db_helper_v3 = SQLiteV3DatabaseV3( + db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + ) + conversations, total_cnt = db_helper_v3.get_all_conversations( + page=1, page_size=10000000 + ) + logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...") + + async with db_helper.get_db() as dbsession: + dbsession: AsyncSession + async with dbsession.begin(): + for conversation in conversations: + try: + conv = db_helper_v3.get_conversation_by_user_id( + user_id=conversation.get("user_id", "unknown"), + cid=conversation.get("cid", "unknown"), + ) + if not conv: + logger.warning( + f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" + ) + if ":" in conv.user_id: + logger.warning( + f"跳过 user_id 为 {conv.user_id} 的会话,它不是 WebChat 的消息历史记录。" + ) + continue + platform_id = "webchat" + history = json.loads(conv.history) if conv.history else [] + for msg in history: + type_ = msg.get("type") # user type, "bot" or "user" + new_history = PlatformMessageHistory( + platform_id=platform_id, + user_id=conv.cid, # we use conv.cid as user_id for webchat + content=msg, + sender_id=type_, + sender_name=type_, + ) + dbsession.add(new_history) + + logger.info(f"迁移旧 WebChat 会话 {conv.cid} 到新表成功。") + except Exception: + logger.error( + f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败", + exc_info=True, + ) diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py new file mode 100644 index 000000000..e7e734abd --- /dev/null +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -0,0 +1,493 @@ +import sqlite3 +import time +from astrbot.core.db.po import Platform, Stats +from typing import Tuple, List, Dict, Any +from dataclasses import dataclass + +@dataclass +class Conversation: + """LLM 对话存储 + + 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 + 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 + """ + + user_id: str + cid: str + history: str = "" + """字符串格式的列表。""" + created_at: int = 0 + updated_at: int = 0 + title: str = "" + persona_id: str = "" + + +INIT_SQL = """ +CREATE TABLE IF NOT EXISTS platform( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS llm( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS plugin( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS command( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS llm_history( + provider_type VARCHAR(32), + session_id VARCHAR(32), + content TEXT +); + +-- ATRI +CREATE TABLE IF NOT EXISTS atri_vision( + id TEXT, + url_or_path TEXT, + caption TEXT, + is_meme BOOLEAN, + keywords TEXT, + platform_name VARCHAR(32), + session_id VARCHAR(32), + sender_nickname VARCHAR(32), + timestamp INTEGER +); + +CREATE TABLE IF NOT EXISTS webchat_conversation( + user_id TEXT, -- 会话 id + cid TEXT, -- 对话 id + history TEXT, + created_at INTEGER, + updated_at INTEGER, + title TEXT, + persona_id TEXT +); + +PRAGMA encoding = 'UTF-8'; +""" + + +class SQLiteDatabase(): + def __init__(self, db_path: str) -> None: + super().__init__() + self.db_path = db_path + + sql = INIT_SQL + + # 初始化数据库 + self.conn = self._get_conn(self.db_path) + c = self.conn.cursor() + c.executescript(sql) + self.conn.commit() + + # 检查 webchat_conversation 的 title 字段是否存在 + c.execute( + """ + PRAGMA table_info(webchat_conversation) + """ + ) + res = c.fetchall() + has_title = False + has_persona_id = False + for row in res: + if row[1] == "title": + has_title = True + if row[1] == "persona_id": + has_persona_id = True + if not has_title: + c.execute( + """ + ALTER TABLE webchat_conversation ADD COLUMN title TEXT; + """ + ) + self.conn.commit() + if not has_persona_id: + c.execute( + """ + ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; + """ + ) + self.conn.commit() + + c.close() + + def _get_conn(self, db_path: str) -> sqlite3.Connection: + conn = sqlite3.connect(self.db_path) + conn.text_factory = str + return conn + + def _exec_sql(self, sql: str, params: Tuple = None): + conn = self.conn + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + conn = self._get_conn(self.db_path) + c = conn.cursor() + + if params: + c.execute(sql, params) + c.close() + else: + c.execute(sql) + c.close() + + conn.commit() + + def insert_platform_metrics(self, metrics: dict): + for k, v in metrics.items(): + self._exec_sql( + """ + INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) + """, + (k, v, int(time.time())), + ) + + def insert_llm_metrics(self, metrics: dict): + for k, v in metrics.items(): + self._exec_sql( + """ + INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) + """, + (k, v, int(time.time())), + ) + + def get_base_stats(self, offset_sec: int = 86400) -> Stats: + """获取 offset_sec 秒前到现在的基础统计数据""" + where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" + + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT * FROM platform + """ + + where_clause + ) + + platform = [] + for row in c.fetchall(): + platform.append(Platform(*row)) + + c.close() + + return Stats(platform=platform) + + def get_total_message_count(self) -> int: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT SUM(count) FROM platform + """ + ) + res = c.fetchone() + c.close() + return res[0] + + def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: + """获取 offset_sec 秒前到现在的基础统计数据(合并)""" + where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" + + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT name, SUM(count), timestamp FROM platform + """ + + where_clause + + " GROUP BY name" + ) + + platform = [] + for row in c.fetchall(): + platform.append(Platform(*row)) + + c.close() + + return Stats(platform, [], []) + + def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? + """, + (user_id, cid), + ) + + res = c.fetchone() + c.close() + + if not res: + return + + return Conversation(*res) + + def new_conversation(self, user_id: str, cid: str): + history = "[]" + updated_at = int(time.time()) + created_at = updated_at + self._exec_sql( + """ + INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) + """, + (user_id, cid, history, updated_at, created_at), + ) + + def get_conversations(self, user_id: str) -> Tuple: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC + """, + (user_id,), + ) + + res = c.fetchall() + c.close() + conversations = [] + for row in res: + cid = row[0] + created_at = row[1] + updated_at = row[2] + title = row[3] + persona_id = row[4] + conversations.append( + Conversation("", cid, "[]", created_at, updated_at, title, persona_id) + ) + return conversations + + def update_conversation(self, user_id: str, cid: str, history: str): + """更新对话,并且同时更新时间""" + updated_at = int(time.time()) + self._exec_sql( + """ + UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? + """, + (history, updated_at, user_id, cid), + ) + + def update_conversation_title(self, user_id: str, cid: str, title: str): + self._exec_sql( + """ + UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? + """, + (title, user_id, cid), + ) + + def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + self._exec_sql( + """ + UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? + """, + (persona_id, user_id, cid), + ) + + def delete_conversation(self, user_id: str, cid: str): + self._exec_sql( + """ + DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? + """, + (user_id, cid), + ) + + def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> Tuple[List[Dict[str, Any]], int]: + """获取所有对话,支持分页,按更新时间降序排序""" + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + try: + # 获取总记录数 + c.execute(""" + SELECT COUNT(*) FROM webchat_conversation + """) + total_count = c.fetchone()[0] + + # 计算偏移量 + offset = (page - 1) * page_size + + # 获取分页数据,按更新时间降序排序 + c.execute( + """ + SELECT user_id, cid, created_at, updated_at, title, persona_id + FROM webchat_conversation + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (page_size, offset), + ) + + rows = c.fetchall() + + conversations = [] + + for row in rows: + user_id, cid, created_at, updated_at, title, persona_id = row + # 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值 + safe_cid = str(cid) if cid else "unknown" + display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid + + conversations.append( + { + "user_id": user_id or "", + "cid": safe_cid, + "title": title or f"对话 {display_cid}", + "persona_id": persona_id or "", + "created_at": created_at or 0, + "updated_at": updated_at or 0, + } + ) + + return conversations, total_count + + except Exception as _: + # 返回空列表和0,确保即使出错也有有效的返回值 + return [], 0 + finally: + c.close() + + def get_filtered_conversations( + self, + page: int = 1, + page_size: int = 20, + platforms: List[str] = None, + message_types: List[str] = None, + search_query: str = None, + exclude_ids: List[str] = None, + exclude_platforms: List[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """获取筛选后的对话列表""" + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + try: + # 构建查询条件 + where_clauses = [] + params = [] + + # 平台筛选 + if platforms and len(platforms) > 0: + platform_conditions = [] + for platform in platforms: + platform_conditions.append("user_id LIKE ?") + params.append(f"{platform}:%") + + if platform_conditions: + where_clauses.append(f"({' OR '.join(platform_conditions)})") + + # 消息类型筛选 + if message_types and len(message_types) > 0: + message_type_conditions = [] + for msg_type in message_types: + message_type_conditions.append("user_id LIKE ?") + params.append(f"%:{msg_type}:%") + + if message_type_conditions: + where_clauses.append(f"({' OR '.join(message_type_conditions)})") + + # 搜索关键词 + if search_query: + search_query = search_query.encode("unicode_escape").decode("utf-8") + where_clauses.append( + "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)" + ) + search_param = f"%{search_query}%" + params.extend([search_param, search_param, search_param, search_param]) + + # 排除特定用户ID + if exclude_ids and len(exclude_ids) > 0: + for exclude_id in exclude_ids: + where_clauses.append("user_id NOT LIKE ?") + params.append(f"{exclude_id}%") + + # 排除特定平台 + if exclude_platforms and len(exclude_platforms) > 0: + for exclude_platform in exclude_platforms: + where_clauses.append("user_id NOT LIKE ?") + params.append(f"{exclude_platform}:%") + + # 构建完整的 WHERE 子句 + where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" + + # 构建计数查询 + count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}" + + # 获取总记录数 + c.execute(count_sql, params) + total_count = c.fetchone()[0] + + # 计算偏移量 + offset = (page - 1) * page_size + + # 构建分页数据查询 + data_sql = f""" + SELECT user_id, cid, created_at, updated_at, title, persona_id + FROM webchat_conversation + {where_sql} + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """ + query_params = params + [page_size, offset] + + # 获取分页数据 + c.execute(data_sql, query_params) + rows = c.fetchall() + + conversations = [] + + for row in rows: + user_id, cid, created_at, updated_at, title, persona_id = row + # 确保 cid 是字符串类型,否则使用一个默认值 + safe_cid = str(cid) if cid else "unknown" + display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid + + conversations.append( + { + "user_id": user_id or "", + "cid": safe_cid, + "title": title or f"对话 {display_cid}", + "persona_id": persona_id or "", + "created_at": created_at or 0, + "updated_at": updated_at or 0, + } + ) + + return conversations, total_count + + except Exception as _: + # 返回空列表和0,确保即使出错也有有效的返回值 + return [], 0 + finally: + c.close() diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 49adb2781..30f7188a1 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -1,7 +1,190 @@ -"""指标数据""" +import uuid +from datetime import datetime, timezone from dataclasses import dataclass, field -from typing import List +from sqlmodel import ( + SQLModel, + Text, + JSON, + UniqueConstraint, + Field, +) +from typing import Optional + + +class PlatformStat(SQLModel, table=True): + """This class represents the statistics of bot usage across different platforms. + + Note: In astrbot v4, we moved `platform` table to here. + """ + + __tablename__ = "platform_stats" + + id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + timestamp: datetime = Field(nullable=False) + platform_id: str = Field(nullable=False) + platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc. + count: int = Field(default=0, nullable=False) + + __table_args__ = ( + UniqueConstraint( + "timestamp", + "platform_id", + "platform_type", + name="uix_platform_stats", + ), + ) + + +class ConversationV2(SQLModel, table=True): + __tablename__ = "conversations" + + inner_conversation_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + conversation_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()) + ) + platform_id: str = Field(nullable=False) + user_id: str = Field(nullable=False) + content: Optional[list] = Field(default=None, sa_type=JSON) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + title: Optional[str] = Field(default=None, max_length=255) + persona_id: Optional[str] = Field(default=None) + + __table_args__ = ( + UniqueConstraint( + "conversation_id", + name="uix_conversation_id", + ), + ) + + +class Persona(SQLModel, table=True): + """Persona is a set of instructions for LLMs to follow. + + It can be used to customize the behavior of LLMs. + """ + + __tablename__ = "personas" + + id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + persona_id: str = Field(max_length=255, nullable=False) + system_prompt: str = Field(sa_type=Text, nullable=False) + begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON) + """a list of strings, each representing a dialog to start with""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "persona_id", + name="uix_persona_id", + ), + ) + + +class Preference(SQLModel, table=True): + """This class represents user preferences for bots.""" + + __tablename__ = "preferences" + + key: str = Field(primary_key=True, nullable=False) + value: str = Field(sa_type=Text, nullable=False) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + +class PlatformMessageHistory(SQLModel, table=True): + """This class represents the message history for a specific platform. + + It is used to store messages that are not LLM-generated, such as user messages + or platform-specific messages. + """ + + __tablename__ = "platform_message_history" + + id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + platform_id: str = Field(nullable=False) + user_id: str = Field(nullable=False) # An id of group, user in platform + sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform + sender_name: Optional[str] = Field(default=None) # Name of the sender in the platform + content: dict = Field(sa_type=JSON, nullable=False) # a message chain list + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + +class Attachment(SQLModel, table=True): + """This class represents attachments for messages in AstrBot. + + Attachments can be images, files, or other media types. + """ + + __tablename__ = "attachments" + + inner_attachment_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + attachment_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()) + ) + path: str = Field(nullable=False) # Path to the file on disk + type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file') + mime_type: str = Field(nullable=False) # MIME type of the file + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "attachment_id", + name="uix_attachment_id", + ), + ) + + +@dataclass +class Conversation: + """LLM 对话类 + + 对于 WebChat,history 存储了包括指令、回复、图片等在内的所有消息。 + 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 + + 在 v4.0.0 版本及之后,WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中, + """ + + platform_id: str + user_id: str + cid: str + """对话 ID, 是 uuid 格式的字符串""" + history: str = "" + """字符串格式的对话列表。""" + title: str = "" + persona_id: str = "" + created_at: int = 0 + updated_at: int = 0 + + +# ==== +# Deprecated, and will be removed in future versions. +# ==== @dataclass @@ -13,77 +196,6 @@ class Platform: timestamp: int -@dataclass -class Provider: - """供应商使用统计数据""" - - name: str - count: int - timestamp: int - - -@dataclass -class Plugin: - """插件使用统计数据""" - - name: str - count: int - timestamp: int - - -@dataclass -class Command: - """命令使用统计数据""" - - name: str - count: int - timestamp: int - - @dataclass class Stats: - platform: List[Platform] = field(default_factory=list) - command: List[Command] = field(default_factory=list) - llm: List[Provider] = field(default_factory=list) - - -@dataclass -class LLMHistory: - """LLM 聊天时持久化的信息""" - - provider_type: str - session_id: str - content: str - - -@dataclass -class ATRIVision: - """Deprecated""" - - id: str - url_or_path: str - caption: str - is_meme: bool - keywords: List[str] - platform_name: str - session_id: str - sender_nickname: str - timestamp: int = -1 - - -@dataclass -class Conversation: - """LLM 对话存储 - - 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 - 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 - """ - - user_id: str - cid: str - history: str = "" - """字符串格式的列表。""" - created_at: int = 0 - updated_at: int = 0 - title: str = "" - persona_id: str = "" + platform: list[Platform] = field(default_factory=list) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 2abba1de9..0308807da 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,567 +1,461 @@ -import sqlite3 -import os -import time -from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation -from . import BaseDatabase -from typing import Tuple, List, Dict, Any +import asyncio +import typing as T +import threading +from datetime import datetime, timedelta +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import ( + ConversationV2, + PlatformStat, + PlatformMessageHistory, + Attachment, + Persona, + Preference, + Stats as DeprecatedStats, + Platform as DeprecatedPlatformStat, + SQLModel, +) + +from sqlalchemy import select, update, delete, text +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import func class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: - super().__init__() self.db_path = db_path + self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" + self.inited = False + super().__init__() - with open( - os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8" - ) as f: - sql = f.read() + async def initialize(self) -> None: + """Initialize the database by creating tables if they do not exist.""" + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + await conn.commit() - # 初始化数据库 - self.conn = self._get_conn(self.db_path) - c = self.conn.cursor() - c.executescript(sql) - self.conn.commit() + # ==== + # Platform Statistics + # ==== - # 检查 webchat_conversation 的 title 字段是否存在 - c.execute( - """ - PRAGMA table_info(webchat_conversation) - """ - ) - res = c.fetchall() - has_title = False - has_persona_id = False - for row in res: - if row[1] == "title": - has_title = True - if row[1] == "persona_id": - has_persona_id = True - if not has_title: - c.execute( - """ - ALTER TABLE webchat_conversation ADD COLUMN title TEXT; - """ - ) - self.conn.commit() - if not has_persona_id: - c.execute( - """ - ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; - """ - ) - self.conn.commit() - - c.close() - - def _get_conn(self, db_path: str) -> sqlite3.Connection: - conn = sqlite3.connect(self.db_path) - conn.text_factory = str - return conn - - def _exec_sql(self, sql: str, params: Tuple = None): - conn = self.conn - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - conn = self._get_conn(self.db_path) - c = conn.cursor() - - if params: - c.execute(sql, params) - c.close() - else: - c.execute(sql) - c.close() - - conn.commit() - - def insert_platform_metrics(self, metrics: dict): - for k, v in metrics.items(): - self._exec_sql( - """ - INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) - """, - (k, v, int(time.time())), - ) - - def insert_plugin_metrics(self, metrics: dict): - pass - - def insert_command_metrics(self, metrics: dict): - for k, v in metrics.items(): - self._exec_sql( - """ - INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?) - """, - (k, v, int(time.time())), - ) - - def insert_llm_metrics(self, metrics: dict): - for k, v in metrics.items(): - self._exec_sql( - """ - INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) - """, - (k, v, int(time.time())), - ) - - def update_llm_history(self, session_id: str, content: str, provider_type: str): - res = self.get_llm_history(session_id, provider_type) - if res: - self._exec_sql( - """ - UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ? - """, - (content, session_id, provider_type), - ) - else: - self._exec_sql( - """ - INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?) - """, - (provider_type, session_id, content), - ) - - def get_llm_history( - self, session_id: str = None, provider_type: str = None - ) -> Tuple: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - conditions = [] - params = [] - - if session_id: - conditions.append("session_id = ?") - params.append(session_id) - - if provider_type: - conditions.append("provider_type = ?") - params.append(provider_type) - - sql = "SELECT * FROM llm_history" - if conditions: - sql += " WHERE " + " AND ".join(conditions) - - c.execute(sql, params) - - res = c.fetchall() - histories = [] - for row in res: - histories.append(LLMHistory(*row)) - c.close() - return histories - - def get_base_stats(self, offset_sec: int = 86400) -> Stats: - """获取 offset_sec 秒前到现在的基础统计数据""" - where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" - - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM platform - """ - + where_clause - ) - - platform = [] - for row in c.fetchall(): - platform.append(Platform(*row)) - - # c.execute( - # ''' - # SELECT * FROM command - # ''' + where_clause - # ) - - # command = [] - # for row in c.fetchall(): - # command.append(Command(*row)) - - # c.execute( - # ''' - # SELECT * FROM llm - # ''' + where_clause - # ) - - # llm = [] - # for row in c.fetchall(): - # llm.append(Provider(*row)) - - c.close() - - return Stats(platform, [], []) - - def get_total_message_count(self) -> int: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT SUM(count) FROM platform - """ - ) - res = c.fetchone() - c.close() - return res[0] - - def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: - """获取 offset_sec 秒前到现在的基础统计数据(合并)""" - where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" - - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT name, SUM(count), timestamp FROM platform - """ - + where_clause - + " GROUP BY name" - ) - - platform = [] - for row in c.fetchall(): - platform.append(Platform(*row)) - - c.close() - - return Stats(platform, [], []) - - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? - """, - (user_id, cid), - ) - - res = c.fetchone() - c.close() - - if not res: - return - - return Conversation(*res) - - def new_conversation(self, user_id: str, cid: str): - history = "[]" - updated_at = int(time.time()) - created_at = updated_at - self._exec_sql( - """ - INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) - """, - (user_id, cid, history, updated_at, created_at), - ) - - def get_conversations(self, user_id: str) -> Tuple: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC - """, - (user_id,), - ) - - res = c.fetchall() - c.close() - conversations = [] - for row in res: - cid = row[0] - created_at = row[1] - updated_at = row[2] - title = row[3] - persona_id = row[4] - conversations.append( - Conversation("", cid, "[]", created_at, updated_at, title, persona_id) - ) - return conversations - - def update_conversation(self, user_id: str, cid: str, history: str): - """更新对话,并且同时更新时间""" - updated_at = int(time.time()) - self._exec_sql( - """ - UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? - """, - (history, updated_at, user_id, cid), - ) - - def update_conversation_title(self, user_id: str, cid: str, title: str): - self._exec_sql( - """ - UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? - """, - (title, user_id, cid), - ) - - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): - self._exec_sql( - """ - UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? - """, - (persona_id, user_id, cid), - ) - - def delete_conversation(self, user_id: str, cid: str): - self._exec_sql( - """ - DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? - """, - (user_id, cid), - ) - - def insert_atri_vision_data(self, vision: ATRIVision): - ts = int(time.time()) - keywords = ",".join(vision.keywords) - self._exec_sql( - """ - INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - vision.id, - vision.url_or_path, - vision.caption, - vision.is_meme, - keywords, - vision.platform_name, - vision.session_id, - vision.sender_nickname, - ts, - ), - ) - - def get_atri_vision_data(self) -> Tuple: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM atri_vision - """ - ) - - res = c.fetchall() - visions = [] - for row in res: - visions.append(ATRIVision(*row)) - c.close() - return visions - - def get_atri_vision_data_by_path_or_id( - self, url_or_path: str, id: str - ) -> ATRIVision: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ? - """, - (url_or_path, id), - ) - - res = c.fetchone() - c.close() - if res: - return ATRIVision(*res) - return None - - def get_all_conversations( - self, page: int = 1, page_size: int = 20 - ) -> Tuple[List[Dict[str, Any]], int]: - """获取所有对话,支持分页,按更新时间降序排序""" - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - try: - # 获取总记录数 - c.execute(""" - SELECT COUNT(*) FROM webchat_conversation - """) - total_count = c.fetchone()[0] - - # 计算偏移量 - offset = (page - 1) * page_size - - # 获取分页数据,按更新时间降序排序 - c.execute( - """ - SELECT user_id, cid, created_at, updated_at, title, persona_id - FROM webchat_conversation - ORDER BY updated_at DESC - LIMIT ? OFFSET ? - """, - (page_size, offset), - ) - - rows = c.fetchall() - - conversations = [] - - for row in rows: - user_id, cid, created_at, updated_at, title, persona_id = row - # 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值 - safe_cid = str(cid) if cid else "unknown" - display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid - - conversations.append( - { - "user_id": user_id or "", - "cid": safe_cid, - "title": title or f"对话 {display_cid}", - "persona_id": persona_id or "", - "created_at": created_at or 0, - "updated_at": updated_at or 0, - } - ) - - return conversations, total_count - - except Exception as _: - # 返回空列表和0,确保即使出错也有有效的返回值 - return [], 0 - finally: - c.close() - - def get_filtered_conversations( + async def insert_platform_stats( self, - page: int = 1, - page_size: int = 20, - platforms: List[str] = None, - message_types: List[str] = None, - search_query: str = None, - exclude_ids: List[str] = None, - exclude_platforms: List[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: - """获取筛选后的对话列表""" - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - try: - # 构建查询条件 - where_clauses = [] - params = [] - - # 平台筛选 - if platforms and len(platforms) > 0: - platform_conditions = [] - for platform in platforms: - platform_conditions.append("user_id LIKE ?") - params.append(f"{platform}:%") - - if platform_conditions: - where_clauses.append(f"({' OR '.join(platform_conditions)})") - - # 消息类型筛选 - if message_types and len(message_types) > 0: - message_type_conditions = [] - for msg_type in message_types: - message_type_conditions.append("user_id LIKE ?") - params.append(f"%:{msg_type}:%") - - if message_type_conditions: - where_clauses.append(f"({' OR '.join(message_type_conditions)})") - - # 搜索关键词 - if search_query: - search_query = search_query.encode("unicode_escape").decode("utf-8") - where_clauses.append( - "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)" - ) - search_param = f"%{search_query}%" - params.extend([search_param, search_param, search_param, search_param]) - - # 排除特定用户ID - if exclude_ids and len(exclude_ids) > 0: - for exclude_id in exclude_ids: - where_clauses.append("user_id NOT LIKE ?") - params.append(f"{exclude_id}%") - - # 排除特定平台 - if exclude_platforms and len(exclude_platforms) > 0: - for exclude_platform in exclude_platforms: - where_clauses.append("user_id NOT LIKE ?") - params.append(f"{exclude_platform}:%") - - # 构建完整的 WHERE 子句 - where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" - - # 构建计数查询 - count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}" - - # 获取总记录数 - c.execute(count_sql, params) - total_count = c.fetchone()[0] - - # 计算偏移量 - offset = (page - 1) * page_size - - # 构建分页数据查询 - data_sql = f""" - SELECT user_id, cid, created_at, updated_at, title, persona_id - FROM webchat_conversation - {where_sql} - ORDER BY updated_at DESC - LIMIT ? OFFSET ? - """ - query_params = params + [page_size, offset] - - # 获取分页数据 - c.execute(data_sql, query_params) - rows = c.fetchall() - - conversations = [] - - for row in rows: - user_id, cid, created_at, updated_at, title, persona_id = row - # 确保 cid 是字符串类型,否则使用一个默认值 - safe_cid = str(cid) if cid else "unknown" - display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid - - conversations.append( + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime = None, + ) -> None: + """Insert a new platform statistic record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + if timestamp is None: + timestamp = datetime.now().replace( + minute=0, second=0, microsecond=0 + ) + current_hour = timestamp + await session.execute( + text(""" + INSERT INTO platform_stats (timestamp, platform_id, platform_type, count) + VALUES (:timestamp, :platform_id, :platform_type, :count) + ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET + count = platform_stats.count + EXCLUDED.count + """), { - "user_id": user_id or "", - "cid": safe_cid, - "title": title or f"对话 {display_cid}", - "persona_id": persona_id or "", - "created_at": created_at or 0, - "updated_at": updated_at or 0, - } + "timestamp": current_hour, + "platform_id": platform_id, + "platform_type": platform_type, + "count": count, + }, ) - return conversations, total_count + async def count_platform_stats(self) -> int: + """Count the number of platform statistics records.""" + async with self.get_db() as session: + session: AsyncSession + result = await session.execute( + select(func.count(PlatformStat.platform_id)).select_from(PlatformStat) + ) + count = result.scalar_one_or_none() + return count if count is not None else 0 - except Exception as _: - # 返回空列表和0,确保即使出错也有有效的返回值 - return [], 0 - finally: - c.close() + async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]: + """Get platform statistics within the specified offset in seconds and group by platform_id.""" + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + text(""" + SELECT * FROM platform_stats + WHERE timestamp >= :start_time + ORDER BY timestamp DESC + GROUP BY platform_id + """), + {"start_time": start_time}, + ) + return result.scalars().all() + + # ==== + # Conversation Management + # ==== + + async def get_conversations( + self, user_id=None, platform_id=None + ): + async with self.get_db() as session: + session: AsyncSession + query = select(ConversationV2) + + if user_id: + query = query.where(ConversationV2.user_id == user_id) + if platform_id: + query = query.where(ConversationV2.platform_id == platform_id) + # order by + query = query.order_by(ConversationV2.created_at.desc()) + result = await session.execute(query) + + return result.scalars().all() + + async def get_conversation_by_id(self, cid): + async with self.get_db() as session: + session: AsyncSession + query = select(ConversationV2).where(ConversationV2.conversation_id == cid) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_all_conversations(self, page=1, page_size=20): + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + result = await session.execute( + select(ConversationV2) + .order_by(ConversationV2.created_at.desc()) + .offset(offset) + .limit(page_size) + ) + return result.scalars().all() + + async def get_filtered_conversations( + self, + page=1, + page_size=20, + platform_ids=None, + search_query="", + **kwargs, + ): + async with self.get_db() as session: + session: AsyncSession + # Build the base query with filters + base_query = select(ConversationV2) + + if platform_ids: + base_query = base_query.where( + ConversationV2.platform_id.in_(platform_ids) + ) + if search_query: + base_query = base_query.where( + ConversationV2.title.ilike(f"%{search_query}%") + ) + + # Get total count matching the filters + count_query = select(func.count()).select_from(base_query.subquery()) + total_count = await session.execute(count_query) + total = total_count.scalar_one() + + # Get paginated results + offset = (page - 1) * page_size + result_query = ( + base_query.order_by(ConversationV2.created_at.desc()) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(result_query) + conversations = result.scalars().all() + + return conversations, total + + async def create_conversation( + self, + user_id, + platform_id, + content=None, + title=None, + persona_id=None, + cid=None, + created_at=None, + updated_at=None, + ): + kwargs = {} + if cid: + kwargs["conversation_id"] = cid + if created_at: + kwargs["created_at"] = created_at + if updated_at: + kwargs["updated_at"] = updated_at + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_conversation = ConversationV2( + user_id=user_id, + content=content or [], + platform_id=platform_id, + title=title, + persona_id=persona_id, + **kwargs, + ) + session.add(new_conversation) + return new_conversation + + async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = update(ConversationV2).where( + ConversationV2.conversation_id == cid + ) + values = {} + if title is not None: + values["title"] = title + if persona_id is not None: + values["persona_id"] = persona_id + if content is not None: + values["content"] = content + if not values: + return + query = query.values(**values) + await session.execute(query) + return await self.get_conversation_by_id(cid) + + async def delete_conversation(self, cid): + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(ConversationV2).where(ConversationV2.conversation_id == cid) + ) + + async def insert_platform_message_history( + self, + platform_id, + user_id, + content, + sender_id=None, + sender_name=None, + ): + """Insert a new platform message history record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_history = PlatformMessageHistory( + platform_id=platform_id, + user_id=user_id, + content=content, + sender_id=sender_id, + sender_name=sender_name, + ) + session.add(new_history) + return new_history + + async def delete_platform_message_offset( + self, platform_id, user_id, offset_sec=86400 + ): + """Delete platform message history records older than the specified offset.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now() + cutoff_time = now - timedelta(seconds=offset_sec) + await session.execute( + delete(PlatformMessageHistory).where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + PlatformMessageHistory.created_at < cutoff_time, + ) + ) + + async def get_platform_message_history( + self, platform_id, user_id, page=1, page_size=20 + ): + """Get platform message history records.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = ( + select(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + ) + .order_by(PlatformMessageHistory.created_at.desc()) + ) + result = await session.execute(query.offset(offset).limit(page_size)) + return result.scalars().all() + + async def insert_attachment(self, path, type, mime_type): + """Insert a new attachment record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_attachment = Attachment( + path=path, + type=type, + mime_type=mime_type, + ) + session.add(new_attachment) + return new_attachment + + async def get_attachment_by_id(self, attachment_id): + """Get an attachment by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Attachment).where(Attachment.id == attachment_id) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def insert_persona(self, persona_id, system_prompt, begin_dialogs=None): + """Insert a new persona record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_persona = Persona( + persona_id=persona_id, + system_prompt=system_prompt, + begin_dialogs=begin_dialogs or [], + ) + session.add(new_persona) + return new_persona + + async def get_persona_by_id(self, persona_id): + """Get a persona by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Persona).where(Persona.persona_id == persona_id) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_personas(self): + """Get all personas for a specific bot.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Persona) + result = await session.execute(query) + return result.scalars().all() + + async def insert_preference_or_update(self, key, value): + """Insert a new preference record or update if it exists.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = select(Preference).where(Preference.key == key) + result = await session.execute(query) + existing_preference = result.scalar_one_or_none() + if existing_preference: + existing_preference.value = value + else: + new_preference = Preference(key=key, value=value) + session.add(new_preference) + return existing_preference or new_preference + + async def get_preference(self, key): + """Get a preference by key.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Preference).where(Preference.key == key) + result = await session.execute(query) + return result.scalar_one_or_none() + + # ==== + # Deprecated Methods + # ==== + + def get_base_stats(self, offset_sec=86400): + """Get base statistics within the specified offset in seconds.""" + + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + select(PlatformStat).where(PlatformStat.timestamp >= start_time) + ) + all_datas = result.scalars().all() + deprecated_stats = DeprecatedStats() + for data in all_datas: + deprecated_stats.platform.append( + DeprecatedPlatformStat( + name=data.platform_id, + count=data.count, + timestamp=data.timestamp.timestamp(), + ) + ) + return deprecated_stats + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result + + def get_total_message_count(self): + """Get the total message count from platform statistics.""" + + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + result = await session.execute( + select(func.sum(PlatformStat.count)).select_from(PlatformStat) + ) + total_count = result.scalar_one_or_none() + return total_count if total_count is not None else 0 + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result + + def get_grouped_base_stats(self, offset_sec=86400): + # group by platform_id + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + select(PlatformStat.platform_id, func.sum(PlatformStat.count)) + .where(PlatformStat.timestamp >= start_time) + .group_by(PlatformStat.platform_id) + ) + grouped_stats = result.all() + deprecated_stats = DeprecatedStats() + for platform_id, count in grouped_stats: + deprecated_stats.platform.append( + DeprecatedPlatformStat( + name=platform_id, + count=count, + timestamp=start_time.timestamp(), + ) + ) + return deprecated_stats + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result diff --git a/astrbot/core/db/sqlite_init.sql b/astrbot/core/db/sqlite_init.sql deleted file mode 100644 index a1ebc54b5..000000000 --- a/astrbot/core/db/sqlite_init.sql +++ /dev/null @@ -1,50 +0,0 @@ -CREATE TABLE IF NOT EXISTS platform( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS llm( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS plugin( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS command( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS llm_history( - provider_type VARCHAR(32), - session_id VARCHAR(32), - content TEXT -); - --- ATRI -CREATE TABLE IF NOT EXISTS atri_vision( - id TEXT, - url_or_path TEXT, - caption TEXT, - is_meme BOOLEAN, - keywords TEXT, - platform_name VARCHAR(32), - session_id VARCHAR(32), - sender_nickname VARCHAR(32), - timestamp INTEGER -); - -CREATE TABLE IF NOT EXISTS webchat_conversation( - user_id TEXT, -- 会话 id - cid TEXT, -- 对话 id - history TEXT, - created_at INTEGER, - updated_at INTEGER, - title TEXT, - persona_id TEXT -); - -PRAGMA encoding = 'UTF-8'; \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 2d9392e6d..1344d4dec 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -266,12 +266,12 @@ class LLMRequestSubStage(Stage): else: async for _ in requesting(): yield + await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp()) # 异步处理 WebChat 特殊情况 if event.get_platform_name() == "webchat": asyncio.create_task(self._handle_webchat(event, req, provider)) - await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp()) async def _handle_webchat( self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 7a3102de5..e565c13ce 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -26,7 +26,7 @@ from .platform_metadata import PlatformMetadata @dataclass -class MessageSesion: +class MessageSession: platform_name: str message_type: MessageType session_id: str @@ -37,8 +37,9 @@ class MessageSesion: @staticmethod def from_str(session_str: str): platform_name, message_type, session_id = session_str.split(":") - return MessageSesion(platform_name, MessageType(message_type), session_id) + return MessageSession(platform_name, MessageType(message_type), session_id) +MessageSesion = MessageSession # back compatibility class AstrMessageEvent(abc.ABC): def __init__( diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py new file mode 100644 index 000000000..16e59a5cc --- /dev/null +++ b/astrbot/core/platform_message_history_mgr.py @@ -0,0 +1,47 @@ +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import PlatformMessageHistory + + +class PlatformMessageHistoryManager: + def __init__(self, db_helper: BaseDatabase): + self.db = db_helper + + async def insert( + self, + platform_id: str, + user_id: str, + content: list[dict], # TODO: parse from message chain + sender_id: str = None, + sender_name: str = None, + ): + """Insert a new platform message history record.""" + await self.db.insert_platform_message_history( + platform_id=platform_id, + user_id=user_id, + content=content, + sender_id=sender_id, + sender_name=sender_name, + ) + + async def get( + self, + platform_id: str, + user_id: str, + page: int = 1, + page_size: int = 200, + ) -> list[PlatformMessageHistory]: + """Get platform message history for a specific user.""" + history = await self.db.get_platform_message_history( + platform_id=platform_id, + user_id=user_id, + page=page, + page_size=page_size, + ) + history.reverse() + return history + + async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400): + """Delete platform message history records older than the specified offset.""" + await self.db.delete_platform_message_offset( + platform_id=platform_id, user_id=user_id, offset_sec=offset_sec + ) diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index a3a73fcc8..7fe9bde05 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -58,9 +58,10 @@ class Metric: pass try: if "adapter_name" in kwargs: - db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1}) - if "llm_name" in kwargs: - db_helper.insert_llm_metrics({kwargs["llm_name"]: 1}) + await db_helper.insert_platform_stats( + platform_id=kwargs["adapter_name"], + platform_type=kwargs.get("adapter_type", "unknown"), + ) except Exception as e: logger.error(f"保存指标到数据库失败: {e}") pass diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 651f1b65c..dde4c7644 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -9,6 +9,7 @@ import asyncio from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.platform.astr_message_event import MessageSession class ChatRoute(Route): @@ -31,13 +32,14 @@ class ChatRoute(Route): "/chat/post_file": ("POST", self.post_file), "/chat/status": ("GET", self.status), } - self.db = db self.core_lifecycle = core_lifecycle self.register_routes() self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") os.makedirs(self.imgs_dir, exist_ok=True) self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] + self.conv_mgr = core_lifecycle.conversation_manager + self.platform_history_mgr = core_lifecycle.platform_message_history_manager async def status(self): has_llm_enabled = ( @@ -131,24 +133,23 @@ class ChatRoute(Route): if not conversation_id: return Response().error("conversation_id is empty").__dict__ - # Get conversation-specific queues - back_queue = webchat_queue_mgr.get_or_create_back_queue(conversation_id) - # append user message - conversation = self.db.get_conversation_by_user_id(username, conversation_id) - try: - history = json.loads(conversation.history) - except BaseException as e: - logger.error(f"Failed to parse conversation history: {e}") - history = [] + webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + + # Get conversation-specific queues + back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) + new_his = {"type": "user", "message": message} if image_url: new_his["image_url"] = image_url if audio_url: new_his["audio_url"] = audio_url - history.append(new_his) - self.db.update_conversation( - username, conversation_id, history=json.dumps(history) + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id=username, + sender_name=username, ) async def stream(): @@ -164,7 +165,6 @@ class ChatRoute(Route): result_text = result["data"] type = result.get("type") - cid = result.get("cid") streaming = result.get("streaming", False) yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" await asyncio.sleep(0.05) @@ -173,17 +173,13 @@ class ChatRoute(Route): break elif (streaming and type == "complete") or not streaming: # append bot message - conversation = self.db.get_conversation_by_user_id( - username, cid - ) - try: - history = json.loads(conversation.history) - except BaseException as e: - logger.error(f"Failed to parse conversation history: {e}") - history = [] - history.append({"type": "bot", "message": result_text}) - self.db.update_conversation( - username, cid, history=json.dumps(history) + new_his = {"type": "bot", "message": result_text} + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", ) except BaseException as _: @@ -191,11 +187,11 @@ class ChatRoute(Route): return # Put message to conversation-specific queue - chat_queue = webchat_queue_mgr.get_or_create_queue(conversation_id) + chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) await chat_queue.put( ( username, - conversation_id, + webchat_conv_id, { "message": message, "image_url": image_url, # list @@ -217,25 +213,51 @@ class ChatRoute(Route): ) return response + async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str: + """从对话 ID 中提取 WebChat 会话 ID + + NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。 + """ + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin="webchat", conversation_id=conversation_id + ) + if not conversation: + raise ValueError(f"Conversation with ID {conversation_id} not found.") + conv_user_id = conversation.user_id + webchat_session_id = MessageSession.from_str(conv_user_id).session_id + if "!" not in webchat_session_id: + raise ValueError(f"Invalid conv user ID: {conv_user_id}") + return webchat_session_id.split("!")[-1] + async def delete_conversation(self): - username = g.get("username", "guest") conversation_id = request.args.get("conversation_id") if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ + username = g.get("username", "guest") # Clean up queues when deleting conversation webchat_queue_mgr.remove_queues(conversation_id) - self.db.delete_conversation(username, conversation_id) + webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + await self.conv_mgr.delete_conversation( + unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", + conversation_id=conversation_id, + ) + await self.platform_history_mgr.delete( + platform_id="webchat", user_id=webchat_conv_id, offset_sec=99999999 + ) return Response().ok().__dict__ async def new_conversation(self): username = g.get("username", "guest") - conversation_id = str(uuid.uuid4()) - self.db.new_conversation(username, conversation_id) - return Response().ok(data={"conversation_id": conversation_id}).__dict__ + webchat_conv_id = str(uuid.uuid4()) + conv_id = await self.conv_mgr.new_conversation( + unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", + platform_id="webchat", + content=[], + ) + return Response().ok(data={"conversation_id": conv_id}).__dict__ async def rename_conversation(self): - username = g.get("username", "guest") post_data = await request.json if "conversation_id" not in post_data or "title" not in post_data: return Response().error("Missing key: conversation_id or title").__dict__ @@ -243,20 +265,42 @@ class ChatRoute(Route): conversation_id = post_data["conversation_id"] title = post_data["title"] - self.db.update_conversation_title(username, conversation_id, title=title) + await self.conv_mgr.update_conversation( + unified_msg_origin="webchat", # fake + conversation_id=conversation_id, + title=title, + ) return Response().ok(message="重命名成功!").__dict__ async def get_conversations(self): - username = g.get("username", "guest") - conversations = self.db.get_conversations(username) - return Response().ok(data=conversations).__dict__ + conversations = await self.conv_mgr.get_conversations(platform_id="webchat") + # remove content + conversations_ = [] + for conv in conversations: + conv.history = None + conversations_.append(conv) + return Response().ok(data=conversations_).__dict__ async def get_conversation(self): - username = g.get("username", "guest") conversation_id = request.args.get("conversation_id") if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ - conversation = self.db.get_conversation_by_user_id(username, conversation_id) + webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) - return Response().ok(data=conversation).__dict__ + # Get platform message history + history_ls = await self.platform_history_mgr.get( + platform_id="webchat", user_id=webchat_conv_id, page=1, page_size=1000 + ) + + history_res = [history.model_dump() for history in history_ls] + + return ( + Response() + .ok( + data={ + "history": history_res, + } + ) + .__dict__ + ) diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py index dde6f9a5a..fb5d3e10e 100644 --- a/astrbot/dashboard/routes/conversation.py +++ b/astrbot/dashboard/routes/conversation.py @@ -29,6 +29,7 @@ class ConversationRoute(Route): ), } self.db_helper = db_helper + self.conv_mgr = core_lifecycle.conversation_manager self.core_lifecycle = core_lifecycle self.register_routes() @@ -54,7 +55,6 @@ class ConversationRoute(Route): exclude_platforms.split(",") if exclude_platforms else [] ) - # 限制页面大小,防止请求过大数据 if page < 1: page = 1 if page_size < 1: @@ -62,9 +62,11 @@ class ConversationRoute(Route): if page_size > 100: page_size = 100 - # 使用数据库的分页方法获取会话列表和总数,传入筛选条件 try: - conversations, total_count = self.db_helper.get_filtered_conversations( + ( + conversations, + total_count, + ) = await self.conv_mgr.get_filtered_conversations( page=page, page_size=page_size, platforms=platform_list, @@ -108,7 +110,9 @@ class ConversationRoute(Route): if not user_id or not cid: return Response().error("缺少必要参数: user_id 和 cid").__dict__ - conversation = self.db_helper.get_conversation_by_user_id(user_id, cid) + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=user_id, conversation_id=cid + ) if not conversation: return Response().error("对话不存在").__dict__ @@ -143,14 +147,18 @@ class ConversationRoute(Route): if not user_id or not cid: return Response().error("缺少必要参数: user_id 和 cid").__dict__ - conversation = self.db_helper.get_conversation_by_user_id(user_id, cid) + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=user_id, conversation_id=cid + ) if not conversation: return Response().error("对话不存在").__dict__ - if title is not None: - self.db_helper.update_conversation_title(user_id, cid, title) - if persona_id is not None: - self.db_helper.update_conversation_persona_id(user_id, cid, persona_id) - + if title is not None or persona_id is not None: + await self.conv_mgr.update_conversation( + unified_msg_origin=user_id, + conversation_id=cid, + title=title, + persona_id=persona_id, + ) return Response().ok({"message": "对话信息更新成功"}).__dict__ except Exception as e: @@ -201,11 +209,17 @@ class ConversationRoute(Route): Response().error("history 必须是有效的 JSON 字符串或数组").__dict__ ) - conversation = self.db_helper.get_conversation_by_user_id(user_id, cid) + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=user_id, conversation_id=cid + ) if not conversation: return Response().error("对话不存在").__dict__ - self.db_helper.update_conversation(user_id, cid, history) + history = json.loads(history) if isinstance(history, str) else history + + await self.conv_mgr.update_conversation( + unified_msg_origin=user_id, conversation_id=cid, history=history + ) return Response().ok({"message": "对话历史更新成功"}).__dict__ diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index fdcbdbf73..86b526be3 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -32,7 +32,7 @@ class SessionManagementRoute(Route): "/session/update_name": ("POST", self.update_session_name), "/session/update_status": ("POST", self.update_session_status), } - self.db_helper = db_helper + self.conv_mgr = core_lifecycle.conversation_manager self.core_lifecycle = core_lifecycle self.register_routes() @@ -90,8 +90,8 @@ class SessionManagementRoute(Route): } # 获取对话信息 - conversation = self.db_helper.get_conversation_by_user_id( - session_id, conversation_id + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=session_id, conversation_id=conversation_id ) if conversation: session_info["persona_id"] = conversation.persona_id @@ -358,8 +358,8 @@ class SessionManagementRoute(Route): ) # 获取对话信息 - conversation = self.db_helper.get_conversation_by_user_id( - session_id, conversation_id + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=session_id, conversation_id=conversation_id ) if conversation: session_info["persona_id"] = conversation.persona_id diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 79397290e..5bc401a0d 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -86,7 +86,7 @@ class StatRoute(Route): message_time_based_stats = [] idx = 0 - for bucket_end in range(start_time, now, 1800): + for bucket_end in range(start_time, now, 3600): cnt = 0 while ( idx < len(stat.platform) diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 099cf09c0..208a751ab 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -147,24 +147,24 @@