diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 2a6ac4273..17fd52138 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -22,6 +22,7 @@ from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -103,6 +104,13 @@ class AstrBotCoreLifecycle: logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") logger.error(traceback.format_exc()) + # migration for webchat session + try: + await migrate_webchat_session(self.db) + except Exception as e: + logger.error(f"Migration for webchat session failed: {e!s}") + logger.error(traceback.format_exc()) + # 初始化事件队列 self.event_queue = Queue() diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index c62e49289..2af0428d0 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -13,6 +13,7 @@ from astrbot.core.db.po import ( ConversationV2, Persona, PlatformMessageHistory, + PlatformSession, PlatformStat, Preference, Stats, @@ -183,7 +184,7 @@ class BaseDatabase(abc.ABC): user_id: str, offset_sec: int = 86400, ) -> None: - """Delete platform message history records older than the specified offset.""" + """Delete platform message history records newer than the specified offset.""" ... @abc.abstractmethod @@ -313,3 +314,51 @@ class BaseDatabase(abc.ABC): ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details, support search and platform filter.""" ... + + # ==== + # Platform Session Management + # ==== + + @abc.abstractmethod + async def create_platform_session( + self, + creator: str, + platform_id: str = "webchat", + session_id: str | None = None, + display_name: str | None = None, + is_group: int = 0, + ) -> PlatformSession: + """Create a new Platform session.""" + ... + + @abc.abstractmethod + async def get_platform_session_by_id( + self, session_id: str + ) -> PlatformSession | None: + """Get a Platform session by its ID.""" + ... + + @abc.abstractmethod + async def get_platform_sessions_by_creator( + self, + creator: str, + platform_id: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[PlatformSession]: + """Get all Platform sessions for a specific creator (username) and optionally platform.""" + ... + + @abc.abstractmethod + async def update_platform_session( + self, + session_id: str, + display_name: str | None = None, + ) -> None: + """Update a Platform session's updated_at timestamp and optionally display_name.""" + ... + + @abc.abstractmethod + async def delete_platform_session(self, session_id: str) -> None: + """Delete a Platform session by its ID.""" + ... diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py new file mode 100644 index 000000000..6cb483464 --- /dev/null +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -0,0 +1,131 @@ +"""Migration script for WebChat sessions. + +This migration creates PlatformSession from existing platform_message_history records. + +Changes: +- Creates platform_sessions table +- Adds platform_id field (default: 'webchat') +- Adds display_name field +- Session_id format: {platform_id}_{uuid} +""" + +from sqlalchemy import func, select +from sqlmodel import col + +from astrbot.api import logger, sp +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession + + +async def migrate_webchat_session(db_helper: BaseDatabase): + """Create PlatformSession records from platform_message_history. + + This migration extracts all unique user_ids from platform_message_history + where platform_id='webchat' and creates corresponding PlatformSession records. + """ + # 检查是否已经完成迁移 + migration_done = await db_helper.get_preference( + "global", "global", "migration_done_webchat_session" + ) + if migration_done: + return + + logger.info("开始执行数据库迁移(WebChat 会话迁移)...") + + try: + async with db_helper.get_db() as session: + # 从 platform_message_history 创建 PlatformSession + query = ( + select( + col(PlatformMessageHistory.user_id), + col(PlatformMessageHistory.sender_name), + func.min(PlatformMessageHistory.created_at).label("earliest"), + func.max(PlatformMessageHistory.updated_at).label("latest"), + ) + .where(col(PlatformMessageHistory.platform_id) == "webchat") + .where(col(PlatformMessageHistory.sender_id) == "astrbot") + .group_by(col(PlatformMessageHistory.user_id)) + ) + + result = await session.execute(query) + webchat_users = result.all() + + if not webchat_users: + logger.info("没有找到需要迁移的 WebChat 数据") + await sp.put_async( + "global", "global", "migration_done_webchat_session", True + ) + return + + logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") + + # 检查已存在的会话 + existing_query = select(col(PlatformSession.session_id)) + existing_result = await session.execute(existing_query) + existing_session_ids = {row[0] for row in existing_result.fetchall()} + + # 查询 Conversations 表中的 title,用于设置 display_name + # 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id} + user_ids_to_query = [ + f"webchat:FriendMessage:webchat!astrbot!{user_id}" + for user_id, _, _, _ in webchat_users + ] + conv_query = select( + col(ConversationV2.user_id), col(ConversationV2.title) + ).where(col(ConversationV2.user_id).in_(user_ids_to_query)) + conv_result = await session.execute(conv_query) + # 创建 user_id -> title 的映射字典 + title_map = { + user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title + for user_id, title in conv_result.fetchall() + } + + # 批量创建 PlatformSession 记录 + sessions_to_add = [] + skipped_count = 0 + + for user_id, sender_name, created_at, updated_at in webchat_users: + # user_id 就是 webchat_conv_id (session_id) + session_id = user_id + + # sender_name 通常是 username,但可能为 None + creator = sender_name if sender_name else "guest" + + # 检查是否已经存在该会话 + if session_id in existing_session_ids: + logger.debug(f"会话 {session_id} 已存在,跳过") + skipped_count += 1 + continue + + # 从 Conversations 表中获取 display_name + display_name = title_map.get(user_id) + + # 创建新的 PlatformSession(保留原有的时间戳) + new_session = PlatformSession( + session_id=session_id, + platform_id="webchat", + creator=creator, + is_group=0, + created_at=created_at, + updated_at=updated_at, + display_name=display_name, + ) + sessions_to_add.append(new_session) + + # 批量插入 + if sessions_to_add: + session.add_all(sessions_to_add) + await session.commit() + + logger.info( + f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}", + ) + else: + logger.info("没有新会话需要迁移") + + # 标记迁移完成 + await sp.put_async("global", "global", "migration_done_webchat_session", True) + + except Exception as e: + logger.error(f"迁移过程中发生错误: {e}", exc_info=True) + raise diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 5cf25ec13..d6621d072 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -155,6 +155,48 @@ class PlatformMessageHistory(SQLModel, table=True): ) +class PlatformSession(SQLModel, table=True): + """Platform session table for managing user sessions across different platforms. + + A session represents a chat window for a specific user on a specific platform. + Each session can have multiple conversations (对话) associated with it. + """ + + __tablename__ = "platform_sessions" # type: ignore + + inner_id: int | None = Field( + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, + ) + session_id: str = Field( + max_length=100, + nullable=False, + unique=True, + default_factory=lambda: f"webchat_{uuid.uuid4()}", + ) + platform_id: str = Field(default="webchat", nullable=False) + """Platform identifier (e.g., 'webchat', 'qq', 'discord')""" + creator: str = Field(nullable=False) + """Username of the session creator""" + display_name: str | None = Field(default=None, max_length=255) + """Display name for the session""" + is_group: int = Field(default=0, nullable=False) + """0 for private chat, 1 for group chat (not implemented yet)""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "session_id", + name="uix_platform_session_id", + ), + ) + + class Attachment(SQLModel, table=True): """This class represents attachments for messages in AstrBot. diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 457a4ab3f..194618612 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,7 +1,7 @@ import asyncio import threading import typing as T -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, delete, desc, func, or_, select, text, update @@ -12,6 +12,7 @@ from astrbot.core.db.po import ( ConversationV2, Persona, PlatformMessageHistory, + PlatformSession, PlatformStat, Preference, SQLModel, @@ -412,7 +413,7 @@ class SQLiteDatabase(BaseDatabase): user_id, offset_sec=86400, ): - """Delete platform message history records older than the specified offset.""" + """Delete platform message history records newer than the specified offset.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -422,7 +423,7 @@ class SQLiteDatabase(BaseDatabase): delete(PlatformMessageHistory).where( col(PlatformMessageHistory.platform_id) == platform_id, col(PlatformMessageHistory.user_id) == user_id, - col(PlatformMessageHistory.created_at) < cutoff_time, + col(PlatformMessageHistory.created_at) >= cutoff_time, ), ) @@ -709,3 +710,101 @@ class SQLiteDatabase(BaseDatabase): t.start() t.join() return result + + # ==== + # Platform Session Management + # ==== + + async def create_platform_session( + self, + creator: str, + platform_id: str = "webchat", + session_id: str | None = None, + display_name: str | None = None, + is_group: int = 0, + ) -> PlatformSession: + """Create a new Platform session.""" + kwargs = {} + if session_id: + kwargs["session_id"] = session_id + + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_session = PlatformSession( + creator=creator, + platform_id=platform_id, + display_name=display_name, + is_group=is_group, + **kwargs, + ) + session.add(new_session) + await session.flush() + await session.refresh(new_session) + return new_session + + async def get_platform_session_by_id( + self, session_id: str + ) -> PlatformSession | None: + """Get a Platform session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(PlatformSession).where( + PlatformSession.session_id == session_id, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_platform_sessions_by_creator( + self, + creator: str, + platform_id: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[PlatformSession]: + """Get all Platform sessions for a specific creator (username) and optionally platform.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = select(PlatformSession).where(PlatformSession.creator == creator) + + if platform_id: + query = query.where(PlatformSession.platform_id == platform_id) + + query = ( + query.order_by(desc(PlatformSession.updated_at)) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def update_platform_session( + self, + session_id: str, + display_name: str | None = None, + ) -> None: + """Update a Platform session's updated_at timestamp and optionally display_name.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} + if display_name is not None: + values["display_name"] = display_name + + await session.execute( + update(PlatformSession) + .where(col(PlatformSession.session_id == session_id)) + .values(**values), + ) + + async def delete_platform_session(self, session_id: str) -> None: + """Delete a Platform session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(PlatformSession).where( + col(PlatformSession.session_id == session_id), + ), + ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 56946550a..1ad789563 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -10,7 +10,6 @@ from quart import g, make_response, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -36,11 +35,14 @@ class ChatRoute(Route): super().__init__(context) self.routes = { "/chat/send": ("POST", self.chat), - "/chat/new_conversation": ("GET", self.new_conversation), - "/chat/conversations": ("GET", self.get_conversations), - "/chat/get_conversation": ("GET", self.get_conversation), - "/chat/delete_conversation": ("GET", self.delete_conversation), - "/chat/rename_conversation": ("POST", self.rename_conversation), + "/chat/new_session": ("GET", self.new_session), + "/chat/sessions": ("GET", self.get_sessions), + "/chat/get_session": ("GET", self.get_session), + "/chat/delete_session": ("GET", self.delete_webchat_session), + "/chat/update_session_display_name": ( + "POST", + self.update_session_display_name, + ), "/chat/get_file": ("GET", self.get_file), "/chat/post_image": ("POST", self.post_image), "/chat/post_file": ("POST", self.post_file), @@ -53,6 +55,7 @@ class ChatRoute(Route): self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] self.conv_mgr = core_lifecycle.conversation_manager self.platform_history_mgr = core_lifecycle.platform_message_history_manager + self.db = db self.running_convs: dict[str, bool] = {} @@ -116,11 +119,14 @@ class ChatRoute(Route): if "message" not in post_data and "image_url" not in post_data: return Response().error("Missing key: message or image_url").__dict__ - if "conversation_id" not in post_data: - return Response().error("Missing key: conversation_id").__dict__ + if "session_id" not in post_data and "conversation_id" not in post_data: + return ( + Response().error("Missing key: session_id or conversation_id").__dict__ + ) message = post_data["message"] - conversation_id = post_data["conversation_id"] + # conversation_id = post_data["conversation_id"] + session_id = post_data.get("session_id", post_data.get("conversation_id")) image_url = post_data.get("image_url") audio_url = post_data.get("audio_url") selected_provider = post_data.get("selected_provider") @@ -133,11 +139,11 @@ class ChatRoute(Route): .error("Message and image_url and audio_url are empty") .__dict__ ) - if not conversation_id: - return Response().error("conversation_id is empty").__dict__ + if not session_id: + return Response().error("session_id is empty").__dict__ # 追加用户消息 - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + webchat_conv_id = session_id # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) @@ -245,88 +251,110 @@ class ChatRoute(Route): response.timeout = None # fix SSE auto disconnect issue return response - async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str: - """从对话 ID 中提取 WebChat 会话 ID - - NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。 - """ - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin="webchat", - conversation_id=conversation_id, - ) - if not conversation: - raise ValueError(f"Conversation with ID {conversation_id} not found.") - conv_user_id = conversation.user_id - webchat_session_id = MessageSession.from_str(conv_user_id).session_id - if "!" not in webchat_session_id: - raise ValueError(f"Invalid conv user ID: {conv_user_id}") - return webchat_session_id.split("!")[-1] - - async def delete_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + async def delete_webchat_session(self): + """Delete a Platform session and all its related data.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ username = g.get("username", "guest") - # Clean up queues when deleting conversation - webchat_queue_mgr.remove_queues(conversation_id) - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) - await self.conv_mgr.delete_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - conversation_id=conversation_id, - ) + # 验证会话是否存在且属于当前用户 + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 删除该会话下的所有对话 + unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}" + await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) + + # 删除消息历史 await self.platform_history_mgr.delete( - platform_id="webchat", - user_id=webchat_conv_id, + platform_id=session.platform_id, + user_id=session_id, offset_sec=99999999, ) + + # 清理队列(仅对 webchat) + if session.platform_id == "webchat": + webchat_queue_mgr.remove_queues(session_id) + + # 删除会话 + await self.db.delete_platform_session(session_id) + return Response().ok().__dict__ - async def new_conversation(self): + async def new_session(self): + """Create a new Platform session (default: webchat).""" username = g.get("username", "guest") - webchat_conv_id = str(uuid.uuid4()) - conv_id = await self.conv_mgr.new_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - platform_id="webchat", - content=[], + + # 获取可选的 platform_id 参数,默认为 webchat + platform_id = request.args.get("platform_id", "webchat") + + # 创建新会话 + session = await self.db.create_platform_session( + creator=username, + platform_id=platform_id, + is_group=0, ) - return Response().ok(data={"conversation_id": conv_id}).__dict__ - async def rename_conversation(self): - post_data = await request.json - if "conversation_id" not in post_data or "title" not in post_data: - return Response().error("Missing key: conversation_id or title").__dict__ - - conversation_id = post_data["conversation_id"] - title = post_data["title"] - - await self.conv_mgr.update_conversation( - unified_msg_origin="webchat", # fake - conversation_id=conversation_id, - title=title, + return ( + Response() + .ok( + data={ + "session_id": session.session_id, + "platform_id": session.platform_id, + } + ) + .__dict__ ) - return Response().ok(message="重命名成功!").__dict__ - async def get_conversations(self): - conversations = await self.conv_mgr.get_conversations(platform_id="webchat") - # remove content - conversations_ = [] - for conv in conversations: - conv.history = None - conversations_.append(conv) - return Response().ok(data=conversations_).__dict__ + async def get_sessions(self): + """Get all Platform sessions for the current user.""" + username = g.get("username", "guest") - async def get_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + # 获取可选的 platform_id 参数 + platform_id = request.args.get("platform_id") - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + sessions = await self.db.get_platform_sessions_by_creator( + creator=username, + platform_id=platform_id, + page=1, + page_size=100, # 暂时返回前100个 + ) - # Get platform message history + # 转换为字典格式,并添加额外信息 + sessions_data = [] + for session in sessions: + sessions_data.append( + { + "session_id": session.session_id, + "platform_id": session.platform_id, + "creator": session.creator, + "display_name": session.display_name, + "is_group": session.is_group, + "created_at": session.created_at.astimezone().isoformat(), + "updated_at": session.updated_at.astimezone().isoformat(), + } + ) + + return Response().ok(data=sessions_data).__dict__ + + async def get_session(self): + """Get session information and message history by session_id.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ + + # 获取会话信息以确定 platform_id + session = await self.db.get_platform_session_by_id(session_id) + platform_id = session.platform_id if session else "webchat" + + # Get platform message history using session_id history_ls = await self.platform_history_mgr.get( - platform_id="webchat", - user_id=webchat_conv_id, + platform_id=platform_id, + user_id=session_id, page=1, page_size=1000, ) @@ -338,8 +366,37 @@ class ChatRoute(Route): .ok( data={ "history": history_res, - "is_running": self.running_convs.get(webchat_conv_id, False), + "is_running": self.running_convs.get(session_id, False), }, ) .__dict__ ) + + async def update_session_display_name(self): + """Update a Platform session's display name.""" + post_data = await request.json + + session_id = post_data.get("session_id") + display_name = post_data.get("display_name") + + if not session_id: + return Response().error("Missing key: session_id").__dict__ + if display_name is None: + return Response().error("Missing key: display_name").__dict__ + + username = g.get("username", "guest") + + # 验证会话是否存在且属于当前用户 + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 更新 display_name + await self.db.update_platform_session( + session_id=session_id, + display_name=display_name, + ) + + return Response().ok().__dict__ diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index bb3418d67..09acd1b7e 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -6,9 +6,9 @@