diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 2a6ac4273..17fd52138 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -22,6 +22,7 @@ from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -103,6 +104,13 @@ class AstrBotCoreLifecycle: logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") logger.error(traceback.format_exc()) + # migration for webchat session + try: + await migrate_webchat_session(self.db) + except Exception as e: + logger.error(f"Migration for webchat session failed: {e!s}") + logger.error(traceback.format_exc()) + # 初始化事件队列 self.event_queue = Queue() diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index c62e49289..2af0428d0 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -13,6 +13,7 @@ from astrbot.core.db.po import ( ConversationV2, Persona, PlatformMessageHistory, + PlatformSession, PlatformStat, Preference, Stats, @@ -183,7 +184,7 @@ class BaseDatabase(abc.ABC): user_id: str, offset_sec: int = 86400, ) -> None: - """Delete platform message history records older than the specified offset.""" + """Delete platform message history records newer than the specified offset.""" ... @abc.abstractmethod @@ -313,3 +314,51 @@ class BaseDatabase(abc.ABC): ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details, support search and platform filter.""" ... + + # ==== + # Platform Session Management + # ==== + + @abc.abstractmethod + async def create_platform_session( + self, + creator: str, + platform_id: str = "webchat", + session_id: str | None = None, + display_name: str | None = None, + is_group: int = 0, + ) -> PlatformSession: + """Create a new Platform session.""" + ... + + @abc.abstractmethod + async def get_platform_session_by_id( + self, session_id: str + ) -> PlatformSession | None: + """Get a Platform session by its ID.""" + ... + + @abc.abstractmethod + async def get_platform_sessions_by_creator( + self, + creator: str, + platform_id: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[PlatformSession]: + """Get all Platform sessions for a specific creator (username) and optionally platform.""" + ... + + @abc.abstractmethod + async def update_platform_session( + self, + session_id: str, + display_name: str | None = None, + ) -> None: + """Update a Platform session's updated_at timestamp and optionally display_name.""" + ... + + @abc.abstractmethod + async def delete_platform_session(self, session_id: str) -> None: + """Delete a Platform session by its ID.""" + ... diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py new file mode 100644 index 000000000..6cb483464 --- /dev/null +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -0,0 +1,131 @@ +"""Migration script for WebChat sessions. + +This migration creates PlatformSession from existing platform_message_history records. + +Changes: +- Creates platform_sessions table +- Adds platform_id field (default: 'webchat') +- Adds display_name field +- Session_id format: {platform_id}_{uuid} +""" + +from sqlalchemy import func, select +from sqlmodel import col + +from astrbot.api import logger, sp +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession + + +async def migrate_webchat_session(db_helper: BaseDatabase): + """Create PlatformSession records from platform_message_history. + + This migration extracts all unique user_ids from platform_message_history + where platform_id='webchat' and creates corresponding PlatformSession records. + """ + # 检查是否已经完成迁移 + migration_done = await db_helper.get_preference( + "global", "global", "migration_done_webchat_session" + ) + if migration_done: + return + + logger.info("开始执行数据库迁移(WebChat 会话迁移)...") + + try: + async with db_helper.get_db() as session: + # 从 platform_message_history 创建 PlatformSession + query = ( + select( + col(PlatformMessageHistory.user_id), + col(PlatformMessageHistory.sender_name), + func.min(PlatformMessageHistory.created_at).label("earliest"), + func.max(PlatformMessageHistory.updated_at).label("latest"), + ) + .where(col(PlatformMessageHistory.platform_id) == "webchat") + .where(col(PlatformMessageHistory.sender_id) == "astrbot") + .group_by(col(PlatformMessageHistory.user_id)) + ) + + result = await session.execute(query) + webchat_users = result.all() + + if not webchat_users: + logger.info("没有找到需要迁移的 WebChat 数据") + await sp.put_async( + "global", "global", "migration_done_webchat_session", True + ) + return + + logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") + + # 检查已存在的会话 + existing_query = select(col(PlatformSession.session_id)) + existing_result = await session.execute(existing_query) + existing_session_ids = {row[0] for row in existing_result.fetchall()} + + # 查询 Conversations 表中的 title,用于设置 display_name + # 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id} + user_ids_to_query = [ + f"webchat:FriendMessage:webchat!astrbot!{user_id}" + for user_id, _, _, _ in webchat_users + ] + conv_query = select( + col(ConversationV2.user_id), col(ConversationV2.title) + ).where(col(ConversationV2.user_id).in_(user_ids_to_query)) + conv_result = await session.execute(conv_query) + # 创建 user_id -> title 的映射字典 + title_map = { + user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title + for user_id, title in conv_result.fetchall() + } + + # 批量创建 PlatformSession 记录 + sessions_to_add = [] + skipped_count = 0 + + for user_id, sender_name, created_at, updated_at in webchat_users: + # user_id 就是 webchat_conv_id (session_id) + session_id = user_id + + # sender_name 通常是 username,但可能为 None + creator = sender_name if sender_name else "guest" + + # 检查是否已经存在该会话 + if session_id in existing_session_ids: + logger.debug(f"会话 {session_id} 已存在,跳过") + skipped_count += 1 + continue + + # 从 Conversations 表中获取 display_name + display_name = title_map.get(user_id) + + # 创建新的 PlatformSession(保留原有的时间戳) + new_session = PlatformSession( + session_id=session_id, + platform_id="webchat", + creator=creator, + is_group=0, + created_at=created_at, + updated_at=updated_at, + display_name=display_name, + ) + sessions_to_add.append(new_session) + + # 批量插入 + if sessions_to_add: + session.add_all(sessions_to_add) + await session.commit() + + logger.info( + f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}", + ) + else: + logger.info("没有新会话需要迁移") + + # 标记迁移完成 + await sp.put_async("global", "global", "migration_done_webchat_session", True) + + except Exception as e: + logger.error(f"迁移过程中发生错误: {e}", exc_info=True) + raise diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 5cf25ec13..d6621d072 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -155,6 +155,48 @@ class PlatformMessageHistory(SQLModel, table=True): ) +class PlatformSession(SQLModel, table=True): + """Platform session table for managing user sessions across different platforms. + + A session represents a chat window for a specific user on a specific platform. + Each session can have multiple conversations (对话) associated with it. + """ + + __tablename__ = "platform_sessions" # type: ignore + + inner_id: int | None = Field( + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, + ) + session_id: str = Field( + max_length=100, + nullable=False, + unique=True, + default_factory=lambda: f"webchat_{uuid.uuid4()}", + ) + platform_id: str = Field(default="webchat", nullable=False) + """Platform identifier (e.g., 'webchat', 'qq', 'discord')""" + creator: str = Field(nullable=False) + """Username of the session creator""" + display_name: str | None = Field(default=None, max_length=255) + """Display name for the session""" + is_group: int = Field(default=0, nullable=False) + """0 for private chat, 1 for group chat (not implemented yet)""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "session_id", + name="uix_platform_session_id", + ), + ) + + class Attachment(SQLModel, table=True): """This class represents attachments for messages in AstrBot. diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 457a4ab3f..194618612 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,7 +1,7 @@ import asyncio import threading import typing as T -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, delete, desc, func, or_, select, text, update @@ -12,6 +12,7 @@ from astrbot.core.db.po import ( ConversationV2, Persona, PlatformMessageHistory, + PlatformSession, PlatformStat, Preference, SQLModel, @@ -412,7 +413,7 @@ class SQLiteDatabase(BaseDatabase): user_id, offset_sec=86400, ): - """Delete platform message history records older than the specified offset.""" + """Delete platform message history records newer than the specified offset.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -422,7 +423,7 @@ class SQLiteDatabase(BaseDatabase): delete(PlatformMessageHistory).where( col(PlatformMessageHistory.platform_id) == platform_id, col(PlatformMessageHistory.user_id) == user_id, - col(PlatformMessageHistory.created_at) < cutoff_time, + col(PlatformMessageHistory.created_at) >= cutoff_time, ), ) @@ -709,3 +710,101 @@ class SQLiteDatabase(BaseDatabase): t.start() t.join() return result + + # ==== + # Platform Session Management + # ==== + + async def create_platform_session( + self, + creator: str, + platform_id: str = "webchat", + session_id: str | None = None, + display_name: str | None = None, + is_group: int = 0, + ) -> PlatformSession: + """Create a new Platform session.""" + kwargs = {} + if session_id: + kwargs["session_id"] = session_id + + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_session = PlatformSession( + creator=creator, + platform_id=platform_id, + display_name=display_name, + is_group=is_group, + **kwargs, + ) + session.add(new_session) + await session.flush() + await session.refresh(new_session) + return new_session + + async def get_platform_session_by_id( + self, session_id: str + ) -> PlatformSession | None: + """Get a Platform session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(PlatformSession).where( + PlatformSession.session_id == session_id, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_platform_sessions_by_creator( + self, + creator: str, + platform_id: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[PlatformSession]: + """Get all Platform sessions for a specific creator (username) and optionally platform.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = select(PlatformSession).where(PlatformSession.creator == creator) + + if platform_id: + query = query.where(PlatformSession.platform_id == platform_id) + + query = ( + query.order_by(desc(PlatformSession.updated_at)) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def update_platform_session( + self, + session_id: str, + display_name: str | None = None, + ) -> None: + """Update a Platform session's updated_at timestamp and optionally display_name.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} + if display_name is not None: + values["display_name"] = display_name + + await session.execute( + update(PlatformSession) + .where(col(PlatformSession.session_id == session_id)) + .values(**values), + ) + + async def delete_platform_session(self, session_id: str) -> None: + """Delete a Platform session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(PlatformSession).where( + col(PlatformSession.session_id == session_id), + ), + ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 56946550a..1ad789563 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -10,7 +10,6 @@ from quart import g, make_response, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -36,11 +35,14 @@ class ChatRoute(Route): super().__init__(context) self.routes = { "/chat/send": ("POST", self.chat), - "/chat/new_conversation": ("GET", self.new_conversation), - "/chat/conversations": ("GET", self.get_conversations), - "/chat/get_conversation": ("GET", self.get_conversation), - "/chat/delete_conversation": ("GET", self.delete_conversation), - "/chat/rename_conversation": ("POST", self.rename_conversation), + "/chat/new_session": ("GET", self.new_session), + "/chat/sessions": ("GET", self.get_sessions), + "/chat/get_session": ("GET", self.get_session), + "/chat/delete_session": ("GET", self.delete_webchat_session), + "/chat/update_session_display_name": ( + "POST", + self.update_session_display_name, + ), "/chat/get_file": ("GET", self.get_file), "/chat/post_image": ("POST", self.post_image), "/chat/post_file": ("POST", self.post_file), @@ -53,6 +55,7 @@ class ChatRoute(Route): self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] self.conv_mgr = core_lifecycle.conversation_manager self.platform_history_mgr = core_lifecycle.platform_message_history_manager + self.db = db self.running_convs: dict[str, bool] = {} @@ -116,11 +119,14 @@ class ChatRoute(Route): if "message" not in post_data and "image_url" not in post_data: return Response().error("Missing key: message or image_url").__dict__ - if "conversation_id" not in post_data: - return Response().error("Missing key: conversation_id").__dict__ + if "session_id" not in post_data and "conversation_id" not in post_data: + return ( + Response().error("Missing key: session_id or conversation_id").__dict__ + ) message = post_data["message"] - conversation_id = post_data["conversation_id"] + # conversation_id = post_data["conversation_id"] + session_id = post_data.get("session_id", post_data.get("conversation_id")) image_url = post_data.get("image_url") audio_url = post_data.get("audio_url") selected_provider = post_data.get("selected_provider") @@ -133,11 +139,11 @@ class ChatRoute(Route): .error("Message and image_url and audio_url are empty") .__dict__ ) - if not conversation_id: - return Response().error("conversation_id is empty").__dict__ + if not session_id: + return Response().error("session_id is empty").__dict__ # 追加用户消息 - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + webchat_conv_id = session_id # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) @@ -245,88 +251,110 @@ class ChatRoute(Route): response.timeout = None # fix SSE auto disconnect issue return response - async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str: - """从对话 ID 中提取 WebChat 会话 ID - - NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。 - """ - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin="webchat", - conversation_id=conversation_id, - ) - if not conversation: - raise ValueError(f"Conversation with ID {conversation_id} not found.") - conv_user_id = conversation.user_id - webchat_session_id = MessageSession.from_str(conv_user_id).session_id - if "!" not in webchat_session_id: - raise ValueError(f"Invalid conv user ID: {conv_user_id}") - return webchat_session_id.split("!")[-1] - - async def delete_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + async def delete_webchat_session(self): + """Delete a Platform session and all its related data.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ username = g.get("username", "guest") - # Clean up queues when deleting conversation - webchat_queue_mgr.remove_queues(conversation_id) - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) - await self.conv_mgr.delete_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - conversation_id=conversation_id, - ) + # 验证会话是否存在且属于当前用户 + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 删除该会话下的所有对话 + unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}" + await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) + + # 删除消息历史 await self.platform_history_mgr.delete( - platform_id="webchat", - user_id=webchat_conv_id, + platform_id=session.platform_id, + user_id=session_id, offset_sec=99999999, ) + + # 清理队列(仅对 webchat) + if session.platform_id == "webchat": + webchat_queue_mgr.remove_queues(session_id) + + # 删除会话 + await self.db.delete_platform_session(session_id) + return Response().ok().__dict__ - async def new_conversation(self): + async def new_session(self): + """Create a new Platform session (default: webchat).""" username = g.get("username", "guest") - webchat_conv_id = str(uuid.uuid4()) - conv_id = await self.conv_mgr.new_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - platform_id="webchat", - content=[], + + # 获取可选的 platform_id 参数,默认为 webchat + platform_id = request.args.get("platform_id", "webchat") + + # 创建新会话 + session = await self.db.create_platform_session( + creator=username, + platform_id=platform_id, + is_group=0, ) - return Response().ok(data={"conversation_id": conv_id}).__dict__ - async def rename_conversation(self): - post_data = await request.json - if "conversation_id" not in post_data or "title" not in post_data: - return Response().error("Missing key: conversation_id or title").__dict__ - - conversation_id = post_data["conversation_id"] - title = post_data["title"] - - await self.conv_mgr.update_conversation( - unified_msg_origin="webchat", # fake - conversation_id=conversation_id, - title=title, + return ( + Response() + .ok( + data={ + "session_id": session.session_id, + "platform_id": session.platform_id, + } + ) + .__dict__ ) - return Response().ok(message="重命名成功!").__dict__ - async def get_conversations(self): - conversations = await self.conv_mgr.get_conversations(platform_id="webchat") - # remove content - conversations_ = [] - for conv in conversations: - conv.history = None - conversations_.append(conv) - return Response().ok(data=conversations_).__dict__ + async def get_sessions(self): + """Get all Platform sessions for the current user.""" + username = g.get("username", "guest") - async def get_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + # 获取可选的 platform_id 参数 + platform_id = request.args.get("platform_id") - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + sessions = await self.db.get_platform_sessions_by_creator( + creator=username, + platform_id=platform_id, + page=1, + page_size=100, # 暂时返回前100个 + ) - # Get platform message history + # 转换为字典格式,并添加额外信息 + sessions_data = [] + for session in sessions: + sessions_data.append( + { + "session_id": session.session_id, + "platform_id": session.platform_id, + "creator": session.creator, + "display_name": session.display_name, + "is_group": session.is_group, + "created_at": session.created_at.astimezone().isoformat(), + "updated_at": session.updated_at.astimezone().isoformat(), + } + ) + + return Response().ok(data=sessions_data).__dict__ + + async def get_session(self): + """Get session information and message history by session_id.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ + + # 获取会话信息以确定 platform_id + session = await self.db.get_platform_session_by_id(session_id) + platform_id = session.platform_id if session else "webchat" + + # Get platform message history using session_id history_ls = await self.platform_history_mgr.get( - platform_id="webchat", - user_id=webchat_conv_id, + platform_id=platform_id, + user_id=session_id, page=1, page_size=1000, ) @@ -338,8 +366,37 @@ class ChatRoute(Route): .ok( data={ "history": history_res, - "is_running": self.running_convs.get(webchat_conv_id, False), + "is_running": self.running_convs.get(session_id, False), }, ) .__dict__ ) + + async def update_session_display_name(self): + """Update a Platform session's display name.""" + post_data = await request.json + + session_id = post_data.get("session_id") + display_name = post_data.get("display_name") + + if not session_id: + return Response().error("Missing key: session_id").__dict__ + if display_name is None: + return Response().error("Missing key: display_name").__dict__ + + username = g.get("username", "guest") + + # 验证会话是否存在且属于当前用户 + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 更新 display_name + await self.db.update_platform_session( + session_id=session_id, + display_name=display_name, + ) + + return Response().ok().__dict__ diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index bb3418d67..09acd1b7e 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -6,9 +6,9 @@
@@ -62,7 +62,7 @@ @@ -142,7 +142,7 @@ import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue'; import MessageList from '@/components/chat/MessageList.vue'; import ConversationSidebar from '@/components/chat/ConversationSidebar.vue'; import ChatInput from '@/components/chat/ChatInput.vue'; -import { useConversations } from '@/composables/useConversations'; +import { useSessions } from '@/composables/useSessions'; import { useMessages } from '@/composables/useMessages'; import { useMediaHandling } from '@/composables/useMediaHandling'; import { useRecording } from '@/composables/useRecording'; @@ -169,22 +169,22 @@ const previewImageUrl = ref(''); // 使用 composables const { - conversations, - selectedConversations, - currCid, - pendingCid, + sessions, + selectedSessions, + currSessionId, + pendingSessionId, editTitleDialog, editingTitle, - editingCid, - getCurrentConversation, - getConversations, - newConversation, - deleteConversation: deleteConv, + editingSessionId, + getCurrentSession, + getSessions, + newSession, + deleteSession: deleteSessionFn, showEditTitleDialog, saveTitle, - updateConversationTitle, + updateSessionTitle, newChat -} = useConversations(props.chatboxMode); +} = useSessions(props.chatboxMode); const { stagedImagesName, @@ -206,10 +206,10 @@ const { isStreaming, isConvRunning, enableStreaming, - getConversationMessages: getConvMessages, + getSessionMessages: getSessionMsg, sendMessage: sendMsg, toggleStreaming -} = useMessages(currCid, getMediaFile, updateConversationTitle, getConversations); +} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); // 组件引用 const messageList = ref | null>(null); @@ -248,13 +248,13 @@ function openImagePreview(imageUrl: string) { imagePreviewDialog.value = true; } -async function handleSelectConversation(cids: string[]) { - if (!cids[0]) return; +async function handleSelectConversation(sessionIds: string[]) { + if (!sessionIds[0]) return; // 更新 URL const basePath = props.chatboxMode ? '/chatbox' : '/chat'; - if (route.path !== `${basePath}/${cids[0]}`) { - router.push(`${basePath}/${cids[0]}`); + if (route.path !== `${basePath}/${sessionIds[0]}`) { + router.push(`${basePath}/${sessionIds[0]}`); return; } @@ -263,10 +263,10 @@ async function handleSelectConversation(cids: string[]) { closeMobileSidebar(); } - currCid.value = cids[0]; - selectedConversations.value = [cids[0]]; + currSessionId.value = sessionIds[0]; + selectedSessions.value = [sessionIds[0]]; - await getConvMessages(cids[0], router); + await getSessionMsg(sessionIds[0], router); nextTick(() => { messageList.value?.scrollToBottom(); @@ -278,8 +278,8 @@ function handleNewChat() { messages.value = []; } -async function handleDeleteConversation(cid: string) { - await deleteConv(cid); +async function handleDeleteConversation(sessionId: string) { + await deleteSessionFn(sessionId); messages.value = []; } @@ -303,8 +303,8 @@ async function handleSendMessage() { return; } - if (!currCid.value) { - await newConversation(); + if (!currSessionId.value) { + await newSession(); } const promptToSend = prompt.value.trim(); @@ -340,15 +340,15 @@ watch( } if (to.startsWith('/chat/') || to.startsWith('/chatbox/')) { - const pathCid = to.split('/')[2]; - if (pathCid && pathCid !== currCid.value) { - if (conversations.value.length > 0) { - const conversation = conversations.value.find(c => c.cid === pathCid); - if (conversation) { - handleSelectConversation([pathCid]); + const pathSessionId = to.split('/')[2]; + if (pathSessionId && pathSessionId !== currSessionId.value) { + if (sessions.value.length > 0) { + const session = sessions.value.find(s => s.session_id === pathSessionId); + if (session) { + handleSelectConversation([pathSessionId]); } } else { - pendingCid.value = pathCid; + pendingSessionId.value = pathSessionId; } } } @@ -357,25 +357,25 @@ watch( ); // 会话列表加载后处理待定会话 -watch(conversations, (newConversations) => { - if (pendingCid.value && newConversations.length > 0) { - const conversation = newConversations.find(c => c.cid === pendingCid.value); - if (conversation) { - selectedConversations.value = [pendingCid.value]; - handleSelectConversation([pendingCid.value]); - pendingCid.value = null; +watch(sessions, (newSessions) => { + if (pendingSessionId.value && newSessions.length > 0) { + const session = newSessions.find(s => s.session_id === pendingSessionId.value); + if (session) { + selectedSessions.value = [pendingSessionId.value]; + handleSelectConversation([pendingSessionId.value]); + pendingSessionId.value = null; } - } else if (!currCid.value && newConversations.length > 0) { - const firstConversation = newConversations[0]; - selectedConversations.value = [firstConversation.cid]; - handleSelectConversation([firstConversation.cid]); + } else if (!currSessionId.value && newSessions.length > 0) { + const firstSession = newSessions[0]; + selectedSessions.value = [firstSession.session_id]; + handleSelectConversation([firstSession.session_id]); } }); onMounted(() => { checkMobile(); window.addEventListener('resize', checkMobile); - getConversations(); + getSessions(); }); onBeforeUnmount(() => { @@ -506,4 +506,4 @@ onBeforeUnmount(() => { padding: 0 !important; } } - \ No newline at end of file + diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue index 80b574c7f..5abc1bed8 100644 --- a/dashboard/src/components/chat/ConversationSidebar.vue +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -31,10 +31,10 @@
- {{ tm('actions.newChat') }} -
@@ -44,27 +44,27 @@
- + - - {{ item.title || tm('conversation.newConversation') }} + {{ item.display_name || tm('conversation.newConversation') }} - {{ formatDate(item.updated_at) }} + {{ new Date(item.updated_at).toLocaleString() }} @@ -72,7 +72,7 @@ -
+
{{ tm('conversation.noHistory') }} @@ -86,12 +86,12 @@ + diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts index dc243f1b4..5e1a6c7a1 100644 --- a/dashboard/src/composables/useMessages.ts +++ b/dashboard/src/composables/useMessages.ts @@ -18,10 +18,10 @@ export interface Message { } export function useMessages( - currCid: Ref, + currSessionId: Ref, getMediaFile: (filename: string) => Promise, - updateConversationTitle: (cid: string, title: string) => void, - onConversationsUpdate: () => void + updateSessionTitle: (sessionId: string, title: string) => void, + onSessionsUpdate: () => void ) { const messages = ref([]); const isStreaming = ref(false); @@ -41,23 +41,23 @@ export function useMessages( localStorage.setItem('enableStreaming', JSON.stringify(enableStreaming.value)); } - async function getConversationMessages(cid: string, router: any) { - if (!cid) return; + async function getSessionMessages(sessionId: string, router: any) { + if (!sessionId) return; try { - const response = await axios.get('/api/chat/get_conversation?conversation_id=' + cid); + const response = await axios.get('/api/chat/get_session?session_id=' + sessionId); isConvRunning.value = response.data.data.is_running || false; let history = response.data.data.history; if (isConvRunning.value) { if (!isToastedRunningInfo.value) { - useToast().info("该对话正在运行中。", { timeout: 5000 }); + useToast().info("该会话正在运行中。", { timeout: 5000 }); isToastedRunningInfo.value = true; } - // 如果对话还在运行,3秒后重新获取消息 + // 如果会话还在运行,3秒后重新获取消息 setTimeout(() => { - getConversationMessages(currCid.value, router); + getSessionMessages(currSessionId.value, router); }, 3000); } @@ -159,7 +159,7 @@ export function useMessages( }, body: JSON.stringify({ message: prompt, - conversation_id: currCid.value, + session_id: currSessionId.value, image_url: imageNames, audio_url: audioName ? [audioName] : [], selected_provider: selectedProviderId, @@ -256,7 +256,7 @@ export function useMessages( } } } else if (chunk_json.type === 'update_title') { - updateConversationTitle(chunk_json.cid, chunk_json.data); + updateSessionTitle(chunk_json.session_id, chunk_json.data); } if ((chunk_json.type === 'break' && chunk_json.streaming) || !chunk_json.streaming) { @@ -272,8 +272,8 @@ export function useMessages( } } - // 获取最新的对话列表 - onConversationsUpdate(); + // 获取最新的会话列表 + onSessionsUpdate(); } catch (err) { console.error('发送消息失败:', err); @@ -296,8 +296,9 @@ export function useMessages( isStreaming, isConvRunning, enableStreaming, - getConversationMessages, + getSessionMessages, sendMessage, toggleStreaming }; } + diff --git a/dashboard/src/composables/useSessions.ts b/dashboard/src/composables/useSessions.ts new file mode 100644 index 000000000..f14e3aa11 --- /dev/null +++ b/dashboard/src/composables/useSessions.ts @@ -0,0 +1,145 @@ +import { ref, computed } from 'vue'; +import axios from 'axios'; +import { useRouter } from 'vue-router'; + +export interface Session { + session_id: string; + display_name: string; + updated_at: string; +} + +export function useSessions(chatboxMode: boolean = false) { + const router = useRouter(); + const sessions = ref([]); + const selectedSessions = ref([]); + const currSessionId = ref(''); + const pendingSessionId = ref(null); + + // 编辑标题相关 + const editTitleDialog = ref(false); + const editingTitle = ref(''); + const editingSessionId = ref(''); + + const getCurrentSession = computed(() => { + if (!currSessionId.value) return null; + return sessions.value.find(s => s.session_id === currSessionId.value); + }); + + async function getSessions() { + try { + const response = await axios.get('/api/chat/sessions'); + sessions.value = response.data.data; + + // 处理待加载的会话 + if (pendingSessionId.value) { + const session = sessions.value.find(s => s.session_id === pendingSessionId.value); + if (session) { + selectedSessions.value = [pendingSessionId.value]; + pendingSessionId.value = null; + } + } else if (!currSessionId.value && sessions.value.length > 0) { + // 默认选择第一个会话 + const firstSession = sessions.value[0]; + selectedSessions.value = [firstSession.session_id]; + } + } catch (err: any) { + if (err.response?.status === 401) { + router.push('/auth/login?redirect=/chatbox'); + } + console.error(err); + } + } + + async function newSession() { + try { + const response = await axios.get('/api/chat/new_session'); + const sessionId = response.data.data.session_id; + currSessionId.value = sessionId; + + // 更新 URL + const basePath = chatboxMode ? '/chatbox' : '/chat'; + router.push(`${basePath}/${sessionId}`); + + await getSessions(); + return sessionId; + } catch (err) { + console.error(err); + throw err; + } + } + + async function deleteSession(sessionId: string) { + try { + await axios.get('/api/chat/delete_session?session_id=' + sessionId); + await getSessions(); + currSessionId.value = ''; + selectedSessions.value = []; + } catch (err) { + console.error(err); + } + } + + function showEditTitleDialog(sessionId: string, title: string) { + editingSessionId.value = sessionId; + editingTitle.value = title || ''; + editTitleDialog.value = true; + } + + async function saveTitle() { + if (!editingSessionId.value) return; + + const trimmedTitle = editingTitle.value.trim(); + try { + await axios.post('/api/chat/update_session_display_name', { + session_id: editingSessionId.value, + display_name: trimmedTitle + }); + + // 更新本地会话标题 + const session = sessions.value.find(s => s.session_id === editingSessionId.value); + if (session) { + session.display_name = trimmedTitle; + } + editTitleDialog.value = false; + } catch (err) { + console.error('重命名会话失败:', err); + } + } + + function updateSessionTitle(sessionId: string, title: string) { + const session = sessions.value.find(s => s.session_id === sessionId); + if (session) { + session.display_name = title; + } + } + + function newChat(closeMobileSidebar?: () => void) { + currSessionId.value = ''; + selectedSessions.value = []; + + const basePath = chatboxMode ? '/chatbox' : '/chat'; + router.push(basePath); + + if (closeMobileSidebar) { + closeMobileSidebar(); + } + } + + return { + sessions, + selectedSessions, + currSessionId, + pendingSessionId, + editTitleDialog, + editingTitle, + editingSessionId, + getCurrentSession, + getSessions, + newSession, + deleteSession, + showEditTitleDialog, + saveTitle, + updateSessionTitle, + newChat + }; +} diff --git a/dashboard/src/i18n/locales/en-US/features/chat.json b/dashboard/src/i18n/locales/en-US/features/chat.json index bce552aff..bc83787ff 100644 --- a/dashboard/src/i18n/locales/en-US/features/chat.json +++ b/dashboard/src/i18n/locales/en-US/features/chat.json @@ -47,7 +47,11 @@ "noHistory": "No conversation history", "systemStatus": "System Status", "llmService": "LLM Service", - "speechToText": "Speech to Text" + "speechToText": "Speech to Text", + "editDisplayName": "Edit Session Name", + "displayName": "Session Name", + "displayNameUpdated": "Session name updated", + "displayNameUpdateFailed": "Failed to update session name" }, "modes": { "darkMode": "Switch to Dark Mode", diff --git a/dashboard/src/i18n/locales/zh-CN/features/chat.json b/dashboard/src/i18n/locales/zh-CN/features/chat.json index 37bacd408..002ea626e 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/chat.json +++ b/dashboard/src/i18n/locales/zh-CN/features/chat.json @@ -43,11 +43,15 @@ "exitFullscreen": "退出全屏" }, "conversation": { - "newConversation": "新对话", + "newConversation": "新的聊天", "noHistory": "暂无对话历史", "systemStatus": "系统状态", "llmService": "LLM 服务", - "speechToText": "语音转文本" + "speechToText": "语音转文本", + "editDisplayName": "编辑会话名称", + "displayName": "会话名称", + "displayNameUpdated": "会话名称已更新", + "displayNameUpdateFailed": "更新会话名称失败" }, "modes": { "darkMode": "切换到夜间模式",