From 323ec29b026d2b146741e0ad5c18a7d442ab08cd Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 18 Nov 2025 22:04:26 +0800 Subject: [PATCH 01/27] refactor: Implement WebChat session management and migration from version 4.6 to 4.7 - Added WebChatSession model for managing user sessions. - Introduced methods for creating, retrieving, updating, and deleting WebChat sessions in the database. - Updated core lifecycle to include migration from version 4.6 to 4.7, creating WebChat sessions from existing platform message history. - Refactored chat routes to support new session-based architecture, replacing conversation-related endpoints with session endpoints. - Updated frontend components to handle sessions instead of conversations, including session creation and management. --- astrbot/core/core_lifecycle.py | 8 + astrbot/core/db/__init__.py | 43 +++ astrbot/core/db/migration/migra_46_to_47.py | 103 +++++++ astrbot/core/db/po.py | 38 +++ astrbot/core/db/sqlite.py | 84 ++++++ astrbot/dashboard/routes/chat.py | 138 +++++---- dashboard/src/components/chat/Chat.vue | 304 +++++++++----------- 7 files changed, 484 insertions(+), 234 deletions(-) create mode 100644 astrbot/core/db/migration/migra_46_to_47.py diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 2a6ac4273..676e50384 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -22,6 +22,7 @@ from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_46_to_47 import migrate_46_to_47 from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -103,6 +104,13 @@ class AstrBotCoreLifecycle: logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") logger.error(traceback.format_exc()) + # 4.6 to 4.7 migration for webchat sessions and group feature + try: + await migrate_46_to_47(self.db) + except Exception as e: + logger.error(f"Migration from version 4.6 to 4.7 failed: {e!s}") + logger.error(traceback.format_exc()) + # 初始化事件队列 self.event_queue = Queue() diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index c62e49289..456682bd2 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -16,6 +16,7 @@ from astrbot.core.db.po import ( PlatformStat, Preference, Stats, + WebChatSession, ) @@ -313,3 +314,45 @@ class BaseDatabase(abc.ABC): ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details, support search and platform filter.""" ... + + # ==== + # WebChat Session Management + # ==== + + @abc.abstractmethod + async def create_webchat_session( + self, + creator: str, + session_id: str | None = None, + is_group: int = 0, + ) -> WebChatSession: + """Create a new WebChat session.""" + ... + + @abc.abstractmethod + async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None: + """Get a WebChat session by its ID.""" + ... + + @abc.abstractmethod + async def get_webchat_sessions_by_creator( + self, + creator: str, + page: int = 1, + page_size: int = 20, + ) -> list[WebChatSession]: + """Get all WebChat sessions for a specific creator (username).""" + ... + + @abc.abstractmethod + async def update_webchat_session( + self, + session_id: str, + ) -> None: + """Update a WebChat session's updated_at timestamp.""" + ... + + @abc.abstractmethod + async def delete_webchat_session(self, session_id: str) -> None: + """Delete a WebChat session by its ID.""" + ... diff --git a/astrbot/core/db/migration/migra_46_to_47.py b/astrbot/core/db/migration/migra_46_to_47.py new file mode 100644 index 000000000..407a667c9 --- /dev/null +++ b/astrbot/core/db/migration/migra_46_to_47.py @@ -0,0 +1,103 @@ +"""Migration script from version 4.6 to 4.7. + +This migration creates WebChat sessions from existing platform_message_history records. +""" + +from sqlalchemy import func, select +from sqlmodel import col + +from astrbot.api import logger, sp +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import PlatformMessageHistory + + +async def migrate_46_to_47(db_helper: BaseDatabase): + """Migrate WebChat data to the new session table. + + This migration extracts all unique user_ids from platform_message_history + where platform_id='webchat' and creates corresponding WebChatSession records. + """ + # 检查是否已经完成迁移 + # migration_done = await db_helper.get_preference( + # "global", "global", "migration_done_v47" + # ) + # if migration_done: + # return + + logger.info("开始执行数据库迁移(4.6 -> 4.7)...") + + try: + async with db_helper.get_db() as session: + # 1. 查询所有 webchat 的唯一 user_id 以及它们的最早和最新消息时间 + query = ( + select( + col(PlatformMessageHistory.user_id), + col(PlatformMessageHistory.sender_name), + func.min(PlatformMessageHistory.created_at).label("earliest"), + func.max(PlatformMessageHistory.updated_at).label("latest"), + ) + .where(col(PlatformMessageHistory.platform_id) == "webchat") + .where(col(PlatformMessageHistory.sender_id) == "astrbot") + .group_by(col(PlatformMessageHistory.user_id)) + ) + + result = await session.execute(query) + webchat_users = result.all() + + if not webchat_users: + logger.info("没有找到需要迁移的 WebChat 数据") + await sp.put_async("global", "global", "migration_done_v47", True) + return + + logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") + + # 2. 为每个 user_id 创建 WebChatSession 记录 + migrated_count = 0 + skipped_count = 0 + + for user_id, sender_name, created_at, updated_at in webchat_users: + # user_id 就是 webchat_conv_id (session_id) + session_id = user_id + + # sender_name 通常是 username,但可能为 None + # 从第一条消息中提取 creator + creator = sender_name if sender_name else "guest" + + # 检查是否已经存在该会话 + existing_session = await db_helper.get_webchat_session_by_id(session_id) + if existing_session: + logger.debug(f"会话 {session_id} 已存在,跳过") + skipped_count += 1 + continue + + # 创建新的 WebChatSession + try: + await db_helper.create_webchat_session( + creator=creator, + session_id=session_id, + is_group=0, + ) + + # 更新时间戳以匹配历史记录 + # 注意:这里我们需要直接更新数据库,因为 create 方法会设置当前时间 + # 但我们希望保留原始的创建和更新时间 + + migrated_count += 1 + + if migrated_count % 100 == 0: + logger.info(f"已迁移 {migrated_count} 个会话...") + + except Exception as e: + logger.error(f"迁移会话 {session_id} 失败: {e}") + continue + + logger.info( + f"WebChat 会话迁移完成!成功迁移: {migrated_count}, 跳过: {skipped_count}", + ) + + # 标记迁移完成 + await sp.put_async("global", "global", "migration_done_v47", True) + + except Exception as e: + logger.error(f"迁移过程中发生错误: {e}", exc_info=True) + raise diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 1e7245976..eee4c9dc6 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -161,6 +161,44 @@ class PlatformMessageHistory(SQLModel, table=True): ) +class WebChatSession(SQLModel, table=True): + """WebChat session table for managing user sessions. + + A session represents a chat window for a specific user. Each session can have + multiple conversations (对话) associated with it. + """ + + __tablename__ = "webchat_sessions" + + inner_id: int | None = Field( + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, + ) + session_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + ) + creator: str = Field(nullable=False) + """Username of the session creator""" + is_group: int = Field(default=0, nullable=False) + """0 for private chat, 1 for group chat (not implemented yet)""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "session_id", + name="uix_webchat_session_id", + ), + ) + + class Attachment(SQLModel, table=True): """This class represents attachments for messages in AstrBot. diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 457a4ab3f..b96a2d3ff 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -15,6 +15,7 @@ from astrbot.core.db.po import ( PlatformStat, Preference, SQLModel, + WebChatSession, ) from astrbot.core.db.po import ( Platform as DeprecatedPlatformStat, @@ -709,3 +710,86 @@ class SQLiteDatabase(BaseDatabase): t.start() t.join() return result + + # ==== + # WebChat Session Management + # ==== + + async def create_webchat_session( + self, + creator: str, + session_id: str | None = None, + is_group: int = 0, + ) -> WebChatSession: + """Create a new WebChat session.""" + kwargs = {} + if session_id: + kwargs["session_id"] = session_id + + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_session = WebChatSession( + creator=creator, + is_group=is_group, + **kwargs, + ) + session.add(new_session) + await session.flush() + await session.refresh(new_session) + return new_session + + async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None: + """Get a WebChat session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(WebChatSession).where( + WebChatSession.session_id == session_id, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_webchat_sessions_by_creator( + self, + creator: str, + page: int = 1, + page_size: int = 20, + ) -> list[WebChatSession]: + """Get all WebChat sessions for a specific creator (username).""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = ( + select(WebChatSession) + .where(WebChatSession.creator == creator) + .order_by(desc(WebChatSession.updated_at)) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def update_webchat_session( + self, + session_id: str, + ) -> None: + """Update a WebChat session's updated_at timestamp.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + update(WebChatSession) + .where(WebChatSession.session_id == session_id) + .values(updated_at=datetime.now()), + ) + + async def delete_webchat_session(self, session_id: str) -> None: + """Delete a WebChat session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(WebChatSession).where( + WebChatSession.session_id == session_id, + ), + ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 56946550a..8eacc5c4f 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -10,7 +10,6 @@ from quart import g, make_response, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -36,11 +35,10 @@ class ChatRoute(Route): super().__init__(context) self.routes = { "/chat/send": ("POST", self.chat), - "/chat/new_conversation": ("GET", self.new_conversation), - "/chat/conversations": ("GET", self.get_conversations), - "/chat/get_conversation": ("GET", self.get_conversation), - "/chat/delete_conversation": ("GET", self.delete_conversation), - "/chat/rename_conversation": ("POST", self.rename_conversation), + "/chat/new_session": ("GET", self.new_session), + "/chat/sessions": ("GET", self.get_sessions), + "/chat/get_session": ("GET", self.get_session), + "/chat/delete_session": ("GET", self.delete_webchat_session), "/chat/get_file": ("GET", self.get_file), "/chat/post_image": ("POST", self.post_image), "/chat/post_file": ("POST", self.post_file), @@ -53,6 +51,7 @@ class ChatRoute(Route): self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] self.conv_mgr = core_lifecycle.conversation_manager self.platform_history_mgr = core_lifecycle.platform_message_history_manager + self.db = db self.running_convs: dict[str, bool] = {} @@ -137,7 +136,8 @@ class ChatRoute(Route): return Response().error("conversation_id is empty").__dict__ # 追加用户消息 - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + # conversation_id 现在实际上是 session_id + webchat_conv_id = conversation_id # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) @@ -245,88 +245,86 @@ class ChatRoute(Route): response.timeout = None # fix SSE auto disconnect issue return response - async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str: - """从对话 ID 中提取 WebChat 会话 ID - - NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。 - """ - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin="webchat", - conversation_id=conversation_id, - ) - if not conversation: - raise ValueError(f"Conversation with ID {conversation_id} not found.") - conv_user_id = conversation.user_id - webchat_session_id = MessageSession.from_str(conv_user_id).session_id - if "!" not in webchat_session_id: - raise ValueError(f"Invalid conv user ID: {conv_user_id}") - return webchat_session_id.split("!")[-1] - - async def delete_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + async def delete_webchat_session(self): + """Delete a WebChat session and all its related data.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ username = g.get("username", "guest") - # Clean up queues when deleting conversation - webchat_queue_mgr.remove_queues(conversation_id) - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) - await self.conv_mgr.delete_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - conversation_id=conversation_id, - ) + # 验证会话是否存在且属于当前用户 + session = await self.db.get_webchat_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 删除该会话下的所有对话 + unified_msg_origin = f"webchat:FriendMessage:webchat!{username}!{session_id}" + await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) + + # 删除消息历史 await self.platform_history_mgr.delete( platform_id="webchat", - user_id=webchat_conv_id, + user_id=session_id, offset_sec=99999999, ) + + # 清理队列 + webchat_queue_mgr.remove_queues(session_id) + + # 删除会话 + await self.db.delete_webchat_session(session_id) + return Response().ok().__dict__ - async def new_conversation(self): + async def new_session(self): + """Create a new WebChat session.""" username = g.get("username", "guest") - webchat_conv_id = str(uuid.uuid4()) - conv_id = await self.conv_mgr.new_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - platform_id="webchat", - content=[], + + # 创建新会话 + session = await self.db.create_webchat_session( + creator=username, + is_group=0, ) - return Response().ok(data={"conversation_id": conv_id}).__dict__ - async def rename_conversation(self): - post_data = await request.json - if "conversation_id" not in post_data or "title" not in post_data: - return Response().error("Missing key: conversation_id or title").__dict__ + return Response().ok(data={"session_id": session.session_id}).__dict__ - conversation_id = post_data["conversation_id"] - title = post_data["title"] + async def get_sessions(self): + """Get all WebChat sessions for the current user.""" + username = g.get("username", "guest") - await self.conv_mgr.update_conversation( - unified_msg_origin="webchat", # fake - conversation_id=conversation_id, - title=title, + sessions = await self.db.get_webchat_sessions_by_creator( + creator=username, + page=1, + page_size=100, # 暂时返回前100个 ) - return Response().ok(message="重命名成功!").__dict__ - async def get_conversations(self): - conversations = await self.conv_mgr.get_conversations(platform_id="webchat") - # remove content - conversations_ = [] - for conv in conversations: - conv.history = None - conversations_.append(conv) - return Response().ok(data=conversations_).__dict__ + # 转换为字典格式,并添加额外信息 + sessions_data = [] + for session in sessions: + sessions_data.append( + { + "session_id": session.session_id, + "creator": session.creator, + "is_group": session.is_group, + "created_at": int(session.created_at.timestamp()), + "updated_at": int(session.updated_at.timestamp()), + } + ) - async def get_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + return Response().ok(data=sessions_data).__dict__ - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + async def get_session(self): + """Get session information and message history by session_id.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ - # Get platform message history + # Get platform message history using session_id history_ls = await self.platform_history_mgr.get( platform_id="webchat", - user_id=webchat_conv_id, + user_id=session_id, page=1, page_size=1000, ) @@ -338,7 +336,7 @@ class ChatRoute(Route): .ok( data={ "history": history_res, - "is_running": self.running_convs.get(webchat_conv_id, False), + "is_running": self.running_convs.get(session_id, False), }, ) .__dict__ diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index d671b15b7..69cb5d6a8 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -38,11 +38,11 @@