From 323ec29b026d2b146741e0ad5c18a7d442ab08cd Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 18 Nov 2025 22:04:26 +0800 Subject: [PATCH 01/27] refactor: Implement WebChat session management and migration from version 4.6 to 4.7 - Added WebChatSession model for managing user sessions. - Introduced methods for creating, retrieving, updating, and deleting WebChat sessions in the database. - Updated core lifecycle to include migration from version 4.6 to 4.7, creating WebChat sessions from existing platform message history. - Refactored chat routes to support new session-based architecture, replacing conversation-related endpoints with session endpoints. - Updated frontend components to handle sessions instead of conversations, including session creation and management. --- astrbot/core/core_lifecycle.py | 8 + astrbot/core/db/__init__.py | 43 +++ astrbot/core/db/migration/migra_46_to_47.py | 103 +++++++ astrbot/core/db/po.py | 38 +++ astrbot/core/db/sqlite.py | 84 ++++++ astrbot/dashboard/routes/chat.py | 138 +++++---- dashboard/src/components/chat/Chat.vue | 304 +++++++++----------- 7 files changed, 484 insertions(+), 234 deletions(-) create mode 100644 astrbot/core/db/migration/migra_46_to_47.py diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 2a6ac4273..676e50384 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -22,6 +22,7 @@ from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_46_to_47 import migrate_46_to_47 from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -103,6 +104,13 @@ class AstrBotCoreLifecycle: logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") logger.error(traceback.format_exc()) + # 4.6 to 4.7 migration for webchat sessions and group feature + try: + await migrate_46_to_47(self.db) + except Exception as e: + logger.error(f"Migration from version 4.6 to 4.7 failed: {e!s}") + logger.error(traceback.format_exc()) + # 初始化事件队列 self.event_queue = Queue() diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index c62e49289..456682bd2 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -16,6 +16,7 @@ from astrbot.core.db.po import ( PlatformStat, Preference, Stats, + WebChatSession, ) @@ -313,3 +314,45 @@ class BaseDatabase(abc.ABC): ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details, support search and platform filter.""" ... + + # ==== + # WebChat Session Management + # ==== + + @abc.abstractmethod + async def create_webchat_session( + self, + creator: str, + session_id: str | None = None, + is_group: int = 0, + ) -> WebChatSession: + """Create a new WebChat session.""" + ... + + @abc.abstractmethod + async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None: + """Get a WebChat session by its ID.""" + ... + + @abc.abstractmethod + async def get_webchat_sessions_by_creator( + self, + creator: str, + page: int = 1, + page_size: int = 20, + ) -> list[WebChatSession]: + """Get all WebChat sessions for a specific creator (username).""" + ... + + @abc.abstractmethod + async def update_webchat_session( + self, + session_id: str, + ) -> None: + """Update a WebChat session's updated_at timestamp.""" + ... + + @abc.abstractmethod + async def delete_webchat_session(self, session_id: str) -> None: + """Delete a WebChat session by its ID.""" + ... diff --git a/astrbot/core/db/migration/migra_46_to_47.py b/astrbot/core/db/migration/migra_46_to_47.py new file mode 100644 index 000000000..407a667c9 --- /dev/null +++ b/astrbot/core/db/migration/migra_46_to_47.py @@ -0,0 +1,103 @@ +"""Migration script from version 4.6 to 4.7. + +This migration creates WebChat sessions from existing platform_message_history records. +""" + +from sqlalchemy import func, select +from sqlmodel import col + +from astrbot.api import logger, sp +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import PlatformMessageHistory + + +async def migrate_46_to_47(db_helper: BaseDatabase): + """Migrate WebChat data to the new session table. + + This migration extracts all unique user_ids from platform_message_history + where platform_id='webchat' and creates corresponding WebChatSession records. + """ + # 检查是否已经完成迁移 + # migration_done = await db_helper.get_preference( + # "global", "global", "migration_done_v47" + # ) + # if migration_done: + # return + + logger.info("开始执行数据库迁移(4.6 -> 4.7)...") + + try: + async with db_helper.get_db() as session: + # 1. 查询所有 webchat 的唯一 user_id 以及它们的最早和最新消息时间 + query = ( + select( + col(PlatformMessageHistory.user_id), + col(PlatformMessageHistory.sender_name), + func.min(PlatformMessageHistory.created_at).label("earliest"), + func.max(PlatformMessageHistory.updated_at).label("latest"), + ) + .where(col(PlatformMessageHistory.platform_id) == "webchat") + .where(col(PlatformMessageHistory.sender_id) == "astrbot") + .group_by(col(PlatformMessageHistory.user_id)) + ) + + result = await session.execute(query) + webchat_users = result.all() + + if not webchat_users: + logger.info("没有找到需要迁移的 WebChat 数据") + await sp.put_async("global", "global", "migration_done_v47", True) + return + + logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") + + # 2. 为每个 user_id 创建 WebChatSession 记录 + migrated_count = 0 + skipped_count = 0 + + for user_id, sender_name, created_at, updated_at in webchat_users: + # user_id 就是 webchat_conv_id (session_id) + session_id = user_id + + # sender_name 通常是 username,但可能为 None + # 从第一条消息中提取 creator + creator = sender_name if sender_name else "guest" + + # 检查是否已经存在该会话 + existing_session = await db_helper.get_webchat_session_by_id(session_id) + if existing_session: + logger.debug(f"会话 {session_id} 已存在,跳过") + skipped_count += 1 + continue + + # 创建新的 WebChatSession + try: + await db_helper.create_webchat_session( + creator=creator, + session_id=session_id, + is_group=0, + ) + + # 更新时间戳以匹配历史记录 + # 注意:这里我们需要直接更新数据库,因为 create 方法会设置当前时间 + # 但我们希望保留原始的创建和更新时间 + + migrated_count += 1 + + if migrated_count % 100 == 0: + logger.info(f"已迁移 {migrated_count} 个会话...") + + except Exception as e: + logger.error(f"迁移会话 {session_id} 失败: {e}") + continue + + logger.info( + f"WebChat 会话迁移完成!成功迁移: {migrated_count}, 跳过: {skipped_count}", + ) + + # 标记迁移完成 + await sp.put_async("global", "global", "migration_done_v47", True) + + except Exception as e: + logger.error(f"迁移过程中发生错误: {e}", exc_info=True) + raise diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 1e7245976..eee4c9dc6 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -161,6 +161,44 @@ class PlatformMessageHistory(SQLModel, table=True): ) +class WebChatSession(SQLModel, table=True): + """WebChat session table for managing user sessions. + + A session represents a chat window for a specific user. Each session can have + multiple conversations (对话) associated with it. + """ + + __tablename__ = "webchat_sessions" + + inner_id: int | None = Field( + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, + ) + session_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + ) + creator: str = Field(nullable=False) + """Username of the session creator""" + is_group: int = Field(default=0, nullable=False) + """0 for private chat, 1 for group chat (not implemented yet)""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "session_id", + name="uix_webchat_session_id", + ), + ) + + class Attachment(SQLModel, table=True): """This class represents attachments for messages in AstrBot. diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 457a4ab3f..b96a2d3ff 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -15,6 +15,7 @@ from astrbot.core.db.po import ( PlatformStat, Preference, SQLModel, + WebChatSession, ) from astrbot.core.db.po import ( Platform as DeprecatedPlatformStat, @@ -709,3 +710,86 @@ class SQLiteDatabase(BaseDatabase): t.start() t.join() return result + + # ==== + # WebChat Session Management + # ==== + + async def create_webchat_session( + self, + creator: str, + session_id: str | None = None, + is_group: int = 0, + ) -> WebChatSession: + """Create a new WebChat session.""" + kwargs = {} + if session_id: + kwargs["session_id"] = session_id + + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_session = WebChatSession( + creator=creator, + is_group=is_group, + **kwargs, + ) + session.add(new_session) + await session.flush() + await session.refresh(new_session) + return new_session + + async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None: + """Get a WebChat session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(WebChatSession).where( + WebChatSession.session_id == session_id, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_webchat_sessions_by_creator( + self, + creator: str, + page: int = 1, + page_size: int = 20, + ) -> list[WebChatSession]: + """Get all WebChat sessions for a specific creator (username).""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = ( + select(WebChatSession) + .where(WebChatSession.creator == creator) + .order_by(desc(WebChatSession.updated_at)) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def update_webchat_session( + self, + session_id: str, + ) -> None: + """Update a WebChat session's updated_at timestamp.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + update(WebChatSession) + .where(WebChatSession.session_id == session_id) + .values(updated_at=datetime.now()), + ) + + async def delete_webchat_session(self, session_id: str) -> None: + """Delete a WebChat session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(WebChatSession).where( + WebChatSession.session_id == session_id, + ), + ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 56946550a..8eacc5c4f 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -10,7 +10,6 @@ from quart import g, make_response, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -36,11 +35,10 @@ class ChatRoute(Route): super().__init__(context) self.routes = { "/chat/send": ("POST", self.chat), - "/chat/new_conversation": ("GET", self.new_conversation), - "/chat/conversations": ("GET", self.get_conversations), - "/chat/get_conversation": ("GET", self.get_conversation), - "/chat/delete_conversation": ("GET", self.delete_conversation), - "/chat/rename_conversation": ("POST", self.rename_conversation), + "/chat/new_session": ("GET", self.new_session), + "/chat/sessions": ("GET", self.get_sessions), + "/chat/get_session": ("GET", self.get_session), + "/chat/delete_session": ("GET", self.delete_webchat_session), "/chat/get_file": ("GET", self.get_file), "/chat/post_image": ("POST", self.post_image), "/chat/post_file": ("POST", self.post_file), @@ -53,6 +51,7 @@ class ChatRoute(Route): self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] self.conv_mgr = core_lifecycle.conversation_manager self.platform_history_mgr = core_lifecycle.platform_message_history_manager + self.db = db self.running_convs: dict[str, bool] = {} @@ -137,7 +136,8 @@ class ChatRoute(Route): return Response().error("conversation_id is empty").__dict__ # 追加用户消息 - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + # conversation_id 现在实际上是 session_id + webchat_conv_id = conversation_id # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) @@ -245,88 +245,86 @@ class ChatRoute(Route): response.timeout = None # fix SSE auto disconnect issue return response - async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str: - """从对话 ID 中提取 WebChat 会话 ID - - NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。 - """ - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin="webchat", - conversation_id=conversation_id, - ) - if not conversation: - raise ValueError(f"Conversation with ID {conversation_id} not found.") - conv_user_id = conversation.user_id - webchat_session_id = MessageSession.from_str(conv_user_id).session_id - if "!" not in webchat_session_id: - raise ValueError(f"Invalid conv user ID: {conv_user_id}") - return webchat_session_id.split("!")[-1] - - async def delete_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + async def delete_webchat_session(self): + """Delete a WebChat session and all its related data.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ username = g.get("username", "guest") - # Clean up queues when deleting conversation - webchat_queue_mgr.remove_queues(conversation_id) - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) - await self.conv_mgr.delete_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - conversation_id=conversation_id, - ) + # 验证会话是否存在且属于当前用户 + session = await self.db.get_webchat_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 删除该会话下的所有对话 + unified_msg_origin = f"webchat:FriendMessage:webchat!{username}!{session_id}" + await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) + + # 删除消息历史 await self.platform_history_mgr.delete( platform_id="webchat", - user_id=webchat_conv_id, + user_id=session_id, offset_sec=99999999, ) + + # 清理队列 + webchat_queue_mgr.remove_queues(session_id) + + # 删除会话 + await self.db.delete_webchat_session(session_id) + return Response().ok().__dict__ - async def new_conversation(self): + async def new_session(self): + """Create a new WebChat session.""" username = g.get("username", "guest") - webchat_conv_id = str(uuid.uuid4()) - conv_id = await self.conv_mgr.new_conversation( - unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", - platform_id="webchat", - content=[], + + # 创建新会话 + session = await self.db.create_webchat_session( + creator=username, + is_group=0, ) - return Response().ok(data={"conversation_id": conv_id}).__dict__ - async def rename_conversation(self): - post_data = await request.json - if "conversation_id" not in post_data or "title" not in post_data: - return Response().error("Missing key: conversation_id or title").__dict__ + return Response().ok(data={"session_id": session.session_id}).__dict__ - conversation_id = post_data["conversation_id"] - title = post_data["title"] + async def get_sessions(self): + """Get all WebChat sessions for the current user.""" + username = g.get("username", "guest") - await self.conv_mgr.update_conversation( - unified_msg_origin="webchat", # fake - conversation_id=conversation_id, - title=title, + sessions = await self.db.get_webchat_sessions_by_creator( + creator=username, + page=1, + page_size=100, # 暂时返回前100个 ) - return Response().ok(message="重命名成功!").__dict__ - async def get_conversations(self): - conversations = await self.conv_mgr.get_conversations(platform_id="webchat") - # remove content - conversations_ = [] - for conv in conversations: - conv.history = None - conversations_.append(conv) - return Response().ok(data=conversations_).__dict__ + # 转换为字典格式,并添加额外信息 + sessions_data = [] + for session in sessions: + sessions_data.append( + { + "session_id": session.session_id, + "creator": session.creator, + "is_group": session.is_group, + "created_at": int(session.created_at.timestamp()), + "updated_at": int(session.updated_at.timestamp()), + } + ) - async def get_conversation(self): - conversation_id = request.args.get("conversation_id") - if not conversation_id: - return Response().error("Missing key: conversation_id").__dict__ + return Response().ok(data=sessions_data).__dict__ - webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + async def get_session(self): + """Get session information and message history by session_id.""" + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ - # Get platform message history + # Get platform message history using session_id history_ls = await self.platform_history_mgr.get( platform_id="webchat", - user_id=webchat_conv_id, + user_id=session_id, page=1, page_size=1000, ) @@ -338,7 +336,7 @@ class ChatRoute(Route): .ok( data={ "history": history_res, - "is_running": self.running_convs.get(webchat_conv_id, False), + "is_running": self.running_convs.get(session_id, False), }, ) .__dict__ diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index d671b15b7..69cb5d6a8 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -38,11 +38,11 @@
- {{ tm('actions.newChat') }} -
@@ -52,26 +52,24 @@
- - - - {{ item.title - || tm('conversation.newConversation') }} - {{ - formatDate(item.updated_at) - }} + + + + + {{ tm('conversation.newConversation') }} + + + {{ formatDate(session.updated_at) }} + @@ -79,9 +77,9 @@ -
+
-
+
{{ tm('conversation.noHistory') }}
@@ -109,7 +107,7 @@ @@ -131,7 +129,7 @@ @@ -217,21 +215,6 @@
- - - - {{ tm('actions.editTitle') }} - - - - - - {{ t('core.common.cancel') }} - {{ t('core.common.save') }} - - - @@ -289,9 +272,9 @@ export default { return { prompt: '', messages: [], - conversations: [], - selectedConversations: [], // 用于控制左侧列表的选中状态 - currCid: '', + sessions: [], // WebChat 会话列表 + selectedSessions: [], // 当前选中的会话 + currSessionId: '', // 当前会话ID stagedImagesName: [], // 用于存储图片文件名的数组 stagedImagesUrl: [], // 用于存储图片的blob URL数组 loadingChat: false, @@ -310,10 +293,6 @@ export default { mediaCache: {}, // Add a cache to store media blobs - // 添加对话标题编辑相关变量 - editTitleDialog: false, - editingTitle: '', - editingCid: '', // 侧边栏折叠状态 sidebarCollapsed: true, @@ -346,10 +325,10 @@ export default { isDark() { return useCustomizerStore().uiTheme === 'PurpleThemeDark'; }, - // Get the current conversation from the conversations array - getCurrentConversation() { - if (!this.currCid) return null; - return this.conversations.find(c => c.cid === this.currCid); + // Get the current session from the sessions array + getCurrentSession() { + if (!this.currSessionId) return null; + return this.sessions.find(s => s.session_id === this.currSessionId); } }, @@ -364,43 +343,43 @@ export default { (from.path.startsWith('/chatbox') && to.path.startsWith('/chat')))) { } - // Check if the route matches /chat/ or /chatbox/ pattern + // Check if the route matches /chat/ or /chatbox/ pattern if (to.path.startsWith('/chat/') || to.path.startsWith('/chatbox/')) { - const pathCid = to.path.split('/')[2]; - console.log('Path CID:', pathCid); - if (pathCid && pathCid !== this.currCid) { - // If conversations are already loaded - if (this.conversations.length > 0) { - const conversation = this.conversations.find(c => c.cid === pathCid); - if (conversation) { - this.getConversationMessages([pathCid]); + const pathSessionId = to.path.split('/')[2]; + console.log('Path Session ID:', pathSessionId); + if (pathSessionId && pathSessionId !== this.currSessionId) { + // If sessions are already loaded + if (this.sessions.length > 0) { + const session = this.sessions.find(s => s.session_id === pathSessionId); + if (session) { + this.getSessionMessages([pathSessionId]); } } else { - // Store the cid to be used after conversations are loaded - this.pendingCid = pathCid; + // Store the session_id to be used after sessions are loaded + this.pendingCid = pathSessionId; } } } } }, - // Watch for conversations loaded to handle pending cid - conversations: { - handler(newConversations) { - if (this.pendingCid && newConversations.length > 0) { - const conversation = newConversations.find(c => c.cid === this.pendingCid); - if (conversation) { - // 先设置选中状态,然后加载对话消息 - this.selectedConversations = [this.pendingCid]; - this.getConversationMessages([this.pendingCid]); + // Watch for sessions loaded to handle pending session ID + sessions: { + handler(newSessions) { + if (this.pendingCid && newSessions.length > 0) { + const session = newSessions.find(s => s.session_id === this.pendingCid); + if (session) { + // 先设置选中状态,然后加载会话消息 + this.selectedSessions = [this.pendingCid]; + this.getSessionMessages([this.pendingCid]); this.pendingCid = null; } } else { - // 如果没有URL参数指定的对话,且当前没有选中对话,则默认打开第一个对话 - if (!this.currCid && newConversations.length > 0) { - const firstConversation = newConversations[0]; - this.selectedConversations = [firstConversation.cid]; - this.getConversationMessages([firstConversation.cid]); + // 如果没有URL参数指定的会话,且当前没有选中会话,则默认选中第一个会话 + if (!this.currSessionId && newSessions.length > 0) { + const firstSession = newSessions[0]; + this.selectedSessions = [firstSession.session_id]; + // 不自动加载消息,等用户点击或发送消息 } } } @@ -431,7 +410,7 @@ export default { // 设置输入框标签 this.inputFieldLabel = this.tm('input.chatPrompt'); - this.getConversations(); + this.getSessions(); let inputField = document.getElementById('input-field'); inputField.addEventListener('paste', this.handlePaste); inputField.addEventListener('keydown', function (e) { @@ -532,34 +511,6 @@ export default { this.sidebarHoverExpanded = false; }, - // 显示编辑对话标题对话框 - showEditTitleDialog(cid, title) { - this.editingCid = cid; - this.editingTitle = title || ''; // 如果标题为空,则设置为空字符串 - this.editTitleDialog = true; - }, - - // 保存对话标题 - saveTitle() { - if (!this.editingCid) return; - - const trimmedTitle = this.editingTitle.trim(); - axios.post('/api/chat/rename_conversation', { - conversation_id: this.editingCid, - title: trimmedTitle - }) - .then(response => { - // 更新本地对话列表中的标题 - const conversation = this.conversations.find(c => c.cid === this.editingCid); - if (conversation) { - conversation.title = trimmedTitle; - } - this.editTitleDialog = false; - }) - .catch(err => { - console.error('重命名对话失败:', err); - }); - }, async getMediaFile(filename) { if (this.mediaCache[filename]) { @@ -691,9 +642,16 @@ export default { // Reset the input value to allow selecting the same file again event.target.value = ''; }, - getConversations() { - axios.get('/api/chat/conversations').then(response => { - this.conversations = response.data.data; + getSessions() { + axios.get('/api/chat/sessions').then(response => { + this.sessions = response.data.data; + // 使用 sessions 作为显示列表(兼容旧代码) + this.conversations = this.sessions.map(session => ({ + cid: session.session_id, + title: this.tm('conversation.newConversation'), // 暂时使用默认标题 + updated_at: session.updated_at, + created_at: session.created_at + })); // If there's a pending conversation ID from the route if (this.pendingCid) { @@ -703,30 +661,33 @@ export default { this.pendingCid = null; } } else { - // 如果没有URL参数指定的对话,且当前没有选中对话,则默认打开第一个对话 - if (!this.currCid && this.conversations.length > 0) { - const firstConversation = this.conversations[0]; - this.selectedConversations = [firstConversation.cid]; - this.getConversationMessages([firstConversation.cid]); + // 如果没有URL参数指定的会话,且当前没有选中会话,则默认打开第一个会话 + if (!this.currSessionId && this.sessions.length > 0) { + const firstSession = this.sessions[0]; + this.currSessionId = firstSession.session_id; + this.selectedConversations = [firstSession.session_id]; + // 注意:现在不自动加载消息,等用户发送消息时再创建对话 } } }).catch(err => { - if (err.response.status === 401) { + if (err.response && err.response.status === 401) { this.$router.push('/auth/login?redirect=/chatbox'); } console.error(err); }); }, - getConversationMessages(cid) { - if (!cid[0]) + getSessionMessages(sessionIds) { + if (!sessionIds[0]) return; - // Update the URL to reflect the selected conversation - if (this.$route.path !== `/chat/${cid[0]}` && this.$route.path !== `/chatbox/${cid[0]}`) { + const sessionId = sessionIds[0]; + + // Update the URL to reflect the selected session + if (this.$route.path !== `/chat/${sessionId}` && this.$route.path !== `/chatbox/${sessionId}`) { if (this.$route.path.startsWith('/chatbox')) { - this.$router.push(`/chatbox/${cid[0]}`); + this.$router.push(`/chatbox/${sessionId}`); } else { - this.$router.push(`/chat/${cid[0]}`); + this.$router.push(`/chat/${sessionId}`); } return } @@ -736,22 +697,22 @@ export default { this.closeMobileSidebar(); } - axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(async response => { - this.currCid = cid[0]; - // Update the selected conversation in the sidebar - this.selectedConversations = [cid[0]]; + axios.get('/api/chat/get_session?session_id=' + sessionId).then(async response => { + this.currSessionId = sessionId; + // Update the selected session in the sidebar + this.selectedSessions = [sessionId]; let history = response.data.data.history; this.isConvRunning = response.data.data.is_running || false; if (this.isConvRunning) { if (!this.isToastedRunningInfo) { - useToast().info("该对话正在运行中。", { timeout: 5000 }); + useToast().info("该会话正在运行中。", { timeout: 5000 }); this.isToastedRunningInfo = true; } - // 如果对话还在运行,3秒后重新获取消息 + // 如果会话还在运行,3秒后重新获取消息 setTimeout(() => { - this.getConversationMessages([this.currCid]); + this.getSessionMessages([this.currSessionId]); }, 3000); } @@ -795,35 +756,40 @@ export default { }); }, async newConversation() { - return axios.get('/api/chat/new_conversation').then(response => { - const cid = response.data.data.conversation_id; - this.currCid = cid; - // Update the URL to reflect the new conversation - if (this.$route.path.startsWith('/chatbox')) { - this.$router.push(`/chatbox/${cid}`); - } else { - this.$router.push(`/chat/${cid}`); - } - this.getConversations(); - return cid; - }).catch(err => { - console.error(err); - throw err; - }); + // 懒加载:如果没有会话ID,先创建会话 + if (!this.currSessionId) { + await this.newC(); + } + // 返回会话ID作为"对话ID"(兼容旧逻辑) + return this.currSessionId; }, - newC() { - this.currCid = ''; - this.selectedConversations = []; // 清除选中状态 - this.messages = []; - // 手机端关闭侧边栏 - if (this.isMobile) { - this.closeMobileSidebar(); - } - if (this.$route.path.startsWith('/chatbox')) { - this.$router.push('/chatbox'); - } else { - this.$router.push('/chat'); + async newC() { + // 创建新会话 + try { + const response = await axios.get('/api/chat/new_session'); + const sessionId = response.data.data.session_id; + + this.currSessionId = sessionId; + this.selectedSessions = [sessionId]; // 选中新会话 + this.messages = []; + + // 手机端关闭侧边栏 + if (this.isMobile) { + this.closeMobileSidebar(); + } + + // 更新URL + if (this.$route.path.startsWith('/chatbox')) { + this.$router.push(`/chatbox/${sessionId}`); + } else { + this.$router.push(`/chat/${sessionId}`); + } + + // 刷新会话列表 + this.getSessions(); + } catch (err) { + console.error('创建新会话失败:', err); } }, @@ -843,14 +809,24 @@ export default { return date.toLocaleString(locale, options).replace(/\//g, '-').replace(/, /g, ' '); }, - deleteConversation(cid) { - axios.get('/api/chat/delete_conversation?conversation_id=' + cid).then(response => { - this.getConversations(); - this.currCid = ''; - this.selectedConversations = []; // 清除选中状态 - this.messages = []; + deleteSession(sessionId) { + // 删除会话 + axios.get('/api/chat/delete_session?session_id=' + sessionId).then(response => { + this.getSessions(); + // 如果删除的是当前会话,清空状态 + if (this.currSessionId === sessionId) { + this.currSessionId = ''; + this.selectedSessions = []; + this.messages = []; + // 更新URL + if (this.$route.path.startsWith('/chatbox')) { + this.$router.push('/chatbox'); + } else { + this.$router.push('/chat'); + } + } }).catch(err => { - console.error(err); + console.error('删除会话失败:', err); }); }, @@ -868,9 +844,9 @@ export default { return; } - if (this.currCid == '') { - const cid = await this.newConversation(); - // URL is already updated in newConversation method + if (this.currSessionId == '') { + await this.newConversation(); + // Session is created and URL is updated } // 保存当前要发送的数据到临时变量 @@ -935,7 +911,7 @@ export default { }, body: JSON.stringify({ message: promptToSend, - conversation_id: this.currCid, + conversation_id: this.currSessionId, image_url: imageNamesToSend, audio_url: audioNameToSend ? [audioNameToSend] : [], selected_provider: selectedProviderId, @@ -1063,7 +1039,7 @@ export default { this.loadingChat = false; // get the latest conversations - this.getConversations(); + this.getSessions(); } catch (err) { console.error('发送消息失败:', err); From 0747099cacf58ea177091da5c418f42ebbd9059d Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 18 Nov 2025 22:07:43 +0800 Subject: [PATCH 02/27] fix: restore migration check for version 4.7 --- astrbot/core/db/migration/migra_46_to_47.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/astrbot/core/db/migration/migra_46_to_47.py b/astrbot/core/db/migration/migra_46_to_47.py index 407a667c9..407a840d8 100644 --- a/astrbot/core/db/migration/migra_46_to_47.py +++ b/astrbot/core/db/migration/migra_46_to_47.py @@ -18,11 +18,11 @@ async def migrate_46_to_47(db_helper: BaseDatabase): where platform_id='webchat' and creates corresponding WebChatSession records. """ # 检查是否已经完成迁移 - # migration_done = await db_helper.get_preference( - # "global", "global", "migration_done_v47" - # ) - # if migration_done: - # return + migration_done = await db_helper.get_preference( + "global", "global", "migration_done_v47" + ) + if migration_done: + return logger.info("开始执行数据库迁移(4.6 -> 4.7)...") From cf4a5d9ea498f29e8e43e3a06471f6688044846f Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 18 Nov 2025 22:37:55 +0800 Subject: [PATCH 03/27] refactor: change to platform session --- astrbot/core/core_lifecycle.py | 2 +- astrbot/core/db/__init__.py | 34 ++++--- astrbot/core/db/migration/migra_46_to_47.py | 30 +++--- astrbot/core/db/po.py | 20 ++-- astrbot/core/db/sqlite.py | 69 +++++++++----- astrbot/dashboard/routes/chat.py | 83 ++++++++++++++--- dashboard/src/components/chat/Chat.vue | 92 ++++++++++++++++++- .../src/i18n/locales/en-US/features/chat.json | 6 +- .../src/i18n/locales/zh-CN/features/chat.json | 6 +- 9 files changed, 262 insertions(+), 80 deletions(-) diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 676e50384..fdf757116 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -104,7 +104,7 @@ class AstrBotCoreLifecycle: logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") logger.error(traceback.format_exc()) - # 4.6 to 4.7 migration for webchat sessions and group feature + # 4.6 to 4.7 migration for platform sessions and group feature try: await migrate_46_to_47(self.db) except Exception as e: diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 456682bd2..48ccb6801 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -13,10 +13,10 @@ from astrbot.core.db.po import ( ConversationV2, Persona, PlatformMessageHistory, + PlatformSession, PlatformStat, Preference, Stats, - WebChatSession, ) @@ -316,43 +316,49 @@ class BaseDatabase(abc.ABC): ... # ==== - # WebChat Session Management + # Platform Session Management # ==== @abc.abstractmethod - async def create_webchat_session( + async def create_platform_session( self, creator: str, + platform_id: str = "webchat", session_id: str | None = None, + display_name: str | None = None, is_group: int = 0, - ) -> WebChatSession: - """Create a new WebChat session.""" + ) -> PlatformSession: + """Create a new Platform session.""" ... @abc.abstractmethod - async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None: - """Get a WebChat session by its ID.""" + async def get_platform_session_by_id( + self, session_id: str + ) -> PlatformSession | None: + """Get a Platform session by its ID.""" ... @abc.abstractmethod - async def get_webchat_sessions_by_creator( + async def get_platform_sessions_by_creator( self, creator: str, + platform_id: str | None = None, page: int = 1, page_size: int = 20, - ) -> list[WebChatSession]: - """Get all WebChat sessions for a specific creator (username).""" + ) -> list[PlatformSession]: + """Get all Platform sessions for a specific creator (username) and optionally platform.""" ... @abc.abstractmethod - async def update_webchat_session( + async def update_platform_session( self, session_id: str, + display_name: str | None = None, ) -> None: - """Update a WebChat session's updated_at timestamp.""" + """Update a Platform session's updated_at timestamp and optionally display_name.""" ... @abc.abstractmethod - async def delete_webchat_session(self, session_id: str) -> None: - """Delete a WebChat session by its ID.""" + async def delete_platform_session(self, session_id: str) -> None: + """Delete a Platform session by its ID.""" ... diff --git a/astrbot/core/db/migration/migra_46_to_47.py b/astrbot/core/db/migration/migra_46_to_47.py index 407a840d8..e523551f6 100644 --- a/astrbot/core/db/migration/migra_46_to_47.py +++ b/astrbot/core/db/migration/migra_46_to_47.py @@ -1,6 +1,12 @@ """Migration script from version 4.6 to 4.7. -This migration creates WebChat sessions from existing platform_message_history records. +This migration creates PlatformSession from existing platform_message_history records. + +Changes: +- Creates platform_sessions table +- Adds platform_id field (default: 'webchat') +- Adds display_name field +- Session_id format: {platform_id}_{uuid} """ from sqlalchemy import func, select @@ -12,10 +18,10 @@ from astrbot.core.db.po import PlatformMessageHistory async def migrate_46_to_47(db_helper: BaseDatabase): - """Migrate WebChat data to the new session table. + """Create PlatformSession records from platform_message_history. This migration extracts all unique user_ids from platform_message_history - where platform_id='webchat' and creates corresponding WebChatSession records. + where platform_id='webchat' and creates corresponding PlatformSession records. """ # 检查是否已经完成迁移 migration_done = await db_helper.get_preference( @@ -28,7 +34,7 @@ async def migrate_46_to_47(db_helper: BaseDatabase): try: async with db_helper.get_db() as session: - # 1. 查询所有 webchat 的唯一 user_id 以及它们的最早和最新消息时间 + # 从 platform_message_history 创建 PlatformSession query = ( select( col(PlatformMessageHistory.user_id), @@ -51,7 +57,7 @@ async def migrate_46_to_47(db_helper: BaseDatabase): logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") - # 2. 为每个 user_id 创建 WebChatSession 记录 + # 为每个 user_id 创建 PlatformSession 记录 migrated_count = 0 skipped_count = 0 @@ -60,28 +66,26 @@ async def migrate_46_to_47(db_helper: BaseDatabase): session_id = user_id # sender_name 通常是 username,但可能为 None - # 从第一条消息中提取 creator creator = sender_name if sender_name else "guest" # 检查是否已经存在该会话 - existing_session = await db_helper.get_webchat_session_by_id(session_id) + existing_session = await db_helper.get_platform_session_by_id( + session_id + ) if existing_session: logger.debug(f"会话 {session_id} 已存在,跳过") skipped_count += 1 continue - # 创建新的 WebChatSession + # 创建新的 PlatformSession try: - await db_helper.create_webchat_session( + await db_helper.create_platform_session( creator=creator, session_id=session_id, + platform_id="webchat", is_group=0, ) - # 更新时间戳以匹配历史记录 - # 注意:这里我们需要直接更新数据库,因为 create 方法会设置当前时间 - # 但我们希望保留原始的创建和更新时间 - migrated_count += 1 if migrated_count % 100 == 0: diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index eee4c9dc6..9fc871d08 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -161,14 +161,14 @@ class PlatformMessageHistory(SQLModel, table=True): ) -class WebChatSession(SQLModel, table=True): - """WebChat session table for managing user sessions. +class PlatformSession(SQLModel, table=True): + """Platform session table for managing user sessions across different platforms. - A session represents a chat window for a specific user. Each session can have - multiple conversations (对话) associated with it. + A session represents a chat window for a specific user on a specific platform. + Each session can have multiple conversations (对话) associated with it. """ - __tablename__ = "webchat_sessions" + __tablename__ = "platform_sessions" inner_id: int | None = Field( primary_key=True, @@ -176,13 +176,17 @@ class WebChatSession(SQLModel, table=True): default=None, ) session_id: str = Field( - max_length=36, + max_length=100, nullable=False, unique=True, - default_factory=lambda: str(uuid.uuid4()), + default_factory=lambda: f"webchat_{uuid.uuid4()}", ) + platform_id: str = Field(default="webchat", nullable=False) + """Platform identifier (e.g., 'webchat', 'qq', 'discord')""" creator: str = Field(nullable=False) """Username of the session creator""" + display_name: str | None = Field(default=None, max_length=255) + """Display name for the session""" is_group: int = Field(default=0, nullable=False) """0 for private chat, 1 for group chat (not implemented yet)""" created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @@ -194,7 +198,7 @@ class WebChatSession(SQLModel, table=True): __table_args__ = ( UniqueConstraint( "session_id", - name="uix_webchat_session_id", + name="uix_platform_session_id", ), ) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index b96a2d3ff..202c9d892 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,6 +1,7 @@ import asyncio import threading import typing as T +import uuid from datetime import datetime, timedelta from sqlalchemy.ext.asyncio import AsyncSession @@ -12,10 +13,10 @@ from astrbot.core.db.po import ( ConversationV2, Persona, PlatformMessageHistory, + PlatformSession, PlatformStat, Preference, SQLModel, - WebChatSession, ) from astrbot.core.db.po import ( Platform as DeprecatedPlatformStat, @@ -712,25 +713,32 @@ class SQLiteDatabase(BaseDatabase): return result # ==== - # WebChat Session Management + # Platform Session Management # ==== - async def create_webchat_session( + async def create_platform_session( self, creator: str, + platform_id: str = "webchat", session_id: str | None = None, + display_name: str | None = None, is_group: int = 0, - ) -> WebChatSession: - """Create a new WebChat session.""" + ) -> PlatformSession: + """Create a new Platform session.""" kwargs = {} if session_id: kwargs["session_id"] = session_id + else: + # Auto-generate session_id with platform_id prefix + kwargs["session_id"] = f"{platform_id}_{uuid.uuid4()}" async with self.get_db() as session: session: AsyncSession async with session.begin(): - new_session = WebChatSession( + new_session = PlatformSession( creator=creator, + platform_id=platform_id, + display_name=display_name, is_group=is_group, **kwargs, ) @@ -739,57 +747,68 @@ class SQLiteDatabase(BaseDatabase): await session.refresh(new_session) return new_session - async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None: - """Get a WebChat session by its ID.""" + async def get_platform_session_by_id( + self, session_id: str + ) -> PlatformSession | None: + """Get a Platform session by its ID.""" async with self.get_db() as session: session: AsyncSession - query = select(WebChatSession).where( - WebChatSession.session_id == session_id, + query = select(PlatformSession).where( + PlatformSession.session_id == session_id, ) result = await session.execute(query) return result.scalar_one_or_none() - async def get_webchat_sessions_by_creator( + async def get_platform_sessions_by_creator( self, creator: str, + platform_id: str | None = None, page: int = 1, page_size: int = 20, - ) -> list[WebChatSession]: - """Get all WebChat sessions for a specific creator (username).""" + ) -> list[PlatformSession]: + """Get all Platform sessions for a specific creator (username) and optionally platform.""" async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size + query = select(PlatformSession).where(PlatformSession.creator == creator) + + if platform_id: + query = query.where(PlatformSession.platform_id == platform_id) + query = ( - select(WebChatSession) - .where(WebChatSession.creator == creator) - .order_by(desc(WebChatSession.updated_at)) + query.order_by(desc(PlatformSession.updated_at)) .offset(offset) .limit(page_size) ) result = await session.execute(query) return list(result.scalars().all()) - async def update_webchat_session( + async def update_platform_session( self, session_id: str, + display_name: str | None = None, ) -> None: - """Update a WebChat session's updated_at timestamp.""" + """Update a Platform session's updated_at timestamp and optionally display_name.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): + values = {"updated_at": datetime.now()} + if display_name is not None: + values["display_name"] = display_name + await session.execute( - update(WebChatSession) - .where(WebChatSession.session_id == session_id) - .values(updated_at=datetime.now()), + update(PlatformSession) + .where(PlatformSession.session_id == session_id) + .values(**values), ) - async def delete_webchat_session(self, session_id: str) -> None: - """Delete a WebChat session by its ID.""" + async def delete_platform_session(self, session_id: str) -> None: + """Delete a Platform session by its ID.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): await session.execute( - delete(WebChatSession).where( - WebChatSession.session_id == session_id, + delete(PlatformSession).where( + PlatformSession.session_id == session_id, ), ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 8eacc5c4f..70e52e778 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -39,6 +39,10 @@ class ChatRoute(Route): "/chat/sessions": ("GET", self.get_sessions), "/chat/get_session": ("GET", self.get_session), "/chat/delete_session": ("GET", self.delete_webchat_session), + "/chat/update_session_display_name": ( + "POST", + self.update_session_display_name, + ), "/chat/get_file": ("GET", self.get_file), "/chat/post_image": ("POST", self.post_image), "/chat/post_file": ("POST", self.post_file), @@ -246,56 +250,74 @@ class ChatRoute(Route): return response async def delete_webchat_session(self): - """Delete a WebChat session and all its related data.""" + """Delete a Platform session and all its related data.""" session_id = request.args.get("session_id") if not session_id: return Response().error("Missing key: session_id").__dict__ username = g.get("username", "guest") # 验证会话是否存在且属于当前用户 - session = await self.db.get_webchat_session_by_id(session_id) + session = await self.db.get_platform_session_by_id(session_id) if not session: return Response().error(f"Session {session_id} not found").__dict__ if session.creator != username: return Response().error("Permission denied").__dict__ # 删除该会话下的所有对话 - unified_msg_origin = f"webchat:FriendMessage:webchat!{username}!{session_id}" + unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}" await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) # 删除消息历史 await self.platform_history_mgr.delete( - platform_id="webchat", + platform_id=session.platform_id, user_id=session_id, offset_sec=99999999, ) - # 清理队列 - webchat_queue_mgr.remove_queues(session_id) + # 清理队列(仅对 webchat) + if session.platform_id == "webchat": + webchat_queue_mgr.remove_queues(session_id) # 删除会话 - await self.db.delete_webchat_session(session_id) + await self.db.delete_platform_session(session_id) return Response().ok().__dict__ async def new_session(self): - """Create a new WebChat session.""" + """Create a new Platform session (default: webchat).""" username = g.get("username", "guest") + # 获取可选的 platform_id 参数,默认为 webchat + platform_id = request.args.get("platform_id", "webchat") + # 创建新会话 - session = await self.db.create_webchat_session( + session = await self.db.create_platform_session( creator=username, + platform_id=platform_id, is_group=0, ) - return Response().ok(data={"session_id": session.session_id}).__dict__ + return ( + Response() + .ok( + data={ + "session_id": session.session_id, + "platform_id": session.platform_id, + } + ) + .__dict__ + ) async def get_sessions(self): - """Get all WebChat sessions for the current user.""" + """Get all Platform sessions for the current user.""" username = g.get("username", "guest") - sessions = await self.db.get_webchat_sessions_by_creator( + # 获取可选的 platform_id 参数 + platform_id = request.args.get("platform_id") + + sessions = await self.db.get_platform_sessions_by_creator( creator=username, + platform_id=platform_id, page=1, page_size=100, # 暂时返回前100个 ) @@ -306,7 +328,9 @@ class ChatRoute(Route): sessions_data.append( { "session_id": session.session_id, + "platform_id": session.platform_id, "creator": session.creator, + "display_name": session.display_name, "is_group": session.is_group, "created_at": int(session.created_at.timestamp()), "updated_at": int(session.updated_at.timestamp()), @@ -321,9 +345,13 @@ class ChatRoute(Route): if not session_id: return Response().error("Missing key: session_id").__dict__ + # 获取会话信息以确定 platform_id + session = await self.db.get_platform_session_by_id(session_id) + platform_id = session.platform_id if session else "webchat" + # Get platform message history using session_id history_ls = await self.platform_history_mgr.get( - platform_id="webchat", + platform_id=platform_id, user_id=session_id, page=1, page_size=1000, @@ -341,3 +369,32 @@ class ChatRoute(Route): ) .__dict__ ) + + async def update_session_display_name(self): + """Update a Platform session's display name.""" + post_data = await request.json + + session_id = post_data.get("session_id") + display_name = post_data.get("display_name") + + if not session_id: + return Response().error("Missing key: session_id").__dict__ + if display_name is None: + return Response().error("Missing key: display_name").__dict__ + + username = g.get("username", "guest") + + # 验证会话是否存在且属于当前用户 + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 更新 display_name + await self.db.update_platform_session( + session_id=session_id, + display_name=display_name, + ) + + return Response().ok().__dict__ diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 69cb5d6a8..e1a912032 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -59,7 +59,7 @@ - {{ tm('conversation.newConversation') }} + {{ session.display_name || tm('conversation.newConversation') }} {{ formatDate(session.updated_at) }} @@ -67,6 +67,9 @@ @@ -1368,14 +1435,31 @@ export default { transition: all 0.2s ease; } +.session-item:hover .session-actions { + opacity: 1; + visibility: visible; +} + +.session-actions { + display: flex; + gap: 4px; + opacity: 0; + visibility: hidden; + transition: all 0.2s ease; +} + .edit-title-btn, -.delete-conversation-btn { +.delete-conversation-btn, +.edit-session-btn, +.delete-session-btn { opacity: 0.7; transition: opacity 0.2s ease; } .edit-title-btn:hover, -.delete-conversation-btn:hover { +.delete-conversation-btn:hover, +.edit-session-btn:hover, +.delete-session-btn:hover { opacity: 1; } diff --git a/dashboard/src/i18n/locales/en-US/features/chat.json b/dashboard/src/i18n/locales/en-US/features/chat.json index bce552aff..bc83787ff 100644 --- a/dashboard/src/i18n/locales/en-US/features/chat.json +++ b/dashboard/src/i18n/locales/en-US/features/chat.json @@ -47,7 +47,11 @@ "noHistory": "No conversation history", "systemStatus": "System Status", "llmService": "LLM Service", - "speechToText": "Speech to Text" + "speechToText": "Speech to Text", + "editDisplayName": "Edit Session Name", + "displayName": "Session Name", + "displayNameUpdated": "Session name updated", + "displayNameUpdateFailed": "Failed to update session name" }, "modes": { "darkMode": "Switch to Dark Mode", diff --git a/dashboard/src/i18n/locales/zh-CN/features/chat.json b/dashboard/src/i18n/locales/zh-CN/features/chat.json index 37bacd408..8130ecd8b 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/chat.json +++ b/dashboard/src/i18n/locales/zh-CN/features/chat.json @@ -47,7 +47,11 @@ "noHistory": "暂无对话历史", "systemStatus": "系统状态", "llmService": "LLM 服务", - "speechToText": "语音转文本" + "speechToText": "语音转文本", + "editDisplayName": "编辑会话名称", + "displayName": "会话名称", + "displayNameUpdated": "会话名称已更新", + "displayNameUpdateFailed": "更新会话名称失败" }, "modes": { "darkMode": "切换到夜间模式", From cdf617feac5eb89e1a2d4a9ab6e5a35ffb600ffe Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 19 Nov 2025 19:16:15 +0800 Subject: [PATCH 04/27] refactor: optimize WebChat session migration by batch inserting records --- astrbot/core/db/migration/migra_46_to_47.py | 54 +++++++++++---------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/astrbot/core/db/migration/migra_46_to_47.py b/astrbot/core/db/migration/migra_46_to_47.py index e523551f6..79e72eb49 100644 --- a/astrbot/core/db/migration/migra_46_to_47.py +++ b/astrbot/core/db/migration/migra_46_to_47.py @@ -14,7 +14,7 @@ from sqlmodel import col from astrbot.api import logger, sp from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import PlatformMessageHistory +from astrbot.core.db.po import PlatformMessageHistory, PlatformSession async def migrate_46_to_47(db_helper: BaseDatabase): @@ -57,8 +57,13 @@ async def migrate_46_to_47(db_helper: BaseDatabase): logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") - # 为每个 user_id 创建 PlatformSession 记录 - migrated_count = 0 + # 检查已存在的会话 + existing_query = select(col(PlatformSession.session_id)) + existing_result = await session.execute(existing_query) + existing_session_ids = {row[0] for row in existing_result.fetchall()} + + # 批量创建 PlatformSession 记录 + sessions_to_add = [] skipped_count = 0 for user_id, sender_name, created_at, updated_at in webchat_users: @@ -69,35 +74,32 @@ async def migrate_46_to_47(db_helper: BaseDatabase): creator = sender_name if sender_name else "guest" # 检查是否已经存在该会话 - existing_session = await db_helper.get_platform_session_by_id( - session_id - ) - if existing_session: + if session_id in existing_session_ids: logger.debug(f"会话 {session_id} 已存在,跳过") skipped_count += 1 continue - # 创建新的 PlatformSession - try: - await db_helper.create_platform_session( - creator=creator, - session_id=session_id, - platform_id="webchat", - is_group=0, - ) + # 创建新的 PlatformSession(保留原有的时间戳) + new_session = PlatformSession( + session_id=session_id, + platform_id="webchat", + creator=creator, + is_group=0, + created_at=created_at, + updated_at=updated_at, + ) + sessions_to_add.append(new_session) - migrated_count += 1 + # 批量插入 + if sessions_to_add: + session.add_all(sessions_to_add) + await session.commit() - if migrated_count % 100 == 0: - logger.info(f"已迁移 {migrated_count} 个会话...") - - except Exception as e: - logger.error(f"迁移会话 {session_id} 失败: {e}") - continue - - logger.info( - f"WebChat 会话迁移完成!成功迁移: {migrated_count}, 跳过: {skipped_count}", - ) + logger.info( + f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}", + ) + else: + logger.info("没有新会话需要迁移") # 标记迁移完成 await sp.put_async("global", "global", "migration_done_v47", True) From be3e5f3f8bf79bbc54e347d9d55bfcbb1add8647 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 19 Nov 2025 19:41:25 +0800 Subject: [PATCH 05/27] refactor: update message history deletion logic to remove newer records based on offset --- astrbot/core/db/__init__.py | 2 +- astrbot/core/db/sqlite.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 48ccb6801..2af0428d0 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -184,7 +184,7 @@ class BaseDatabase(abc.ABC): user_id: str, offset_sec: int = 86400, ) -> None: - """Delete platform message history records older than the specified offset.""" + """Delete platform message history records newer than the specified offset.""" ... @abc.abstractmethod diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 202c9d892..140fb2a26 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -414,7 +414,7 @@ class SQLiteDatabase(BaseDatabase): user_id, offset_sec=86400, ): - """Delete platform message history records older than the specified offset.""" + """Delete platform message history records newer than the specified offset.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -424,7 +424,7 @@ class SQLiteDatabase(BaseDatabase): delete(PlatformMessageHistory).where( col(PlatformMessageHistory.platform_id) == platform_id, col(PlatformMessageHistory.user_id) == user_id, - col(PlatformMessageHistory.created_at) < cutoff_time, + col(PlatformMessageHistory.created_at) >= cutoff_time, ), ) From e7609563535dd153740b2b8e919ed3e50e8efb31 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 19 Nov 2025 19:41:57 +0800 Subject: [PATCH 06/27] refactor: enhance PlatformSession migration by adding display_name from Conversations and improve session item styling --- astrbot/core/db/migration/migra_46_to_47.py | 22 ++++++++++++++++++- dashboard/src/components/chat/Chat.vue | 12 +++++++++- .../src/i18n/locales/zh-CN/features/chat.json | 2 +- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/astrbot/core/db/migration/migra_46_to_47.py b/astrbot/core/db/migration/migra_46_to_47.py index 79e72eb49..e9907338e 100644 --- a/astrbot/core/db/migration/migra_46_to_47.py +++ b/astrbot/core/db/migration/migra_46_to_47.py @@ -14,7 +14,7 @@ from sqlmodel import col from astrbot.api import logger, sp from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import PlatformMessageHistory, PlatformSession +from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession async def migrate_46_to_47(db_helper: BaseDatabase): @@ -62,6 +62,22 @@ async def migrate_46_to_47(db_helper: BaseDatabase): existing_result = await session.execute(existing_query) existing_session_ids = {row[0] for row in existing_result.fetchall()} + # 查询 Conversations 表中的 title,用于设置 display_name + # 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id} + user_ids_to_query = [ + f"webchat:FriendMessage:webchat!astrbot!{user_id}" + for user_id, _, _, _ in webchat_users + ] + conv_query = select( + col(ConversationV2.user_id), col(ConversationV2.title) + ).where(col(ConversationV2.user_id).in_(user_ids_to_query)) + conv_result = await session.execute(conv_query) + # 创建 user_id -> title 的映射字典 + title_map = { + user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title + for user_id, title in conv_result.fetchall() + } + # 批量创建 PlatformSession 记录 sessions_to_add = [] skipped_count = 0 @@ -79,6 +95,9 @@ async def migrate_46_to_47(db_helper: BaseDatabase): skipped_count += 1 continue + # 从 Conversations 表中获取 display_name + display_name = title_map.get(user_id) + # 创建新的 PlatformSession(保留原有的时间戳) new_session = PlatformSession( session_id=session_id, @@ -87,6 +106,7 @@ async def migrate_46_to_47(db_helper: BaseDatabase): is_group=0, created_at=created_at, updated_at=updated_at, + display_name=display_name, ) sessions_to_add.append(new_session) diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index e1a912032..366641a8a 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -57,7 +57,7 @@ style="background-color: transparent;" v-model:selected="selectedSessions" @update:selected="getSessionMessages"> + class="session-item" active-color="secondary"> {{ session.display_name || tm('conversation.newConversation') }} @@ -1435,6 +1435,16 @@ export default { transition: all 0.2s ease; } +.session-item { + margin-bottom: 6px; + border-radius: 8px !important; + transition: all 0.2s ease; + height: auto !important; + min-height: 56px; + padding: 8px 16px !important; + position: relative; +} + .session-item:hover .session-actions { opacity: 1; visibility: visible; diff --git a/dashboard/src/i18n/locales/zh-CN/features/chat.json b/dashboard/src/i18n/locales/zh-CN/features/chat.json index 8130ecd8b..002ea626e 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/chat.json +++ b/dashboard/src/i18n/locales/zh-CN/features/chat.json @@ -43,7 +43,7 @@ "exitFullscreen": "退出全屏" }, "conversation": { - "newConversation": "新对话", + "newConversation": "新的聊天", "noHistory": "暂无对话历史", "systemStatus": "系统状态", "llmService": "LLM 服务", From 1935ce4700ac757581108c827e3a209dd824df6e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 19 Nov 2025 19:54:29 +0800 Subject: [PATCH 07/27] refactor: update session handling by replacing conversation_id with session_id in chat routes and components --- astrbot/core/db/sqlite.py | 4 ++-- astrbot/dashboard/routes/chat.py | 16 +++++++++------- dashboard/src/components/chat/Chat.vue | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 140fb2a26..69203bf6d 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -729,8 +729,8 @@ class SQLiteDatabase(BaseDatabase): if session_id: kwargs["session_id"] = session_id else: - # Auto-generate session_id with platform_id prefix - kwargs["session_id"] = f"{platform_id}_{uuid.uuid4()}" + # Auto-generate session_id + kwargs["session_id"] = uuid.uuid4() async with self.get_db() as session: session: AsyncSession diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 70e52e778..2620f1a17 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -119,11 +119,14 @@ class ChatRoute(Route): if "message" not in post_data and "image_url" not in post_data: return Response().error("Missing key: message or image_url").__dict__ - if "conversation_id" not in post_data: - return Response().error("Missing key: conversation_id").__dict__ + if "session_id" not in post_data and "conversation_id" not in post_data: + return ( + Response().error("Missing key: session_id or conversation_id").__dict__ + ) message = post_data["message"] - conversation_id = post_data["conversation_id"] + # conversation_id = post_data["conversation_id"] + session_id = post_data.get("session_id", post_data.get("conversation_id")) image_url = post_data.get("image_url") audio_url = post_data.get("audio_url") selected_provider = post_data.get("selected_provider") @@ -136,12 +139,11 @@ class ChatRoute(Route): .error("Message and image_url and audio_url are empty") .__dict__ ) - if not conversation_id: - return Response().error("conversation_id is empty").__dict__ + if not session_id: + return Response().error("session_id is empty").__dict__ # 追加用户消息 - # conversation_id 现在实际上是 session_id - webchat_conv_id = conversation_id + webchat_conv_id = session_id # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 366641a8a..81b3c327f 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -947,7 +947,7 @@ export default { }, body: JSON.stringify({ message: promptToSend, - conversation_id: this.currSessionId, + session_id: this.currSessionId, image_url: imageNamesToSend, audio_url: audioNameToSend ? [audioNameToSend] : [], selected_provider: selectedProviderId, From aa595322874e7682d4e3206643360e5256d13c7a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 20 Nov 2025 15:58:27 +0800 Subject: [PATCH 08/27] refactor: implement migration for WebChat sessions by creating PlatformSession records from platform_message_history --- astrbot/core/core_lifecycle.py | 8 ++++---- ...{migra_46_to_47.py => migra_webchat_session.py} | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) rename astrbot/core/db/migration/{migra_46_to_47.py => migra_webchat_session.py} (91%) diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fdf757116..17fd52138 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -22,7 +22,7 @@ from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 -from astrbot.core.db.migration.migra_46_to_47 import migrate_46_to_47 +from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -104,11 +104,11 @@ class AstrBotCoreLifecycle: logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") logger.error(traceback.format_exc()) - # 4.6 to 4.7 migration for platform sessions and group feature + # migration for webchat session try: - await migrate_46_to_47(self.db) + await migrate_webchat_session(self.db) except Exception as e: - logger.error(f"Migration from version 4.6 to 4.7 failed: {e!s}") + logger.error(f"Migration for webchat session failed: {e!s}") logger.error(traceback.format_exc()) # 初始化事件队列 diff --git a/astrbot/core/db/migration/migra_46_to_47.py b/astrbot/core/db/migration/migra_webchat_session.py similarity index 91% rename from astrbot/core/db/migration/migra_46_to_47.py rename to astrbot/core/db/migration/migra_webchat_session.py index e9907338e..6cb483464 100644 --- a/astrbot/core/db/migration/migra_46_to_47.py +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -1,4 +1,4 @@ -"""Migration script from version 4.6 to 4.7. +"""Migration script for WebChat sessions. This migration creates PlatformSession from existing platform_message_history records. @@ -17,7 +17,7 @@ from astrbot.core.db import BaseDatabase from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession -async def migrate_46_to_47(db_helper: BaseDatabase): +async def migrate_webchat_session(db_helper: BaseDatabase): """Create PlatformSession records from platform_message_history. This migration extracts all unique user_ids from platform_message_history @@ -25,12 +25,12 @@ async def migrate_46_to_47(db_helper: BaseDatabase): """ # 检查是否已经完成迁移 migration_done = await db_helper.get_preference( - "global", "global", "migration_done_v47" + "global", "global", "migration_done_webchat_session" ) if migration_done: return - logger.info("开始执行数据库迁移(4.6 -> 4.7)...") + logger.info("开始执行数据库迁移(WebChat 会话迁移)...") try: async with db_helper.get_db() as session: @@ -52,7 +52,9 @@ async def migrate_46_to_47(db_helper: BaseDatabase): if not webchat_users: logger.info("没有找到需要迁移的 WebChat 数据") - await sp.put_async("global", "global", "migration_done_v47", True) + await sp.put_async( + "global", "global", "migration_done_webchat_session", True + ) return logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") @@ -122,7 +124,7 @@ async def migrate_46_to_47(db_helper: BaseDatabase): logger.info("没有新会话需要迁移") # 标记迁移完成 - await sp.put_async("global", "global", "migration_done_v47", True) + await sp.put_async("global", "global", "migration_done_webchat_session", True) except Exception as e: logger.error(f"迁移过程中发生错误: {e}", exc_info=True) From 6d6fefc4355ce71cbe935449a6bc763f0cb8a83a Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:01:22 +0800 Subject: [PATCH 09/27] fix: anyio.ClosedResourceError when calling mcp tools (#3700) * fix: anyio.ClosedResourceError when calling mcp tools added reconnect mechanism fixes: 3676 * fix(mcp_client): implement thread-safe reconnection using asyncio.Lock --- astrbot/core/agent/mcp_client.py | 180 +++++++++++++++++---- astrbot/core/provider/func_tool_manager.py | 21 +-- pyproject.toml | 3 +- requirements.txt | 1 + 4 files changed, 168 insertions(+), 37 deletions(-) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 05980b212..88cab486e 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -4,6 +4,14 @@ from contextlib import AsyncExitStack from datetime import timedelta from typing import Generic +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + from astrbot import logger from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe @@ -12,21 +20,24 @@ from .run_context import TContext from .tool import FunctionTool try: + import anyio import mcp from mcp.client.sse import sse_client except (ModuleNotFoundError, ImportError): - logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。") + logger.warning( + "Warning: Missing 'mcp' dependency, MCP services will be unavailable." + ) try: from mcp.client.streamable_http import streamablehttp_client except (ModuleNotFoundError, ImportError): logger.warning( - "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。", + "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.", ) def _prepare_config(config: dict) -> dict: - """准备配置,处理嵌套格式""" + """Prepare configuration, handle nested format""" if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] @@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict: async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: - """快速测试 MCP 服务器可达性""" + """Quick test MCP server connectivity""" import aiohttp cfg = _prepare_config(config.copy()) @@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: elif "type" in cfg: transport_type = cfg["type"] else: - raise Exception("MCP 连接配置缺少 transport 或 type 字段") + raise Exception("MCP connection config missing transport or type field") async with aiohttp.ClientSession() as session: if transport_type == "streamable_http": @@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"HTTP {response.status}: {response.reason}" except asyncio.TimeoutError: - return False, f"连接超时: {timeout}秒" + return False, f"Connection timeout: {timeout} seconds" except Exception as e: return False, f"{e!s}" @@ -101,6 +112,7 @@ class MCPClient: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() + self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup self.name: str | None = None self.active: bool = True @@ -108,22 +120,32 @@ class MCPClient: self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() - async def connect_to_server(self, mcp_server_config: dict, name: str): - """连接到 MCP 服务器 + # Store connection config for reconnection + self._mcp_server_config: dict | None = None + self._server_name: str | None = None + self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection + self._reconnecting: bool = False # For logging and debugging - 如果 `url` 参数存在: - 1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。 - 1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。 - 2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。 + async def connect_to_server(self, mcp_server_config: dict, name: str): + """Connect to MCP server + + If `url` parameter exists: + 1. When transport is specified as `streamable_http`, use Streamable HTTP connection. + 2. When transport is specified as `sse`, use SSE connection. + 3. If not specified, default to SSE connection to MCP service. Args: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ + # Store config for reconnection + self._mcp_server_config = mcp_server_config + self._server_name = name + cfg = _prepare_config(mcp_server_config.copy()) def logging_callback(msg: str): - # 处理 MCP 服务的错误日志 + # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -137,7 +159,7 @@ class MCPClient: elif "type" in cfg: transport_type = cfg["type"] else: - raise Exception("MCP 连接配置缺少 transport 或 type 字段") + raise Exception("MCP connection config missing transport or type field") if transport_type != "streamable_http": # SSE transport method @@ -193,7 +215,7 @@ class MCPClient: ) def callback(msg: str): - # 处理 MCP 服务的错误日志 + # Handle MCP service error logs self.server_errlogs.append(msg) stdio_transport = await self.exit_stack.enter_async_context( @@ -222,10 +244,120 @@ class MCPClient: self.tools = response.tools return response + async def _reconnect(self) -> None: + """Reconnect to the MCP server using the stored configuration. + + Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments. + + Raises: + Exception: raised when reconnection fails + """ + async with self._reconnect_lock: + # Check if already reconnecting (useful for logging) + if self._reconnecting: + logger.debug( + f"MCP Client {self._server_name} is already reconnecting, skipping" + ) + return + + if not self._mcp_server_config or not self._server_name: + raise Exception("Cannot reconnect: missing connection configuration") + + self._reconnecting = True + try: + logger.info( + f"Attempting to reconnect to MCP server {self._server_name}..." + ) + + # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues) + if self.exit_stack: + self._old_exit_stacks.append(self.exit_stack) + + # Mark old session as invalid + self.session = None + + # Create new exit stack for new connection + self.exit_stack = AsyncExitStack() + + # Reconnect using stored config + await self.connect_to_server(self._mcp_server_config, self._server_name) + await self.list_tools_and_save() + + logger.info( + f"Successfully reconnected to MCP server {self._server_name}" + ) + except Exception as e: + logger.error( + f"Failed to reconnect to MCP server {self._server_name}: {e}" + ) + raise + finally: + self._reconnecting = False + + async def call_tool_with_reconnect( + self, + tool_name: str, + arguments: dict, + read_timeout_seconds: timedelta, + ) -> mcp.types.CallToolResult: + """Call MCP tool with automatic reconnection on failure, max 2 retries. + + Args: + tool_name: tool name + arguments: tool arguments + read_timeout_seconds: read timeout + + Returns: + MCP tool call result + + Raises: + ValueError: MCP session is not available + anyio.ClosedResourceError: raised after reconnection failure + """ + + @retry( + retry=retry_if_exception_type(anyio.ClosedResourceError), + stop=stop_after_attempt(2), + wait=wait_exponential(multiplier=1, min=1, max=3), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _call_with_retry(): + if not self.session: + raise ValueError("MCP session is not available for MCP function tools.") + + try: + return await self.session.call_tool( + name=tool_name, + arguments=arguments, + read_timeout_seconds=read_timeout_seconds, + ) + except anyio.ClosedResourceError: + logger.warning( + f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." + ) + # Attempt to reconnect + await self._reconnect() + # Reraise the exception to trigger tenacity retry + raise + + return await _call_with_retry() + async def cleanup(self): - """Clean up resources""" - await self.exit_stack.aclose() - self.running_event.set() # Set the running event to indicate cleanup is done + """Clean up resources including old exit stacks from reconnections""" + # Set running_event first to unblock any waiting tasks + self.running_event.set() + + # Close current exit stack + try: + await self.exit_stack.aclose() + except Exception as e: + logger.debug(f"Error closing current exit stack: {e}") + + # Don't close old exit stacks as they may be in different task contexts + # They will be garbage collected naturally + # Just clear the list to release references + self._old_exit_stacks.clear() class MCPTool(FunctionTool, Generic[TContext]): @@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]): async def call( self, context: ContextWrapper[TContext], **kwargs ) -> mcp.types.CallToolResult: - session = self.mcp_client.session - if not session: - raise ValueError("MCP session is not available for MCP function tools.") - res = await session.call_tool( - name=self.mcp_tool.name, + return await self.mcp_client.call_tool_with_reconnect( + tool_name=self.mcp_tool.name, arguments=kwargs, - read_timeout_seconds=timedelta( - seconds=context.tool_call_timeout, - ), + read_timeout_seconds=timedelta(seconds=context.tool_call_timeout), ) - return res diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7cdbeec01..8e04423ed 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -280,19 +280,22 @@ class FunctionToolManager: async def _terminate_mcp_client(self, name: str) -> None: """关闭并清理MCP客户端""" if name in self.mcp_client_dict: + client = self.mcp_client_dict[name] try: # 关闭MCP连接 - await self.mcp_client_dict[name].cleanup() - self.mcp_client_dict.pop(name) + await client.cleanup() except Exception as e: logger.error(f"清空 MCP 客户端资源 {name}: {e}。") - # 移除关联的FuncTool - self.func_list = [ - f - for f in self.func_list - if not (isinstance(f, MCPTool) and f.mcp_server_name == name) - ] - logger.info(f"已关闭 MCP 服务 {name}") + finally: + # Remove client from dict after cleanup attempt (successful or not) + self.mcp_client_dict.pop(name, None) + # 移除关联的FuncTool + self.func_list = [ + f + for f in self.func_list + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) + ] + logger.info(f"已关闭 MCP 服务 {name}") @staticmethod async def test_mcp_server_connection(config: dict) -> list[str]: diff --git a/pyproject.toml b/pyproject.toml index 576bc1966..707581846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "jieba>=0.42.1", "markitdown-no-magika[docx,xls,xlsx]>=0.1.2", "xinference-client", + "tenacity>=9.1.2", ] [dependency-groups] @@ -107,4 +108,4 @@ exclude = ["dashboard", "node_modules", "dist", "data", "tests"] [build-system] requires = ["hatchling"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" diff --git a/requirements.txt b/requirements.txt index e8b3dee3c..b56741192 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,4 @@ rank-bm25>=0.2.2 jieba>=0.42.1 markitdown-no-magika[docx,xls,xlsx]>=0.1.2 xinference-client +tenacity>=9.1.2 \ No newline at end of file From 164a4226ea0daa971a08f939dba19f81cce053e4 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:07:09 +0800 Subject: [PATCH 10/27] feat(chat): refactor chat component structure and add new features (#3701) - Introduced `ConversationSidebar.vue` for improved conversation management and sidebar functionality. - Enhanced `MessageList.vue` to handle loading states and improved message rendering. - Created new composables: `useConversations`, `useMessages`, `useMediaHandling`, `useRecording` for better code organization and reusability. - Added loading indicators and improved user experience during message processing. - Ensured backward compatibility and maintained existing functionalities. --- dashboard/src/components/chat/Chat.vue | 1700 ++++------------- dashboard/src/components/chat/ChatInput.vue | 283 +++ .../components/chat/ConversationSidebar.vue | 310 +++ dashboard/src/components/chat/MessageList.vue | 92 +- dashboard/src/composables/useConversations.ts | 145 ++ dashboard/src/composables/useMediaHandling.ts | 104 + dashboard/src/composables/useMessages.ts | 303 +++ dashboard/src/composables/useRecording.ts | 74 + 8 files changed, 1615 insertions(+), 1396 deletions(-) create mode 100644 dashboard/src/components/chat/ChatInput.vue create mode 100644 dashboard/src/components/chat/ConversationSidebar.vue create mode 100644 dashboard/src/composables/useConversations.ts create mode 100644 dashboard/src/composables/useMediaHandling.ts create mode 100644 dashboard/src/composables/useMessages.ts create mode 100644 dashboard/src/composables/useRecording.ts diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index d671b15b7..bb3418d67 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -5,89 +5,20 @@
- +
@@ -149,69 +80,23 @@
-
-
- -
-
- - - - - - -
-
- - - - - -
-
-
- - -
-
- - -
- -
- - - {{ tm('voice.recording') }} - - -
-
-
+
@@ -227,8 +112,8 @@ - {{ t('core.common.cancel') }} - {{ t('core.common.save') }} + {{ t('core.common.cancel') }} + {{ t('core.common.save') }}
@@ -247,989 +132,271 @@ - - \ No newline at end of file diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue new file mode 100644 index 000000000..7ca0ec94a --- /dev/null +++ b/dashboard/src/components/chat/ChatInput.vue @@ -0,0 +1,283 @@ + + + + + diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue new file mode 100644 index 000000000..80b574c7f --- /dev/null +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -0,0 +1,310 @@ + + + + + diff --git a/dashboard/src/components/chat/MessageList.vue b/dashboard/src/components/chat/MessageList.vue index 9cd69241f..15f3c1d31 100644 --- a/dashboard/src/components/chat/MessageList.vue +++ b/dashboard/src/components/chat/MessageList.vue @@ -37,42 +37,49 @@
- -
-
- - {{ isReasoningExpanded(index) ? 'mdi-chevron-down' : 'mdi-chevron-right' }} - - {{ tm('reasoning.thinking') }} -
-
-
-
+ +
+ {{ tm('message.loading') }}
- -
- - -
-
- +
-
+
@@ -841,6 +848,29 @@ export default { margin: 10px 0; } +.loading-container { + display: flex; + align-items: center; + gap: 12px; + padding: 8px 0; + margin-top: 2px; +} + +.loading-text { + font-size: 14px; + color: var(--v-theme-secondaryText); + animation: pulse 1.5s ease-in-out infinite; +} + +@keyframes pulse { + 0%, 100% { + opacity: 0.6; + } + 50% { + opacity: 1; + } +} + .markdown-content blockquote { border-left: 4px solid var(--v-theme-secondary); padding-left: 16px; diff --git a/dashboard/src/composables/useConversations.ts b/dashboard/src/composables/useConversations.ts new file mode 100644 index 000000000..cf86246c8 --- /dev/null +++ b/dashboard/src/composables/useConversations.ts @@ -0,0 +1,145 @@ +import { ref, computed } from 'vue'; +import axios from 'axios'; +import { useRouter } from 'vue-router'; + +export interface Conversation { + cid: string; + title: string; + updated_at: number; +} + +export function useConversations(chatboxMode: boolean = false) { + const router = useRouter(); + const conversations = ref([]); + const selectedConversations = ref([]); + const currCid = ref(''); + const pendingCid = ref(null); + + // 编辑标题相关 + const editTitleDialog = ref(false); + const editingTitle = ref(''); + const editingCid = ref(''); + + const getCurrentConversation = computed(() => { + if (!currCid.value) return null; + return conversations.value.find(c => c.cid === currCid.value); + }); + + async function getConversations() { + try { + const response = await axios.get('/api/chat/conversations'); + conversations.value = response.data.data; + + // 处理待加载的会话 + if (pendingCid.value) { + const conversation = conversations.value.find(c => c.cid === pendingCid.value); + if (conversation) { + selectedConversations.value = [pendingCid.value]; + pendingCid.value = null; + } + } else if (!currCid.value && conversations.value.length > 0) { + // 默认选择第一个会话 + const firstConversation = conversations.value[0]; + selectedConversations.value = [firstConversation.cid]; + } + } catch (err: any) { + if (err.response?.status === 401) { + router.push('/auth/login?redirect=/chatbox'); + } + console.error(err); + } + } + + async function newConversation() { + try { + const response = await axios.get('/api/chat/new_conversation'); + const cid = response.data.data.conversation_id; + currCid.value = cid; + + // 更新 URL + const basePath = chatboxMode ? '/chatbox' : '/chat'; + router.push(`${basePath}/${cid}`); + + await getConversations(); + return cid; + } catch (err) { + console.error(err); + throw err; + } + } + + async function deleteConversation(cid: string) { + try { + await axios.get('/api/chat/delete_conversation?conversation_id=' + cid); + await getConversations(); + currCid.value = ''; + selectedConversations.value = []; + } catch (err) { + console.error(err); + } + } + + function showEditTitleDialog(cid: string, title: string) { + editingCid.value = cid; + editingTitle.value = title || ''; + editTitleDialog.value = true; + } + + async function saveTitle() { + if (!editingCid.value) return; + + const trimmedTitle = editingTitle.value.trim(); + try { + await axios.post('/api/chat/rename_conversation', { + conversation_id: editingCid.value, + title: trimmedTitle + }); + + // 更新本地会话标题 + const conversation = conversations.value.find(c => c.cid === editingCid.value); + if (conversation) { + conversation.title = trimmedTitle; + } + editTitleDialog.value = false; + } catch (err) { + console.error('重命名对话失败:', err); + } + } + + function updateConversationTitle(cid: string, title: string) { + const conversation = conversations.value.find(c => c.cid === cid); + if (conversation) { + conversation.title = title; + } + } + + function newChat(closeMobileSidebar?: () => void) { + currCid.value = ''; + selectedConversations.value = []; + + const basePath = chatboxMode ? '/chatbox' : '/chat'; + router.push(basePath); + + if (closeMobileSidebar) { + closeMobileSidebar(); + } + } + + return { + conversations, + selectedConversations, + currCid, + pendingCid, + editTitleDialog, + editingTitle, + editingCid, + getCurrentConversation, + getConversations, + newConversation, + deleteConversation, + showEditTitleDialog, + saveTitle, + updateConversationTitle, + newChat + }; +} diff --git a/dashboard/src/composables/useMediaHandling.ts b/dashboard/src/composables/useMediaHandling.ts new file mode 100644 index 000000000..e24c25fb8 --- /dev/null +++ b/dashboard/src/composables/useMediaHandling.ts @@ -0,0 +1,104 @@ +import { ref } from 'vue'; +import axios from 'axios'; + +export function useMediaHandling() { + const stagedImagesName = ref([]); + const stagedImagesUrl = ref([]); + const stagedAudioUrl = ref(''); + const mediaCache = ref>({}); + + async function getMediaFile(filename: string): Promise { + if (mediaCache.value[filename]) { + return mediaCache.value[filename]; + } + + try { + const response = await axios.get('/api/chat/get_file', { + params: { filename }, + responseType: 'blob' + }); + + const blobUrl = URL.createObjectURL(response.data); + mediaCache.value[filename] = blobUrl; + return blobUrl; + } catch (error) { + console.error('Error fetching media file:', error); + return ''; + } + } + + async function processAndUploadImage(file: File) { + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await axios.post('/api/chat/post_image', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }); + + const img = response.data.data.filename; + stagedImagesName.value.push(img); + stagedImagesUrl.value.push(URL.createObjectURL(file)); + } catch (err) { + console.error('Error uploading image:', err); + } + } + + async function handlePaste(event: ClipboardEvent) { + const items = event.clipboardData?.items; + if (!items) return; + + for (let i = 0; i < items.length; i++) { + if (items[i].type.indexOf('image') !== -1) { + const file = items[i].getAsFile(); + if (file) { + await processAndUploadImage(file); + } + } + } + } + + function removeImage(index: number) { + const urlToRevoke = stagedImagesUrl.value[index]; + if (urlToRevoke && urlToRevoke.startsWith('blob:')) { + URL.revokeObjectURL(urlToRevoke); + } + + stagedImagesName.value.splice(index, 1); + stagedImagesUrl.value.splice(index, 1); + } + + function removeAudio() { + stagedAudioUrl.value = ''; + } + + function clearStaged() { + stagedImagesName.value = []; + stagedImagesUrl.value = []; + stagedAudioUrl.value = ''; + } + + function cleanupMediaCache() { + Object.values(mediaCache.value).forEach(url => { + if (url.startsWith('blob:')) { + URL.revokeObjectURL(url); + } + }); + mediaCache.value = {}; + } + + return { + stagedImagesName, + stagedImagesUrl, + stagedAudioUrl, + getMediaFile, + processAndUploadImage, + handlePaste, + removeImage, + removeAudio, + clearStaged, + cleanupMediaCache + }; +} diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts new file mode 100644 index 000000000..dc243f1b4 --- /dev/null +++ b/dashboard/src/composables/useMessages.ts @@ -0,0 +1,303 @@ +import { ref, reactive, type Ref } from 'vue'; +import axios from 'axios'; +import { useToast } from '@/utils/toast'; + +export interface MessageContent { + type: string; + message: string; + reasoning?: string; + image_url?: string[]; + audio_url?: string; + embedded_images?: string[]; + embedded_audio?: string; + isLoading?: boolean; +} + +export interface Message { + content: MessageContent; +} + +export function useMessages( + currCid: Ref, + getMediaFile: (filename: string) => Promise, + updateConversationTitle: (cid: string, title: string) => void, + onConversationsUpdate: () => void +) { + const messages = ref([]); + const isStreaming = ref(false); + const isConvRunning = ref(false); + const isToastedRunningInfo = ref(false); + const activeSSECount = ref(0); + const enableStreaming = ref(true); + + // 从 localStorage 读取流式响应开关状态 + const savedStreamingState = localStorage.getItem('enableStreaming'); + if (savedStreamingState !== null) { + enableStreaming.value = JSON.parse(savedStreamingState); + } + + function toggleStreaming() { + enableStreaming.value = !enableStreaming.value; + localStorage.setItem('enableStreaming', JSON.stringify(enableStreaming.value)); + } + + async function getConversationMessages(cid: string, router: any) { + if (!cid) return; + + try { + const response = await axios.get('/api/chat/get_conversation?conversation_id=' + cid); + isConvRunning.value = response.data.data.is_running || false; + let history = response.data.data.history; + + if (isConvRunning.value) { + if (!isToastedRunningInfo.value) { + useToast().info("该对话正在运行中。", { timeout: 5000 }); + isToastedRunningInfo.value = true; + } + + // 如果对话还在运行,3秒后重新获取消息 + setTimeout(() => { + getConversationMessages(currCid.value, router); + }, 3000); + } + + // 处理历史消息中的媒体文件 + for (let i = 0; i < history.length; i++) { + let content = history[i].content; + + if (content.message?.startsWith('[IMAGE]')) { + let img = content.message.replace('[IMAGE]', ''); + const imageUrl = await getMediaFile(img); + if (!content.embedded_images) { + content.embedded_images = []; + } + content.embedded_images.push(imageUrl); + content.message = ''; + } + + if (content.message?.startsWith('[RECORD]')) { + let audio = content.message.replace('[RECORD]', ''); + const audioUrl = await getMediaFile(audio); + content.embedded_audio = audioUrl; + content.message = ''; + } + + if (content.image_url && content.image_url.length > 0) { + for (let j = 0; j < content.image_url.length; j++) { + content.image_url[j] = await getMediaFile(content.image_url[j]); + } + } + + if (content.audio_url) { + content.audio_url = await getMediaFile(content.audio_url); + } + } + + messages.value = history; + } catch (err) { + console.error(err); + } + } + + async function sendMessage( + prompt: string, + imageNames: string[], + audioName: string, + selectedProviderId: string, + selectedModelName: string + ) { + // Create user message + const userMessage: MessageContent = { + type: 'user', + message: prompt, + image_url: [], + audio_url: undefined + }; + + // Convert image filenames to blob URLs + if (imageNames.length > 0) { + const imagePromises = imageNames.map(name => { + if (!name.startsWith('blob:')) { + return getMediaFile(name); + } + return Promise.resolve(name); + }); + userMessage.image_url = await Promise.all(imagePromises); + } + + // Convert audio filename to blob URL + if (audioName) { + if (!audioName.startsWith('blob:')) { + userMessage.audio_url = await getMediaFile(audioName); + } else { + userMessage.audio_url = audioName; + } + } + + messages.value.push({ content: userMessage }); + + // 添加一个加载中的机器人消息占位符 + const loadingMessage = reactive({ + type: 'bot', + message: '', + reasoning: '', + isLoading: true + }); + messages.value.push({ content: loadingMessage }); + + try { + activeSSECount.value++; + if (activeSSECount.value === 1) { + isConvRunning.value = true; + } + + const response = await fetch('/api/chat/send', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + localStorage.getItem('token') + }, + body: JSON.stringify({ + message: prompt, + conversation_id: currCid.value, + image_url: imageNames, + audio_url: audioName ? [audioName] : [], + selected_provider: selectedProviderId, + selected_model: selectedModelName, + enable_streaming: enableStreaming.value + }) + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const reader = response.body!.getReader(); + const decoder = new TextDecoder(); + let in_streaming = false; + let message_obj: any = null; + + isStreaming.value = true; + + while (true) { + try { + const { done, value } = await reader.read(); + if (done) { + console.log('SSE stream completed'); + break; + } + + const chunk = decoder.decode(value, { stream: true }); + const lines = chunk.split('\n\n'); + + for (let i = 0; i < lines.length; i++) { + let line = lines[i].trim(); + if (!line) continue; + + let chunk_json; + try { + chunk_json = JSON.parse(line.replace('data: ', '')); + } catch (parseError) { + console.warn('JSON解析失败:', line, parseError); + continue; + } + + if (!chunk_json || typeof chunk_json !== 'object' || !chunk_json.hasOwnProperty('type')) { + console.warn('无效的数据对象:', chunk_json); + continue; + } + + if (chunk_json.type === 'error') { + console.error('Error received:', chunk_json.data); + continue; + } + + if (chunk_json.type === 'image') { + let img = chunk_json.data.replace('[IMAGE]', ''); + const imageUrl = await getMediaFile(img); + let bot_resp: MessageContent = { + type: 'bot', + message: '', + embedded_images: [imageUrl] + }; + messages.value.push({ content: bot_resp }); + } else if (chunk_json.type === 'record') { + let audio = chunk_json.data.replace('[RECORD]', ''); + const audioUrl = await getMediaFile(audio); + let bot_resp: MessageContent = { + type: 'bot', + message: '', + embedded_audio: audioUrl + }; + messages.value.push({ content: bot_resp }); + } else if (chunk_json.type === 'plain') { + const chain_type = chunk_json.chain_type || 'normal'; + + if (!in_streaming) { + // 移除加载占位符 + const lastMsg = messages.value[messages.value.length - 1]; + if (lastMsg?.content?.isLoading) { + messages.value.pop(); + } + + message_obj = reactive({ + type: 'bot', + message: chain_type === 'reasoning' ? '' : chunk_json.data, + reasoning: chain_type === 'reasoning' ? chunk_json.data : '', + }); + messages.value.push({ content: message_obj }); + in_streaming = true; + } else { + if (chain_type === 'reasoning') { + // 使用 reactive 对象,直接修改属性会触发响应式更新 + message_obj.reasoning = (message_obj.reasoning || '') + chunk_json.data; + } else { + message_obj.message = (message_obj.message || '') + chunk_json.data; + } + } + } else if (chunk_json.type === 'update_title') { + updateConversationTitle(chunk_json.cid, chunk_json.data); + } + + if ((chunk_json.type === 'break' && chunk_json.streaming) || !chunk_json.streaming) { + in_streaming = false; + if (!chunk_json.streaming) { + isStreaming.value = false; + } + } + } + } catch (readError) { + console.error('SSE读取错误:', readError); + break; + } + } + + // 获取最新的对话列表 + onConversationsUpdate(); + + } catch (err) { + console.error('发送消息失败:', err); + // 移除加载占位符 + const lastMsg = messages.value[messages.value.length - 1]; + if (lastMsg?.content?.isLoading) { + messages.value.pop(); + } + } finally { + isStreaming.value = false; + activeSSECount.value--; + if (activeSSECount.value === 0) { + isConvRunning.value = false; + } + } + } + + return { + messages, + isStreaming, + isConvRunning, + enableStreaming, + getConversationMessages, + sendMessage, + toggleStreaming + }; +} diff --git a/dashboard/src/composables/useRecording.ts b/dashboard/src/composables/useRecording.ts new file mode 100644 index 000000000..4b03e8508 --- /dev/null +++ b/dashboard/src/composables/useRecording.ts @@ -0,0 +1,74 @@ +import { ref } from 'vue'; +import axios from 'axios'; + +export function useRecording() { + const isRecording = ref(false); + const audioChunks = ref([]); + const mediaRecorder = ref(null); + + async function startRecording(onStart?: (label: string) => void) { + try { + const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + mediaRecorder.value = new MediaRecorder(stream); + + mediaRecorder.value.ondataavailable = (event) => { + audioChunks.value.push(event.data); + }; + + mediaRecorder.value.start(); + isRecording.value = true; + + if (onStart) { + onStart('录音中...'); + } + } catch (error) { + console.error('Failed to start recording:', error); + } + } + + async function stopRecording(onStop?: (label: string) => void): Promise { + return new Promise((resolve, reject) => { + if (!mediaRecorder.value) { + reject('No media recorder'); + return; + } + + isRecording.value = false; + if (onStop) { + onStop('聊天输入框'); + } + + mediaRecorder.value.stop(); + mediaRecorder.value.onstop = async () => { + const audioBlob = new Blob(audioChunks.value, { type: 'audio/wav' }); + audioChunks.value = []; + + mediaRecorder.value?.stream.getTracks().forEach(track => track.stop()); + + const formData = new FormData(); + formData.append('file', audioBlob); + + try { + const response = await axios.post('/api/chat/post_file', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }); + + const audio = response.data.data.filename; + console.log('Audio uploaded:', audio); + resolve(audio); + } catch (err) { + console.error('Error uploading audio:', err); + reject(err); + } + }; + }); + } + + return { + isRecording, + startRecording, + stopRecording + }; +} From 8e511bf14b49defc6bcacebb7331e1543f0e12f4 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 19 Nov 2025 15:40:41 +0800 Subject: [PATCH 11/27] fix: build docker ci failed --- .github/workflows/docker-image.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index dfef51365..02bff6a5b 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest env: DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }} - GHCR_OWNER: ${{ github.repository_owner }} + GHCR_OWNER: soulter HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }} steps: From 77dd89b8eb3098f9bcf357fd3c80f8f7a35a6b8e Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Wed, 19 Nov 2025 18:54:56 +0800 Subject: [PATCH 12/27] feat: add supports for gemini-3 series thought signature (#3698) * feat: add supports for gemini-3 series thought signature * feat: refactor tools_call_extra_content to use a dictionary for better structure --- astrbot/core/agent/message.py | 7 ++++ astrbot/core/db/po.py | 20 ++++------- astrbot/core/provider/entities.py | 31 ++++++++++++----- .../core/provider/sources/gemini_source.py | 33 ++++++++++++++----- .../core/provider/sources/openai_source.py | 25 +++++--------- 5 files changed, 70 insertions(+), 46 deletions(-) diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index 4a2e1b149..4c65c32f6 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -119,6 +119,13 @@ class ToolCall(BaseModel): """The ID of the tool call.""" function: FunctionBody """The function body of the tool call.""" + extra_content: dict[str, Any] | None = None + """Extra metadata for the tool call.""" + + def model_dump(self, **kwargs: Any) -> dict[str, Any]: + if self.extra_content is None: + kwargs.setdefault("exclude", set()).add("extra_content") + return super().model_dump(**kwargs) class ToolCallPart(BaseModel): diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 9fc871d08..8fb14c19f 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -3,13 +3,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TypedDict -from sqlmodel import ( - JSON, - Field, - SQLModel, - Text, - UniqueConstraint, -) +from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint class PlatformStat(SQLModel, table=True): @@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True): Note: In astrbot v4, we moved `platform` table to here. """ - __tablename__ = "platform_stats" + __tablename__ = "platform_stats" # type: ignore id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) timestamp: datetime = Field(nullable=False) @@ -37,7 +31,7 @@ class PlatformStat(SQLModel, table=True): class ConversationV2(SQLModel, table=True): - __tablename__ = "conversations" + __tablename__ = "conversations" # type: ignore inner_conversation_id: int = Field( primary_key=True, @@ -74,7 +68,7 @@ class Persona(SQLModel, table=True): It can be used to customize the behavior of LLMs. """ - __tablename__ = "personas" + __tablename__ = "personas" # type: ignore id: int | None = Field( primary_key=True, @@ -104,7 +98,7 @@ class Persona(SQLModel, table=True): class Preference(SQLModel, table=True): """This class represents preferences for bots.""" - __tablename__ = "preferences" + __tablename__ = "preferences" # type: ignore id: int | None = Field( default=None, @@ -140,7 +134,7 @@ class PlatformMessageHistory(SQLModel, table=True): or platform-specific messages. """ - __tablename__ = "platform_message_history" + __tablename__ = "platform_message_history" # type: ignore id: int | None = Field( primary_key=True, @@ -209,7 +203,7 @@ class Attachment(SQLModel, table=True): Attachments can be images, files, or other media types. """ - __tablename__ = "attachments" + __tablename__ = "attachments" # type: ignore inner_attachment_id: int | None = Field( primary_key=True, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index c6978e7b9..dc188f141 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -211,6 +211,8 @@ class LLMResponse: """Tool call names.""" tools_call_ids: list[str] = field(default_factory=list) """Tool call IDs.""" + tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict) + """Tool call extra content. tool_call_id -> extra_content dict""" reasoning_content: str = "" """The reasoning content extracted from the LLM, if any.""" @@ -233,6 +235,7 @@ class LLMResponse: tools_call_args: list[dict[str, Any]] | None = None, tools_call_name: list[str] | None = None, tools_call_ids: list[str] | None = None, + tools_call_extra_content: dict[str, dict[str, Any]] | None = None, raw_completion: ChatCompletion | GenerateContentResponse | AnthropicMessage @@ -256,6 +259,8 @@ class LLMResponse: tools_call_name = [] if tools_call_ids is None: tools_call_ids = [] + if tools_call_extra_content is None: + tools_call_extra_content = {} self.role = role self.completion_text = completion_text @@ -263,6 +268,7 @@ class LLMResponse: self.tools_call_args = tools_call_args self.tools_call_name = tools_call_name self.tools_call_ids = tools_call_ids + self.tools_call_extra_content = tools_call_extra_content self.raw_completion = raw_completion self.is_chunk = is_chunk @@ -288,16 +294,19 @@ class LLMResponse: """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): - ret.append( - { - "id": self.tools_call_ids[idx], - "function": { - "name": self.tools_call_name[idx], - "arguments": json.dumps(tool_call_arg), - }, - "type": "function", + payload = { + "id": self.tools_call_ids[idx], + "function": { + "name": self.tools_call_name[idx], + "arguments": json.dumps(tool_call_arg), }, - ) + "type": "function", + } + if self.tools_call_extra_content.get(self.tools_call_ids[idx]): + payload["extra_content"] = self.tools_call_extra_content[ + self.tools_call_ids[idx] + ] + ret.append(payload) return ret def to_openai_to_calls_model(self) -> list[ToolCall]: @@ -311,6 +320,10 @@ class LLMResponse: name=self.tools_call_name[idx], arguments=json.dumps(tool_call_arg), ), + # the extra_content will not serialize if it's None when calling ToolCall.model_dump() + extra_content=self.tools_call_extra_content.get( + self.tools_call_ids[idx] + ), ), ) return ret diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index b9159eec9..e14140d43 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -290,13 +290,24 @@ class ProviderGoogleGenAI(Provider): parts = [types.Part.from_text(text=content)] append_or_extend(gemini_contents, parts, types.ModelContent) elif not native_tool_enabled and "tool_calls" in message: - parts = [ - types.Part.from_function_call( + parts = [] + for tool in message["tool_calls"]: + part = types.Part.from_function_call( name=tool["function"]["name"], args=json.loads(tool["function"]["arguments"]), ) - for tool in message["tool_calls"] - ] + # we should set thought_signature back to part if exists + # for more info about thought_signature, see: + # https://ai.google.dev/gemini-api/docs/thought-signatures + if "extra_content" in tool: + ts_bs64 = ( + tool["extra_content"] + .get("google", {}) + .get("thought_signature") + ) + if ts_bs64: + part.thought_signature = base64.b64decode(ts_bs64) + parts.append(part) append_or_extend(gemini_contents, parts, types.ModelContent) else: logger.warning("assistant 角色的消息内容为空,已添加空格占位") @@ -393,10 +404,15 @@ class ProviderGoogleGenAI(Provider): llm_response.role = "tool" llm_response.tools_call_name.append(part.function_call.name) llm_response.tools_call_args.append(part.function_call.args) - # gemini 返回的 function_call.id 可能为 None - llm_response.tools_call_ids.append( - part.function_call.id or part.function_call.name, - ) + # function_call.id might be None, use name as fallback + tool_call_id = part.function_call.id or part.function_call.name + llm_response.tools_call_ids.append(tool_call_id) + # extra_content + if part.thought_signature: + ts_bs64 = base64.b64encode(part.thought_signature).decode("utf-8") + llm_response.tools_call_extra_content[tool_call_id] = { + "google": {"thought_signature": ts_bs64} + } elif ( part.inline_data and part.inline_data.mime_type @@ -435,6 +451,7 @@ class ProviderGoogleGenAI(Provider): contents=conversation, config=config, ) + logger.debug(f"genai result: {result}") if not result.candidates: logger.error(f"请求失败, 返回的 candidates 为空: {result}") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index da2ce68f8..3f1d283ce 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -8,7 +8,7 @@ import re from collections.abc import AsyncGenerator from openai import AsyncAzureOpenAI, AsyncOpenAI -from openai._exceptions import NotFoundError, UnprocessableEntityError +from openai._exceptions import NotFoundError from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -279,6 +279,7 @@ class ProviderOpenAIOfficial(Provider): args_ls = [] func_name_ls = [] tool_call_ids = [] + tool_call_extra_content_dict = {} for tool_call in choice.message.tool_calls: if isinstance(tool_call, str): # workaround for #1359 @@ -296,11 +297,16 @@ class ProviderOpenAIOfficial(Provider): args_ls.append(args) func_name_ls.append(tool_call.function.name) tool_call_ids.append(tool_call.id) + + # gemini-2.5 / gemini-3 series extra_content handling + extra_content = getattr(tool_call, "extra_content", None) + if extra_content is not None: + tool_call_extra_content_dict[tool_call.id] = extra_content llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls llm_response.tools_call_ids = tool_call_ids - + llm_response.tools_call_extra_content = tool_call_extra_content_dict # specially handle finish reason if choice.finish_reason == "content_filter": raise Exception( @@ -353,7 +359,7 @@ class ProviderOpenAIOfficial(Provider): payloads = {"messages": context_query, **model_config} - # xAI 原生搜索参数(最小侵入地在此处注入) + # xAI origin search tool inject self._maybe_inject_xai_search(payloads, **kwargs) return payloads, context_query @@ -475,12 +481,6 @@ class ProviderOpenAIOfficial(Provider): self.client.api_key = chosen_key llm_response = await self._query(payloads, func_tool) break - except UnprocessableEntityError as e: - logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") - # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - context_query = new_contexts except Exception as e: last_exception = e ( @@ -545,12 +545,6 @@ class ProviderOpenAIOfficial(Provider): async for response in self._query_stream(payloads, func_tool): yield response break - except UnprocessableEntityError as e: - logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") - # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - context_query = new_contexts except Exception as e: last_exception = e ( @@ -646,4 +640,3 @@ class ProviderOpenAIOfficial(Provider): with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - return "" From 67a9663eff7ed41da6108c3ab84e6f46a240a225 Mon Sep 17 00:00:00 2001 From: Dt8333 <25431943+Dt8333@users.noreply.github.com> Date: Wed, 19 Nov 2025 21:36:34 +0800 Subject: [PATCH 13/27] fix(dashboard.i18n): complete the missing i18n keys(#3699) #3679 --- .../src/i18n/locales/en-US/core/common.json | 3 +- .../src/i18n/locales/en-US/core/header.json | 3 +- .../i18n/locales/en-US/core/navigation.json | 4 +- .../i18n/locales/en-US/features/config.json | 8 +++- .../en-US/features/knowledge-base/detail.json | 24 ++++++++++++ .../features/knowledge-base/document.json | 5 ++- .../i18n/locales/en-US/features/tool-use.json | 38 ++++++++++++++++++- 7 files changed, 76 insertions(+), 9 deletions(-) diff --git a/dashboard/src/i18n/locales/en-US/core/common.json b/dashboard/src/i18n/locales/en-US/core/common.json index 4aff41001..37b384199 100644 --- a/dashboard/src/i18n/locales/en-US/core/common.json +++ b/dashboard/src/i18n/locales/en-US/core/common.json @@ -72,7 +72,8 @@ "enabled": "Enabled", "disabled": "Disabled", "delete": "Delete", + "copy": "Copy", "edit": "Edit", "noData": "No data available" } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/en-US/core/header.json b/dashboard/src/i18n/locales/en-US/core/header.json index 41c5ac0dc..718cc60e6 100644 --- a/dashboard/src/i18n/locales/en-US/core/header.json +++ b/dashboard/src/i18n/locales/en-US/core/header.json @@ -32,7 +32,6 @@ "issueLink": "GitHub Issues" }, "tip": "💡 TIP:", - "tipLink": "", "tipContinue": "By default, the corresponding version of the WebUI files will be downloaded when switching versions. The WebUI code is located in the dashboard directory of the project, and you can use npm to build it yourself.", "dockerTip": "When switching versions, it will try to update both the bot main program and the dashboard. If you are using Docker deployment, you can also re-pull the image or use", "dockerTipLink": "watchtower", @@ -91,4 +90,4 @@ "updateFailed": "Update failed, please try again" } } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/en-US/core/navigation.json b/dashboard/src/i18n/locales/en-US/core/navigation.json index 809b10183..9351d1da4 100644 --- a/dashboard/src/i18n/locales/en-US/core/navigation.json +++ b/dashboard/src/i18n/locales/en-US/core/navigation.json @@ -5,8 +5,8 @@ "persona": "Persona", "toolUse": "MCP Tools", "config": "Config", - "extension": "Extensions", "chat": "Chat", + "extension": "Extensions", "conversation": "Conversations", "sessionManagement": "Session Management", "console": "Console", @@ -20,4 +20,4 @@ "groups": { "more": "More Features" } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/en-US/features/config.json b/dashboard/src/i18n/locales/en-US/features/config.json index c6fee467f..eebab4a73 100644 --- a/dashboard/src/i18n/locales/en-US/features/config.json +++ b/dashboard/src/i18n/locales/en-US/features/config.json @@ -30,7 +30,11 @@ "configApplyError": "Configuration not applied, JSON format error.", "saveSuccess": "Configuration saved successfully", "saveError": "Failed to save configuration", - "loadError": "Failed to load configuration" + "loadError": "Failed to load configuration", + "deleteSuccess": "Deleted successfully", + "deleteError": "Failed to delete", + "updateSuccess": "Updated successfully", + "updateError": "Failed to update" }, "sections": { "general": "General Settings", @@ -59,4 +63,4 @@ "rateLimit": "Rate Limit", "encryption": "Encryption Settings" } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json index 1ae3c09da..90d3e6158 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/detail.json @@ -74,6 +74,30 @@ "urlHint": "The main content will be automatically extracted from the target URL as a document. Currently supports {supported} pages. Before use, please ensure that the target web page allows crawler access.", "beta": "Beta" }, + "retrieval": { + "title": "Retrieval", + "subtitle": "Test the knowledge base using dense and sparse retrieval methods", + "query": "Query", + "queryPlaceholder": "Enter a query...", + "search": "Search", + "searching": "Searching...", + "results": "Results", + "noResults": "No results found", + "tryDifferentQuery": "Try a different query", + "settings": "Retrieval Settings", + "topK": "Number of Results", + "topKHint": "Maximum number of results to return", + "enableRerank": "Enable Rerank", + "enableRerankHint": "Use a rerank model to improve retrieval quality", + "score": "Relevance Score", + "document": "Document", + "chunk": "Chunk #{index}", + "content": "Content", + "charCount": "{count} characters", + "searchSuccess": "Search completed, found {count} results", + "searchFailed": "Search failed", + "queryRequired": "Please enter a query" + }, "settings": { "title": "Knowledge Base Settings", "basic": "Basic Settings", diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json index 35c430aad..d3a3b65c9 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json @@ -22,7 +22,10 @@ "preview": "Preview", "search": "Search Chunks", "searchPlaceholder": "Enter keywords to search chunks...", - "showing": "Showing" + "showing": "Showing", + "deleteConfirm": "Are you sure you want to delete this chunk?", + "deleteSuccess": "Chunk deleted successfully", + "deleteFailed": "Failed to delete chunk" }, "edit": { "title": "Edit Chunk", diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json index 2887d78fa..8a6ccd492 100644 --- a/dashboard/src/i18n/locales/en-US/features/tool-use.json +++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json @@ -96,6 +96,42 @@ }, "confirmDelete": "Are you sure you want to delete server {name}?" }, + "syncProvider": { + "title": "Sync MCP Servers", + "subtitle": "Sync MCP server configurations from providers to local", + "steps": { + "selectProvider": "Step 1: Select Provider", + "configureAuth": "Step 2: Configure Authentication", + "syncServers": "Step 3: Sync Servers" + }, + "providers": { + "modelscope": "ModelScope", + "description": "ModelScope is an open model community providing MCP servers for various machine learning and AI services" + }, + "fields": { + "provider": "Select Provider", + "accessToken": "Access Token", + "tokenRequired": "Access token is required", + "tokenHint": "Please enter your ModelScope access token" + }, + "buttons": { + "cancel": "Cancel", + "previous": "Previous", + "next": "Next", + "sync": "Start Sync", + "getToken": "Get Token" + }, + "status": { + "selectProvider": "Please select an MCP server provider", + "enterToken": "Please enter the access token to continue", + "readyToSync": "Ready to sync server configurations" + }, + "messages": { + "syncSuccess": "MCP servers synced successfully!", + "syncError": "Sync failed: {error}", + "tokenHelp": "How to get a ModelScope access token? Click the button on the right for instructions" + } + }, "messages": { "getServersError": "Failed to get MCP server list: {error}", "getToolsError": "Failed to get function tools list: {error}", @@ -117,4 +153,4 @@ "toggleToolError": "Failed to toggle tool status: {error}", "testError": "Test connection failed: {error}" } -} \ No newline at end of file +} From d5280dcd88af1cd6eaf0b7d72184bfa82e49402e Mon Sep 17 00:00:00 2001 From: Dt8333 <25431943+Dt8333@users.noreply.github.com> Date: Wed, 19 Nov 2025 21:44:38 +0800 Subject: [PATCH 14/27] =?UTF-8?q?fix(core.platform):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=90=AF=E7=94=A8=E5=A4=9A=E4=B8=AA=E4=BC=81=E4=B8=9A=E5=BE=AE?= =?UTF-8?q?=E4=BF=A1=E6=99=BA=E8=83=BD=E6=9C=BA=E5=99=A8=E4=BA=BA=E9=80=82?= =?UTF-8?q?=E9=85=8D=E5=99=A8=E6=97=B6=E6=B6=88=E6=81=AF=E6=B7=B7=E4=B9=B1?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98=20(#3693)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(core.platform): 修复启用多个企业微信智能机器人适配器时消息混乱的问题 移除了全局的消息队列,改为每个适配器处理自己的队列。修改相关方法适应该更改。 #3673 * chore: apply suggestions from code review Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --------- Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com> Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- .../sources/wecom_ai_bot/wecomai_adapter.py | 23 +++++++++++-------- .../sources/wecom_ai_bot/wecomai_event.py | 12 ++++++---- .../sources/wecom_ai_bot/wecomai_queue_mgr.py | 4 ---- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 29ac02653..9c13cfeff 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -30,7 +30,7 @@ from .wecomai_api import ( WecomAIBotStreamMessageBuilder, ) from .wecomai_event import WecomAIBotMessageEvent -from .wecomai_queue_mgr import WecomAIQueueMgr, wecomai_queue_mgr +from .wecomai_queue_mgr import WecomAIQueueMgr from .wecomai_server import WecomAIBotServer from .wecomai_utils import ( WecomAIBotConstants, @@ -144,9 +144,12 @@ class WecomAIBotAdapter(Platform): # 事件循环和关闭信号 self.shutdown_event = asyncio.Event() + # 队列管理器 + self.queue_mgr = WecomAIQueueMgr() + # 队列监听器 self.queue_listener = WecomAIQueueListener( - wecomai_queue_mgr, + self.queue_mgr, self._handle_queued_message, ) @@ -189,7 +192,7 @@ class WecomAIBotAdapter(Platform): stream_id, session_id, ) - wecomai_queue_mgr.set_pending_response(stream_id, callback_params) + self.queue_mgr.set_pending_response(stream_id, callback_params) resp = WecomAIBotStreamMessageBuilder.make_text_stream( stream_id, @@ -207,7 +210,7 @@ class WecomAIBotAdapter(Platform): elif msgtype == "stream": # wechat server is requesting for updates of a stream stream_id = message_data["stream"]["id"] - if not wecomai_queue_mgr.has_back_queue(stream_id): + if not self.queue_mgr.has_back_queue(stream_id): logger.error(f"Cannot find back queue for stream_id: {stream_id}") # 返回结束标志,告诉微信服务器流已结束 @@ -222,7 +225,7 @@ class WecomAIBotAdapter(Platform): callback_params["timestamp"], ) return resp - queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + queue = self.queue_mgr.get_or_create_back_queue(stream_id) if queue.empty(): logger.debug( f"No new messages in back queue for stream_id: {stream_id}", @@ -242,10 +245,9 @@ class WecomAIBotAdapter(Platform): elif msg["type"] == "end": # stream end finish = True - wecomai_queue_mgr.remove_queues(stream_id) + self.queue_mgr.remove_queues(stream_id) break - else: - pass + logger.debug( f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}", ) @@ -313,8 +315,8 @@ class WecomAIBotAdapter(Platform): session_id: str, ): """将消息放入队列进行异步处理""" - input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id) - _ = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + input_queue = self.queue_mgr.get_or_create_queue(stream_id) + _ = self.queue_mgr.get_or_create_back_queue(stream_id) message_payload = { "message_data": message_data, "callback_params": callback_params, @@ -453,6 +455,7 @@ class WecomAIBotAdapter(Platform): platform_meta=self.meta(), session_id=message.session_id, api_client=self.api_client, + queue_mgr=self.queue_mgr, ) self.commit_event(message_event) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index 130182b48..0091783a4 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -8,7 +8,7 @@ from astrbot.api.message_components import ( ) from .wecomai_api import WecomAIBotAPIClient -from .wecomai_queue_mgr import wecomai_queue_mgr +from .wecomai_queue_mgr import WecomAIQueueMgr class WecomAIBotMessageEvent(AstrMessageEvent): @@ -21,6 +21,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): platform_meta, session_id: str, api_client: WecomAIBotAPIClient, + queue_mgr: WecomAIQueueMgr, ): """初始化消息事件 @@ -34,14 +35,16 @@ class WecomAIBotMessageEvent(AstrMessageEvent): """ super().__init__(message_str, message_obj, platform_meta, session_id) self.api_client = api_client + self.queue_mgr = queue_mgr @staticmethod async def _send( message_chain: MessageChain, stream_id: str, + queue_mgr: WecomAIQueueMgr, streaming: bool = False, ): - back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + back_queue = queue_mgr.get_or_create_back_queue(stream_id) if not message_chain: await back_queue.put( @@ -94,7 +97,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "wecom_ai_bot platform event raw_message should be a dict" ) stream_id = raw.get("stream_id", self.session_id) - await WecomAIBotMessageEvent._send(message, stream_id) + await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await super().send(message) async def send_streaming(self, generator, use_fallback=False): @@ -105,7 +108,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "wecom_ai_bot platform event raw_message should be a dict" ) stream_id = raw.get("stream_id", self.session_id) - back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + back_queue = self.queue_mgr.get_or_create_back_queue(stream_id) # 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送 increment_plain = "" @@ -134,6 +137,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): final_data += await WecomAIBotMessageEvent._send( chain, stream_id=stream_id, + queue_mgr=self.queue_mgr, streaming=True, ) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index eb3455292..3a982bdf7 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -151,7 +151,3 @@ class WecomAIQueueMgr: "output_queues": len(self.back_queues), "pending_responses": len(self.pending_responses), } - - -# 全局队列管理器实例 -wecomai_queue_mgr = WecomAIQueueMgr() From e9805ba20541daa31e9fc6d80f678a32cc0aaca7 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:01:22 +0800 Subject: [PATCH 15/27] fix: anyio.ClosedResourceError when calling mcp tools (#3700) * fix: anyio.ClosedResourceError when calling mcp tools added reconnect mechanism fixes: 3676 * fix(mcp_client): implement thread-safe reconnection using asyncio.Lock --- astrbot/core/agent/mcp_client.py | 180 +++++++++++++++++---- astrbot/core/provider/func_tool_manager.py | 21 +-- pyproject.toml | 3 +- requirements.txt | 1 + 4 files changed, 168 insertions(+), 37 deletions(-) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 05980b212..88cab486e 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -4,6 +4,14 @@ from contextlib import AsyncExitStack from datetime import timedelta from typing import Generic +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + from astrbot import logger from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe @@ -12,21 +20,24 @@ from .run_context import TContext from .tool import FunctionTool try: + import anyio import mcp from mcp.client.sse import sse_client except (ModuleNotFoundError, ImportError): - logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。") + logger.warning( + "Warning: Missing 'mcp' dependency, MCP services will be unavailable." + ) try: from mcp.client.streamable_http import streamablehttp_client except (ModuleNotFoundError, ImportError): logger.warning( - "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。", + "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.", ) def _prepare_config(config: dict) -> dict: - """准备配置,处理嵌套格式""" + """Prepare configuration, handle nested format""" if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] @@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict: async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: - """快速测试 MCP 服务器可达性""" + """Quick test MCP server connectivity""" import aiohttp cfg = _prepare_config(config.copy()) @@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: elif "type" in cfg: transport_type = cfg["type"] else: - raise Exception("MCP 连接配置缺少 transport 或 type 字段") + raise Exception("MCP connection config missing transport or type field") async with aiohttp.ClientSession() as session: if transport_type == "streamable_http": @@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"HTTP {response.status}: {response.reason}" except asyncio.TimeoutError: - return False, f"连接超时: {timeout}秒" + return False, f"Connection timeout: {timeout} seconds" except Exception as e: return False, f"{e!s}" @@ -101,6 +112,7 @@ class MCPClient: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() + self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup self.name: str | None = None self.active: bool = True @@ -108,22 +120,32 @@ class MCPClient: self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() - async def connect_to_server(self, mcp_server_config: dict, name: str): - """连接到 MCP 服务器 + # Store connection config for reconnection + self._mcp_server_config: dict | None = None + self._server_name: str | None = None + self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection + self._reconnecting: bool = False # For logging and debugging - 如果 `url` 参数存在: - 1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。 - 1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。 - 2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。 + async def connect_to_server(self, mcp_server_config: dict, name: str): + """Connect to MCP server + + If `url` parameter exists: + 1. When transport is specified as `streamable_http`, use Streamable HTTP connection. + 2. When transport is specified as `sse`, use SSE connection. + 3. If not specified, default to SSE connection to MCP service. Args: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ + # Store config for reconnection + self._mcp_server_config = mcp_server_config + self._server_name = name + cfg = _prepare_config(mcp_server_config.copy()) def logging_callback(msg: str): - # 处理 MCP 服务的错误日志 + # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -137,7 +159,7 @@ class MCPClient: elif "type" in cfg: transport_type = cfg["type"] else: - raise Exception("MCP 连接配置缺少 transport 或 type 字段") + raise Exception("MCP connection config missing transport or type field") if transport_type != "streamable_http": # SSE transport method @@ -193,7 +215,7 @@ class MCPClient: ) def callback(msg: str): - # 处理 MCP 服务的错误日志 + # Handle MCP service error logs self.server_errlogs.append(msg) stdio_transport = await self.exit_stack.enter_async_context( @@ -222,10 +244,120 @@ class MCPClient: self.tools = response.tools return response + async def _reconnect(self) -> None: + """Reconnect to the MCP server using the stored configuration. + + Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments. + + Raises: + Exception: raised when reconnection fails + """ + async with self._reconnect_lock: + # Check if already reconnecting (useful for logging) + if self._reconnecting: + logger.debug( + f"MCP Client {self._server_name} is already reconnecting, skipping" + ) + return + + if not self._mcp_server_config or not self._server_name: + raise Exception("Cannot reconnect: missing connection configuration") + + self._reconnecting = True + try: + logger.info( + f"Attempting to reconnect to MCP server {self._server_name}..." + ) + + # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues) + if self.exit_stack: + self._old_exit_stacks.append(self.exit_stack) + + # Mark old session as invalid + self.session = None + + # Create new exit stack for new connection + self.exit_stack = AsyncExitStack() + + # Reconnect using stored config + await self.connect_to_server(self._mcp_server_config, self._server_name) + await self.list_tools_and_save() + + logger.info( + f"Successfully reconnected to MCP server {self._server_name}" + ) + except Exception as e: + logger.error( + f"Failed to reconnect to MCP server {self._server_name}: {e}" + ) + raise + finally: + self._reconnecting = False + + async def call_tool_with_reconnect( + self, + tool_name: str, + arguments: dict, + read_timeout_seconds: timedelta, + ) -> mcp.types.CallToolResult: + """Call MCP tool with automatic reconnection on failure, max 2 retries. + + Args: + tool_name: tool name + arguments: tool arguments + read_timeout_seconds: read timeout + + Returns: + MCP tool call result + + Raises: + ValueError: MCP session is not available + anyio.ClosedResourceError: raised after reconnection failure + """ + + @retry( + retry=retry_if_exception_type(anyio.ClosedResourceError), + stop=stop_after_attempt(2), + wait=wait_exponential(multiplier=1, min=1, max=3), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _call_with_retry(): + if not self.session: + raise ValueError("MCP session is not available for MCP function tools.") + + try: + return await self.session.call_tool( + name=tool_name, + arguments=arguments, + read_timeout_seconds=read_timeout_seconds, + ) + except anyio.ClosedResourceError: + logger.warning( + f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." + ) + # Attempt to reconnect + await self._reconnect() + # Reraise the exception to trigger tenacity retry + raise + + return await _call_with_retry() + async def cleanup(self): - """Clean up resources""" - await self.exit_stack.aclose() - self.running_event.set() # Set the running event to indicate cleanup is done + """Clean up resources including old exit stacks from reconnections""" + # Set running_event first to unblock any waiting tasks + self.running_event.set() + + # Close current exit stack + try: + await self.exit_stack.aclose() + except Exception as e: + logger.debug(f"Error closing current exit stack: {e}") + + # Don't close old exit stacks as they may be in different task contexts + # They will be garbage collected naturally + # Just clear the list to release references + self._old_exit_stacks.clear() class MCPTool(FunctionTool, Generic[TContext]): @@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]): async def call( self, context: ContextWrapper[TContext], **kwargs ) -> mcp.types.CallToolResult: - session = self.mcp_client.session - if not session: - raise ValueError("MCP session is not available for MCP function tools.") - res = await session.call_tool( - name=self.mcp_tool.name, + return await self.mcp_client.call_tool_with_reconnect( + tool_name=self.mcp_tool.name, arguments=kwargs, - read_timeout_seconds=timedelta( - seconds=context.tool_call_timeout, - ), + read_timeout_seconds=timedelta(seconds=context.tool_call_timeout), ) - return res diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7cdbeec01..8e04423ed 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -280,19 +280,22 @@ class FunctionToolManager: async def _terminate_mcp_client(self, name: str) -> None: """关闭并清理MCP客户端""" if name in self.mcp_client_dict: + client = self.mcp_client_dict[name] try: # 关闭MCP连接 - await self.mcp_client_dict[name].cleanup() - self.mcp_client_dict.pop(name) + await client.cleanup() except Exception as e: logger.error(f"清空 MCP 客户端资源 {name}: {e}。") - # 移除关联的FuncTool - self.func_list = [ - f - for f in self.func_list - if not (isinstance(f, MCPTool) and f.mcp_server_name == name) - ] - logger.info(f"已关闭 MCP 服务 {name}") + finally: + # Remove client from dict after cleanup attempt (successful or not) + self.mcp_client_dict.pop(name, None) + # 移除关联的FuncTool + self.func_list = [ + f + for f in self.func_list + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) + ] + logger.info(f"已关闭 MCP 服务 {name}") @staticmethod async def test_mcp_server_connection(config: dict) -> list[str]: diff --git a/pyproject.toml b/pyproject.toml index 576bc1966..707581846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "jieba>=0.42.1", "markitdown-no-magika[docx,xls,xlsx]>=0.1.2", "xinference-client", + "tenacity>=9.1.2", ] [dependency-groups] @@ -107,4 +108,4 @@ exclude = ["dashboard", "node_modules", "dist", "data", "tests"] [build-system] requires = ["hatchling"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" diff --git a/requirements.txt b/requirements.txt index e8b3dee3c..b56741192 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,4 @@ rank-bm25>=0.2.2 jieba>=0.42.1 markitdown-no-magika[docx,xls,xlsx]>=0.1.2 xinference-client +tenacity>=9.1.2 \ No newline at end of file From 6dc3d161e735380989e92c66818468f58ba0b371 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:07:09 +0800 Subject: [PATCH 16/27] feat(chat): refactor chat component structure and add new features (#3701) - Introduced `ConversationSidebar.vue` for improved conversation management and sidebar functionality. - Enhanced `MessageList.vue` to handle loading states and improved message rendering. - Created new composables: `useConversations`, `useMessages`, `useMediaHandling`, `useRecording` for better code organization and reusability. - Added loading indicators and improved user experience during message processing. - Ensured backward compatibility and maintained existing functionalities. --- dashboard/src/components/chat/Chat.vue | 1796 ++++------------- dashboard/src/components/chat/ChatInput.vue | 283 +++ .../components/chat/ConversationSidebar.vue | 310 +++ dashboard/src/components/chat/MessageList.vue | 92 +- dashboard/src/composables/useMediaHandling.ts | 104 + dashboard/src/composables/useMessages.ts | 303 +++ dashboard/src/composables/useRecording.ts | 74 + dashboard/src/composables/useSessions.ts | 145 ++ 8 files changed, 1628 insertions(+), 1479 deletions(-) create mode 100644 dashboard/src/components/chat/ChatInput.vue create mode 100644 dashboard/src/components/chat/ConversationSidebar.vue create mode 100644 dashboard/src/composables/useMediaHandling.ts create mode 100644 dashboard/src/composables/useMessages.ts create mode 100644 dashboard/src/composables/useRecording.ts create mode 100644 dashboard/src/composables/useSessions.ts diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 81b3c327f..aa1192d09 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -5,90 +5,20 @@
- +
@@ -150,74 +80,43 @@
-
-
- -
-
- - - - - - -
-
- - - - - -
-
-
- - -
-
- - -
- -
- - - {{ tm('voice.recording') }} - - -
-
-
+
+ + + + {{ tm('actions.editTitle') }} + + + + + + {{ t('core.common.cancel') }} + {{ t('core.common.save') }} + + + @@ -231,1048 +130,273 @@ - - - - - - {{ tm('conversation.editDisplayName') }} - - - - - - - - - {{ t('core.common.cancel') }} - - - {{ t('core.common.save') }} - - - - - - \ No newline at end of file diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue new file mode 100644 index 000000000..7ca0ec94a --- /dev/null +++ b/dashboard/src/components/chat/ChatInput.vue @@ -0,0 +1,283 @@ + + + + + diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue new file mode 100644 index 000000000..2ea9a4819 --- /dev/null +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -0,0 +1,310 @@ + + + + + diff --git a/dashboard/src/components/chat/MessageList.vue b/dashboard/src/components/chat/MessageList.vue index 9cd69241f..15f3c1d31 100644 --- a/dashboard/src/components/chat/MessageList.vue +++ b/dashboard/src/components/chat/MessageList.vue @@ -37,42 +37,49 @@
- -
-
- - {{ isReasoningExpanded(index) ? 'mdi-chevron-down' : 'mdi-chevron-right' }} - - {{ tm('reasoning.thinking') }} -
-
-
-
+ +
+ {{ tm('message.loading') }}
- -
- - -
-
- +
-
+
@@ -841,6 +848,29 @@ export default { margin: 10px 0; } +.loading-container { + display: flex; + align-items: center; + gap: 12px; + padding: 8px 0; + margin-top: 2px; +} + +.loading-text { + font-size: 14px; + color: var(--v-theme-secondaryText); + animation: pulse 1.5s ease-in-out infinite; +} + +@keyframes pulse { + 0%, 100% { + opacity: 0.6; + } + 50% { + opacity: 1; + } +} + .markdown-content blockquote { border-left: 4px solid var(--v-theme-secondary); padding-left: 16px; diff --git a/dashboard/src/composables/useMediaHandling.ts b/dashboard/src/composables/useMediaHandling.ts new file mode 100644 index 000000000..e24c25fb8 --- /dev/null +++ b/dashboard/src/composables/useMediaHandling.ts @@ -0,0 +1,104 @@ +import { ref } from 'vue'; +import axios from 'axios'; + +export function useMediaHandling() { + const stagedImagesName = ref([]); + const stagedImagesUrl = ref([]); + const stagedAudioUrl = ref(''); + const mediaCache = ref>({}); + + async function getMediaFile(filename: string): Promise { + if (mediaCache.value[filename]) { + return mediaCache.value[filename]; + } + + try { + const response = await axios.get('/api/chat/get_file', { + params: { filename }, + responseType: 'blob' + }); + + const blobUrl = URL.createObjectURL(response.data); + mediaCache.value[filename] = blobUrl; + return blobUrl; + } catch (error) { + console.error('Error fetching media file:', error); + return ''; + } + } + + async function processAndUploadImage(file: File) { + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await axios.post('/api/chat/post_image', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }); + + const img = response.data.data.filename; + stagedImagesName.value.push(img); + stagedImagesUrl.value.push(URL.createObjectURL(file)); + } catch (err) { + console.error('Error uploading image:', err); + } + } + + async function handlePaste(event: ClipboardEvent) { + const items = event.clipboardData?.items; + if (!items) return; + + for (let i = 0; i < items.length; i++) { + if (items[i].type.indexOf('image') !== -1) { + const file = items[i].getAsFile(); + if (file) { + await processAndUploadImage(file); + } + } + } + } + + function removeImage(index: number) { + const urlToRevoke = stagedImagesUrl.value[index]; + if (urlToRevoke && urlToRevoke.startsWith('blob:')) { + URL.revokeObjectURL(urlToRevoke); + } + + stagedImagesName.value.splice(index, 1); + stagedImagesUrl.value.splice(index, 1); + } + + function removeAudio() { + stagedAudioUrl.value = ''; + } + + function clearStaged() { + stagedImagesName.value = []; + stagedImagesUrl.value = []; + stagedAudioUrl.value = ''; + } + + function cleanupMediaCache() { + Object.values(mediaCache.value).forEach(url => { + if (url.startsWith('blob:')) { + URL.revokeObjectURL(url); + } + }); + mediaCache.value = {}; + } + + return { + stagedImagesName, + stagedImagesUrl, + stagedAudioUrl, + getMediaFile, + processAndUploadImage, + handlePaste, + removeImage, + removeAudio, + clearStaged, + cleanupMediaCache + }; +} diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts new file mode 100644 index 000000000..3779576cd --- /dev/null +++ b/dashboard/src/composables/useMessages.ts @@ -0,0 +1,303 @@ +import { ref, reactive, type Ref } from 'vue'; +import axios from 'axios'; +import { useToast } from '@/utils/toast'; + +export interface MessageContent { + type: string; + message: string; + reasoning?: string; + image_url?: string[]; + audio_url?: string; + embedded_images?: string[]; + embedded_audio?: string; + isLoading?: boolean; +} + +export interface Message { + content: MessageContent; +} + +export function useMessages( + currSessionId: Ref, + getMediaFile: (filename: string) => Promise, + updateSessionTitle: (sessionId: string, title: string) => void, + onSessionsUpdate: () => void +) { + const messages = ref([]); + const isStreaming = ref(false); + const isConvRunning = ref(false); + const isToastedRunningInfo = ref(false); + const activeSSECount = ref(0); + const enableStreaming = ref(true); + + // 从 localStorage 读取流式响应开关状态 + const savedStreamingState = localStorage.getItem('enableStreaming'); + if (savedStreamingState !== null) { + enableStreaming.value = JSON.parse(savedStreamingState); + } + + function toggleStreaming() { + enableStreaming.value = !enableStreaming.value; + localStorage.setItem('enableStreaming', JSON.stringify(enableStreaming.value)); + } + + async function getSessionMessages(sessionId: string, router: any) { + if (!sessionId) return; + + try { + const response = await axios.get('/api/chat/get_session?session_id=' + sessionId); + isConvRunning.value = response.data.data.is_running || false; + let history = response.data.data.history; + + if (isConvRunning.value) { + if (!isToastedRunningInfo.value) { + useToast().info("该会话正在运行中。", { timeout: 5000 }); + isToastedRunningInfo.value = true; + } + + // 如果会话还在运行,3秒后重新获取消息 + setTimeout(() => { + getSessionMessages(currSessionId.value, router); + }, 3000); + } + + // 处理历史消息中的媒体文件 + for (let i = 0; i < history.length; i++) { + let content = history[i].content; + + if (content.message?.startsWith('[IMAGE]')) { + let img = content.message.replace('[IMAGE]', ''); + const imageUrl = await getMediaFile(img); + if (!content.embedded_images) { + content.embedded_images = []; + } + content.embedded_images.push(imageUrl); + content.message = ''; + } + + if (content.message?.startsWith('[RECORD]')) { + let audio = content.message.replace('[RECORD]', ''); + const audioUrl = await getMediaFile(audio); + content.embedded_audio = audioUrl; + content.message = ''; + } + + if (content.image_url && content.image_url.length > 0) { + for (let j = 0; j < content.image_url.length; j++) { + content.image_url[j] = await getMediaFile(content.image_url[j]); + } + } + + if (content.audio_url) { + content.audio_url = await getMediaFile(content.audio_url); + } + } + + messages.value = history; + } catch (err) { + console.error(err); + } + } + + async function sendMessage( + prompt: string, + imageNames: string[], + audioName: string, + selectedProviderId: string, + selectedModelName: string + ) { + // Create user message + const userMessage: MessageContent = { + type: 'user', + message: prompt, + image_url: [], + audio_url: undefined + }; + + // Convert image filenames to blob URLs + if (imageNames.length > 0) { + const imagePromises = imageNames.map(name => { + if (!name.startsWith('blob:')) { + return getMediaFile(name); + } + return Promise.resolve(name); + }); + userMessage.image_url = await Promise.all(imagePromises); + } + + // Convert audio filename to blob URL + if (audioName) { + if (!audioName.startsWith('blob:')) { + userMessage.audio_url = await getMediaFile(audioName); + } else { + userMessage.audio_url = audioName; + } + } + + messages.value.push({ content: userMessage }); + + // 添加一个加载中的机器人消息占位符 + const loadingMessage = reactive({ + type: 'bot', + message: '', + reasoning: '', + isLoading: true + }); + messages.value.push({ content: loadingMessage }); + + try { + activeSSECount.value++; + if (activeSSECount.value === 1) { + isConvRunning.value = true; + } + + const response = await fetch('/api/chat/send', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + localStorage.getItem('token') + }, + body: JSON.stringify({ + message: prompt, + session_id: currSessionId.value, + image_url: imageNames, + audio_url: audioName ? [audioName] : [], + selected_provider: selectedProviderId, + selected_model: selectedModelName, + enable_streaming: enableStreaming.value + }) + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const reader = response.body!.getReader(); + const decoder = new TextDecoder(); + let in_streaming = false; + let message_obj: any = null; + + isStreaming.value = true; + + while (true) { + try { + const { done, value } = await reader.read(); + if (done) { + console.log('SSE stream completed'); + break; + } + + const chunk = decoder.decode(value, { stream: true }); + const lines = chunk.split('\n\n'); + + for (let i = 0; i < lines.length; i++) { + let line = lines[i].trim(); + if (!line) continue; + + let chunk_json; + try { + chunk_json = JSON.parse(line.replace('data: ', '')); + } catch (parseError) { + console.warn('JSON解析失败:', line, parseError); + continue; + } + + if (!chunk_json || typeof chunk_json !== 'object' || !chunk_json.hasOwnProperty('type')) { + console.warn('无效的数据对象:', chunk_json); + continue; + } + + if (chunk_json.type === 'error') { + console.error('Error received:', chunk_json.data); + continue; + } + + if (chunk_json.type === 'image') { + let img = chunk_json.data.replace('[IMAGE]', ''); + const imageUrl = await getMediaFile(img); + let bot_resp: MessageContent = { + type: 'bot', + message: '', + embedded_images: [imageUrl] + }; + messages.value.push({ content: bot_resp }); + } else if (chunk_json.type === 'record') { + let audio = chunk_json.data.replace('[RECORD]', ''); + const audioUrl = await getMediaFile(audio); + let bot_resp: MessageContent = { + type: 'bot', + message: '', + embedded_audio: audioUrl + }; + messages.value.push({ content: bot_resp }); + } else if (chunk_json.type === 'plain') { + const chain_type = chunk_json.chain_type || 'normal'; + + if (!in_streaming) { + // 移除加载占位符 + const lastMsg = messages.value[messages.value.length - 1]; + if (lastMsg?.content?.isLoading) { + messages.value.pop(); + } + + message_obj = reactive({ + type: 'bot', + message: chain_type === 'reasoning' ? '' : chunk_json.data, + reasoning: chain_type === 'reasoning' ? chunk_json.data : '', + }); + messages.value.push({ content: message_obj }); + in_streaming = true; + } else { + if (chain_type === 'reasoning') { + // 使用 reactive 对象,直接修改属性会触发响应式更新 + message_obj.reasoning = (message_obj.reasoning || '') + chunk_json.data; + } else { + message_obj.message = (message_obj.message || '') + chunk_json.data; + } + } + } else if (chunk_json.type === 'update_title') { + updateSessionTitle(chunk_json.session_id, chunk_json.data); + } + + if ((chunk_json.type === 'break' && chunk_json.streaming) || !chunk_json.streaming) { + in_streaming = false; + if (!chunk_json.streaming) { + isStreaming.value = false; + } + } + } + } catch (readError) { + console.error('SSE读取错误:', readError); + break; + } + } + + // 获取最新的会话列表 + onSessionsUpdate(); + + } catch (err) { + console.error('发送消息失败:', err); + // 移除加载占位符 + const lastMsg = messages.value[messages.value.length - 1]; + if (lastMsg?.content?.isLoading) { + messages.value.pop(); + } + } finally { + isStreaming.value = false; + activeSSECount.value--; + if (activeSSECount.value === 0) { + isConvRunning.value = false; + } + } + } + + return { + messages, + isStreaming, + isConvRunning, + enableStreaming, + getSessionMessages, + sendMessage, + toggleStreaming + }; +} diff --git a/dashboard/src/composables/useRecording.ts b/dashboard/src/composables/useRecording.ts new file mode 100644 index 000000000..4b03e8508 --- /dev/null +++ b/dashboard/src/composables/useRecording.ts @@ -0,0 +1,74 @@ +import { ref } from 'vue'; +import axios from 'axios'; + +export function useRecording() { + const isRecording = ref(false); + const audioChunks = ref([]); + const mediaRecorder = ref(null); + + async function startRecording(onStart?: (label: string) => void) { + try { + const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + mediaRecorder.value = new MediaRecorder(stream); + + mediaRecorder.value.ondataavailable = (event) => { + audioChunks.value.push(event.data); + }; + + mediaRecorder.value.start(); + isRecording.value = true; + + if (onStart) { + onStart('录音中...'); + } + } catch (error) { + console.error('Failed to start recording:', error); + } + } + + async function stopRecording(onStop?: (label: string) => void): Promise { + return new Promise((resolve, reject) => { + if (!mediaRecorder.value) { + reject('No media recorder'); + return; + } + + isRecording.value = false; + if (onStop) { + onStop('聊天输入框'); + } + + mediaRecorder.value.stop(); + mediaRecorder.value.onstop = async () => { + const audioBlob = new Blob(audioChunks.value, { type: 'audio/wav' }); + audioChunks.value = []; + + mediaRecorder.value?.stream.getTracks().forEach(track => track.stop()); + + const formData = new FormData(); + formData.append('file', audioBlob); + + try { + const response = await axios.post('/api/chat/post_file', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }); + + const audio = response.data.data.filename; + console.log('Audio uploaded:', audio); + resolve(audio); + } catch (err) { + console.error('Error uploading audio:', err); + reject(err); + } + }; + }); + } + + return { + isRecording, + startRecording, + stopRecording + }; +} diff --git a/dashboard/src/composables/useSessions.ts b/dashboard/src/composables/useSessions.ts new file mode 100644 index 000000000..860668a94 --- /dev/null +++ b/dashboard/src/composables/useSessions.ts @@ -0,0 +1,145 @@ +import { ref, computed } from 'vue'; +import axios from 'axios'; +import { useRouter } from 'vue-router'; + +export interface Session { + session_id: string; + display_name: string; + updated_at: number; +} + +export function useSessions(chatboxMode: boolean = false) { + const router = useRouter(); + const sessions = ref([]); + const selectedSessions = ref([]); + const currSessionId = ref(''); + const pendingSessionId = ref(null); + + // 编辑标题相关 + const editTitleDialog = ref(false); + const editingTitle = ref(''); + const editingSessionId = ref(''); + + const getCurrentSession = computed(() => { + if (!currSessionId.value) return null; + return sessions.value.find(s => s.session_id === currSessionId.value); + }); + + async function getSessions() { + try { + const response = await axios.get('/api/chat/sessions'); + sessions.value = response.data.data; + + // 处理待加载的会话 + if (pendingSessionId.value) { + const session = sessions.value.find(s => s.session_id === pendingSessionId.value); + if (session) { + selectedSessions.value = [pendingSessionId.value]; + pendingSessionId.value = null; + } + } else if (!currSessionId.value && sessions.value.length > 0) { + // 默认选择第一个会话 + const firstSession = sessions.value[0]; + selectedSessions.value = [firstSession.session_id]; + } + } catch (err: any) { + if (err.response?.status === 401) { + router.push('/auth/login?redirect=/chatbox'); + } + console.error(err); + } + } + + async function newSession() { + try { + const response = await axios.get('/api/chat/new_session'); + const sessionId = response.data.data.session_id; + currSessionId.value = sessionId; + + // 更新 URL + const basePath = chatboxMode ? '/chatbox' : '/chat'; + router.push(`${basePath}/${sessionId}`); + + await getSessions(); + return sessionId; + } catch (err) { + console.error(err); + throw err; + } + } + + async function deleteSession(sessionId: string) { + try { + await axios.get('/api/chat/delete_session?session_id=' + sessionId); + await getSessions(); + currSessionId.value = ''; + selectedSessions.value = []; + } catch (err) { + console.error(err); + } + } + + function showEditTitleDialog(sessionId: string, title: string) { + editingSessionId.value = sessionId; + editingTitle.value = title || ''; + editTitleDialog.value = true; + } + + async function saveTitle() { + if (!editingSessionId.value) return; + + const trimmedTitle = editingTitle.value.trim(); + try { + await axios.post('/api/chat/update_session_display_name', { + session_id: editingSessionId.value, + display_name: trimmedTitle + }); + + // 更新本地会话标题 + const session = sessions.value.find(s => s.session_id === editingSessionId.value); + if (session) { + session.display_name = trimmedTitle; + } + editTitleDialog.value = false; + } catch (err) { + console.error('重命名会话失败:', err); + } + } + + function updateSessionTitle(sessionId: string, title: string) { + const session = sessions.value.find(s => s.session_id === sessionId); + if (session) { + session.display_name = title; + } + } + + function newChat(closeMobileSidebar?: () => void) { + currSessionId.value = ''; + selectedSessions.value = []; + + const basePath = chatboxMode ? '/chatbox' : '/chat'; + router.push(basePath); + + if (closeMobileSidebar) { + closeMobileSidebar(); + } + } + + return { + sessions, + selectedSessions, + currSessionId, + pendingSessionId, + editTitleDialog, + editingTitle, + editingSessionId, + getCurrentSession, + getSessions, + newSession, + deleteSession, + showEditTitleDialog, + saveTitle, + updateSessionTitle, + newChat + }; +} From 1d3928d1450e59ada75ab59a5a0bc50ceb78ee24 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 20 Nov 2025 16:33:57 +0800 Subject: [PATCH 17/27] refactor(sqlite): remove auto-generation of session_id in insert method --- astrbot/core/db/sqlite.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 69203bf6d..4a7f25609 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,7 +1,6 @@ import asyncio import threading import typing as T -import uuid from datetime import datetime, timedelta from sqlalchemy.ext.asyncio import AsyncSession @@ -728,9 +727,6 @@ class SQLiteDatabase(BaseDatabase): kwargs = {} if session_id: kwargs["session_id"] = session_id - else: - # Auto-generate session_id - kwargs["session_id"] = uuid.uuid4() async with self.get_db() as session: session: AsyncSession From cb087b5ff94efb8074cd97bec0517f69808ce4ce Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 20 Nov 2025 17:02:01 +0800 Subject: [PATCH 18/27] refactor: update timestamp handling in session management and chat components --- astrbot/core/db/po.py | 2 +- astrbot/core/db/sqlite.py | 8 ++++---- astrbot/dashboard/routes/chat.py | 4 ++-- .../src/components/chat/ConversationSidebar.vue | 17 +---------------- dashboard/src/composables/useSessions.ts | 2 +- 5 files changed, 9 insertions(+), 24 deletions(-) diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 8fb14c19f..d6621d072 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -162,7 +162,7 @@ class PlatformSession(SQLModel, table=True): Each session can have multiple conversations (对话) associated with it. """ - __tablename__ = "platform_sessions" + __tablename__ = "platform_sessions" # type: ignore inner_id: int | None = Field( primary_key=True, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 4a7f25609..194618612 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,7 +1,7 @@ import asyncio import threading import typing as T -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, delete, desc, func, or_, select, text, update @@ -788,13 +788,13 @@ class SQLiteDatabase(BaseDatabase): async with self.get_db() as session: session: AsyncSession async with session.begin(): - values = {"updated_at": datetime.now()} + values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} if display_name is not None: values["display_name"] = display_name await session.execute( update(PlatformSession) - .where(PlatformSession.session_id == session_id) + .where(col(PlatformSession.session_id == session_id)) .values(**values), ) @@ -805,6 +805,6 @@ class SQLiteDatabase(BaseDatabase): async with session.begin(): await session.execute( delete(PlatformSession).where( - PlatformSession.session_id == session_id, + col(PlatformSession.session_id == session_id), ), ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 2620f1a17..1ad789563 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -334,8 +334,8 @@ class ChatRoute(Route): "creator": session.creator, "display_name": session.display_name, "is_group": session.is_group, - "created_at": int(session.created_at.timestamp()), - "updated_at": int(session.updated_at.timestamp()), + "created_at": session.created_at.astimezone().isoformat(), + "updated_at": session.updated_at.astimezone().isoformat(), } ) diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue index 2ea9a4819..b2ebd3fef 100644 --- a/dashboard/src/components/chat/ConversationSidebar.vue +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -54,7 +54,7 @@ {{ item.display_name || tm('conversation.newConversation') }} - {{ formatDate(item.updated_at) }} + {{ new Date(item.updated_at).toLocaleString() }}