From f7a716af43f238e75c50251422666539265c4ef3 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Tue, 2 Dec 2025 17:11:08 +0800 Subject: [PATCH] refactor: message storage format of webchat, support reply and file message segment (#3845) * refactor: message storage format of webchat * refactor: update image and record handling in webchat event processing * fix: thinking placeholder in webchat * feat: supports file upload in webchat * feat: supports to delete attachments when webchat session is deleted * perf: improve performance of file downloading * refactor: remove unused import in chat route * feat: add message timestamp formatting and localization support in chat * fix: handle missing filename in file upload for chat route * feat: enhance file handling in chat and webchat, supporting video uploads and improved attachment management * fix: update property name for embedded files in message handling * fix: compute variable errors after uninstalling plugins * feat: supported for reply message and standarlize the message param * fix: ensure message actions are displayed for the last message in the list --- astrbot/core/db/__init__.py | 31 +- astrbot/core/db/sqlite.py | 54 +++ .../sources/webchat/webchat_adapter.py | 119 ++++-- .../platform/sources/webchat/webchat_event.py | 66 ++-- astrbot/core/platform_message_history_mgr.py | 6 +- astrbot/dashboard/routes/chat.py | 370 ++++++++++++++---- dashboard/src/components/chat/Chat.vue | 82 +++- dashboard/src/components/chat/ChatInput.vue | 99 ++++- dashboard/src/components/chat/MessageList.vue | 290 +++++++++++++- .../src/components/chat/StandaloneChat.vue | 13 +- dashboard/src/composables/useMediaHandling.ts | 112 +++++- dashboard/src/composables/useMessages.ts | 275 ++++++++++--- .../src/i18n/locales/en-US/features/chat.json | 11 +- .../src/i18n/locales/zh-CN/features/chat.json | 11 +- dashboard/src/views/ExtensionPage.vue | 13 +- 15 files changed, 1323 insertions(+), 229 deletions(-) diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 2af0428d0..58d1c6a9c 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -173,7 +173,7 @@ class BaseDatabase(abc.ABC): content: dict, sender_id: str | None = None, sender_name: str | None = None, - ) -> None: + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" ... @@ -198,6 +198,14 @@ class BaseDatabase(abc.ABC): """Get platform message history for a specific user.""" ... + @abc.abstractmethod + async def get_platform_message_history_by_id( + self, + message_id: int, + ) -> PlatformMessageHistory | None: + """Get a platform message history record by its ID.""" + ... + @abc.abstractmethod async def insert_attachment( self, @@ -213,6 +221,27 @@ class BaseDatabase(abc.ABC): """Get an attachment by its ID.""" ... + @abc.abstractmethod + async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]: + """Get multiple attachments by their IDs.""" + ... + + @abc.abstractmethod + async def delete_attachment(self, attachment_id: str) -> bool: + """Delete an attachment by its ID. + + Returns True if the attachment was deleted, False if it was not found. + """ + ... + + @abc.abstractmethod + async def delete_attachments(self, attachment_ids: list[str]) -> int: + """Delete multiple attachments by their IDs. + + Returns the number of attachments deleted. + """ + ... + @abc.abstractmethod async def insert_persona( self, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 276f5821f..5b603abd0 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -449,6 +449,18 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() + async def get_platform_message_history_by_id( + self, message_id: int + ) -> PlatformMessageHistory | None: + """Get a platform message history record by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(PlatformMessageHistory).where( + PlatformMessageHistory.id == message_id + ) + result = await session.execute(query) + return result.scalar_one_or_none() + async def insert_attachment(self, path, type, mime_type): """Insert a new attachment record.""" async with self.get_db() as session: @@ -470,6 +482,48 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query) return result.scalar_one_or_none() + async def get_attachments(self, attachment_ids: list[str]) -> list: + """Get multiple attachments by their IDs.""" + if not attachment_ids: + return [] + async with self.get_db() as session: + session: AsyncSession + query = select(Attachment).where( + Attachment.attachment_id.in_(attachment_ids) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def delete_attachment(self, attachment_id: str) -> bool: + """Delete an attachment by its ID. + + Returns True if the attachment was deleted, False if it was not found. + """ + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = delete(Attachment).where( + col(Attachment.attachment_id) == attachment_id + ) + result = await session.execute(query) + return result.rowcount > 0 + + async def delete_attachments(self, attachment_ids: list[str]) -> int: + """Delete multiple attachments by their IDs. + + Returns the number of attachments deleted. + """ + if not attachment_ids: + return 0 + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = delete(Attachment).where( + col(Attachment.attachment_id).in_(attachment_ids) + ) + result = await session.execute(query) + return result.rowcount + async def insert_persona( self, persona_id, diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index ff5482f58..80df6d80d 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -6,7 +6,9 @@ from collections.abc import Awaitable, Callable from typing import Any from astrbot import logger -from astrbot.core.message.components import Image, Plain, Record +from astrbot.core import db_helper +from astrbot.core.db.po import PlatformMessageHistory +from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform import ( AstrBotMessage, @@ -96,6 +98,92 @@ class WebChatAdapter(Platform): await WebChatMessageEvent._send(message_chain, session.session_id) await super().send_by_session(session, message_chain) + async def _get_message_history( + self, message_id: int + ) -> PlatformMessageHistory | None: + return await db_helper.get_platform_message_history_by_id(message_id) + + async def _parse_message_parts( + self, + message_parts: list, + depth: int = 0, + max_depth: int = 1, + ) -> tuple[list, list[str]]: + """解析消息段列表,返回消息组件列表和纯文本列表 + + Args: + message_parts: 消息段列表 + depth: 当前递归深度 + max_depth: 最大递归深度(用于处理 reply) + + Returns: + tuple[list, list[str]]: (消息组件列表, 纯文本列表) + """ + components = [] + text_parts = [] + + for part in message_parts: + part_type = part.get("type") + if part_type == "plain": + text = part.get("text", "") + components.append(Plain(text)) + text_parts.append(text) + elif part_type == "reply": + message_id = part.get("message_id") + reply_chain = [] + reply_message_str = "" + sender_id = None + sender_name = None + + # recursively get the content of the referenced message + if depth < max_depth and message_id: + history = await self._get_message_history(message_id) + if history and history.content: + reply_parts = history.content.get("message", []) + if isinstance(reply_parts, list): + ( + reply_chain, + reply_text_parts, + ) = await self._parse_message_parts( + reply_parts, + depth=depth + 1, + max_depth=max_depth, + ) + reply_message_str = "".join(reply_text_parts) + sender_id = history.sender_id + sender_name = history.sender_name + + components.append( + Reply( + id=message_id, + chain=reply_chain, + message_str=reply_message_str, + sender_id=sender_id, + sender_nickname=sender_name, + ) + ) + elif part_type == "image": + path = part.get("path") + if path: + components.append(Image.fromFileSystem(path)) + elif part_type == "record": + path = part.get("path") + if path: + components.append(Record.fromFileSystem(path)) + elif part_type == "file": + path = part.get("path") + if path: + filename = part.get("filename") or ( + os.path.basename(path) if path else "file" + ) + components.append(File(name=filename, file=path)) + elif part_type == "video": + path = part.get("path") + if path: + components.append(Video.fromFileSystem(path)) + + return components, text_parts + async def convert_message(self, data: tuple) -> AstrBotMessage: username, cid, payload = data @@ -108,36 +196,15 @@ class WebChatAdapter(Platform): abm.session_id = f"webchat!{username}!{cid}" abm.message_id = str(uuid.uuid4()) - abm.message = [] - if payload["message"]: - abm.message.append(Plain(payload["message"])) - if payload["image_url"]: - if isinstance(payload["image_url"], list): - for img in payload["image_url"]: - abm.message.append( - Image.fromFileSystem(os.path.join(self.imgs_dir, img)), - ) - else: - abm.message.append( - Image.fromFileSystem( - os.path.join(self.imgs_dir, payload["image_url"]), - ), - ) - if payload["audio_url"]: - if isinstance(payload["audio_url"], list): - for audio in payload["audio_url"]: - path = os.path.join(self.imgs_dir, audio) - abm.message.append(Record(file=path, path=path)) - else: - path = os.path.join(self.imgs_dir, payload["audio_url"]) - abm.message.append(Record(file=path, path=path)) + # 处理消息段列表 + message_parts = payload.get("message", []) + abm.message, message_str_parts = await self._parse_message_parts(message_parts) logger.debug(f"WebChatAdapter: {abm.message}") - message_str = payload["message"] abm.timestamp = int(time.time()) - abm.message_str = message_str + abm.message_str = "".join(message_str_parts) abm.raw_message = data return abm diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 4ced79b19..70c834e65 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,12 +1,12 @@ import base64 import os +import shutil import uuid from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Image, Plain, Record +from astrbot.api.message_components import File, Image, Plain, Record from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.io import download_image_by_url from .webchat_queue_mgr import webchat_queue_mgr @@ -19,7 +19,9 @@ class WebChatMessageEvent(AstrMessageEvent): os.makedirs(imgs_dir, exist_ok=True) @staticmethod - async def _send(message: MessageChain, session_id: str, streaming: bool = False): + async def _send( + message: MessageChain | None, session_id: str, streaming: bool = False + ) -> str | None: cid = session_id.split("!")[-1] web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) if not message: @@ -30,7 +32,7 @@ class WebChatMessageEvent(AstrMessageEvent): "streaming": False, }, # end means this request is finished ) - return "" + return data = "" for comp in message.chain: @@ -47,24 +49,11 @@ class WebChatMessageEvent(AstrMessageEvent): ) elif isinstance(comp, Image): # save image to local - filename = str(uuid.uuid4()) + ".jpg" + filename = f"{str(uuid.uuid4())}.jpg" path = os.path.join(imgs_dir, filename) - if comp.file and comp.file.startswith("file:///"): - ph = comp.file[8:] - with open(path, "wb") as f: - with open(ph, "rb") as f2: - f.write(f2.read()) - elif comp.file.startswith("base64://"): - base64_str = comp.file[9:] - image_data = base64.b64decode(base64_str) - with open(path, "wb") as f: - f.write(image_data) - elif comp.file and comp.file.startswith("http"): - await download_image_by_url(comp.file, path=path) - else: - with open(path, "wb") as f: - with open(comp.file, "rb") as f2: - f.write(f2.read()) + image_base64 = await comp.convert_to_base64() + with open(path, "wb") as f: + f.write(base64.b64decode(image_base64)) data = f"[IMAGE]{filename}" await web_chat_back_queue.put( { @@ -76,19 +65,11 @@ class WebChatMessageEvent(AstrMessageEvent): ) elif isinstance(comp, Record): # save record to local - filename = str(uuid.uuid4()) + ".wav" + filename = f"{str(uuid.uuid4())}.wav" path = os.path.join(imgs_dir, filename) - if comp.file and comp.file.startswith("file:///"): - ph = comp.file[8:] - with open(path, "wb") as f: - with open(ph, "rb") as f2: - f.write(f2.read()) - elif comp.file and comp.file.startswith("http"): - await download_image_by_url(comp.file, path=path) - else: - with open(path, "wb") as f: - with open(comp.file, "rb") as f2: - f.write(f2.read()) + record_base64 = await comp.convert_to_base64() + with open(path, "wb") as f: + f.write(base64.b64decode(record_base64)) data = f"[RECORD]{filename}" await web_chat_back_queue.put( { @@ -98,6 +79,23 @@ class WebChatMessageEvent(AstrMessageEvent): "streaming": streaming, }, ) + elif isinstance(comp, File): + # save file to local + file_path = await comp.get_file() + original_name = comp.name or os.path.basename(file_path) + ext = os.path.splitext(original_name)[1] or "" + filename = f"{uuid.uuid4()!s}{ext}" + dest_path = os.path.join(imgs_dir, filename) + shutil.copy2(file_path, dest_path) + data = f"[FILE]{filename}|{original_name}" + await web_chat_back_queue.put( + { + "type": "file", + "cid": cid, + "data": data, + "streaming": streaming, + }, + ) else: logger.debug(f"webchat 忽略: {comp.type}") @@ -131,6 +129,8 @@ class WebChatMessageEvent(AstrMessageEvent): session_id=self.session_id, streaming=True, ) + if not r: + continue if chain.type == "reasoning": reasoning_content += chain.get_plain_text() else: diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index 0e079e893..d6d524698 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -10,12 +10,12 @@ class PlatformMessageHistoryManager: self, platform_id: str, user_id: str, - content: list[dict], # TODO: parse from message chain + content: dict, # TODO: parse from message chain sender_id: str | None = None, sender_name: str | None = None, - ): + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" - await self.db.insert_platform_message_history( + return await self.db.insert_platform_message_history( platform_id=platform_id, user_id=user_id, content=content, diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 5381b5649..56f98bfbb 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -1,11 +1,11 @@ import asyncio import json +import mimetypes import os import uuid from contextlib import asynccontextmanager -from quart import Response as QuartResponse -from quart import g, make_response, request +from quart import g, make_response, request, send_file from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -44,7 +44,7 @@ class ChatRoute(Route): self.update_session_display_name, ), "/chat/get_file": ("GET", self.get_file), - "/chat/post_image": ("POST", self.post_image), + "/chat/get_attachment": ("GET", self.get_attachment), "/chat/post_file": ("POST", self.post_file), } self.core_lifecycle = core_lifecycle @@ -73,52 +73,184 @@ class ChatRoute(Route): if not real_file_path.startswith(real_imgs_dir): return Response().error("Invalid file path").__dict__ - with open(real_file_path, "rb") as f: - filename_ext = os.path.splitext(filename)[1].lower() - - if filename_ext == ".wav": - return QuartResponse(f.read(), mimetype="audio/wav") - if filename_ext[1:] in self.supported_imgs: - return QuartResponse(f.read(), mimetype="image/jpeg") - return QuartResponse(f.read()) + filename_ext = os.path.splitext(filename)[1].lower() + if filename_ext == ".wav": + return await send_file(real_file_path, mimetype="audio/wav") + if filename_ext[1:] in self.supported_imgs: + return await send_file(real_file_path, mimetype="image/jpeg") + return await send_file(real_file_path) except (FileNotFoundError, OSError): return Response().error("File access error").__dict__ - async def post_image(self): - post_data = await request.files - if "file" not in post_data: - return Response().error("Missing key: file").__dict__ + async def get_attachment(self): + """Get attachment file by attachment_id.""" + attachment_id = request.args.get("attachment_id") + if not attachment_id: + return Response().error("Missing key: attachment_id").__dict__ - file = post_data["file"] - filename = str(uuid.uuid4()) + ".jpg" - path = os.path.join(self.imgs_dir, filename) - await file.save(path) + try: + attachment = await self.db.get_attachment_by_id(attachment_id) + if not attachment: + return Response().error("Attachment not found").__dict__ - return Response().ok(data={"filename": filename}).__dict__ + file_path = attachment.path + real_file_path = os.path.realpath(file_path) + + return await send_file(real_file_path, mimetype=attachment.mime_type) + + except (FileNotFoundError, OSError): + return Response().error("File access error").__dict__ async def post_file(self): + """Upload a file and create an attachment record, return attachment_id.""" post_data = await request.files if "file" not in post_data: return Response().error("Missing key: file").__dict__ file = post_data["file"] - filename = f"{uuid.uuid4()!s}" - # 通过文件格式判断文件类型 - if file.content_type.startswith("audio"): - filename += ".wav" + filename = file.filename or f"{uuid.uuid4()!s}" + content_type = file.content_type or "application/octet-stream" + + # 根据 content_type 判断文件类型并添加扩展名 + if content_type.startswith("image"): + attach_type = "image" + elif content_type.startswith("audio"): + attach_type = "record" + elif content_type.startswith("video"): + attach_type = "video" + else: + attach_type = "file" path = os.path.join(self.imgs_dir, filename) await file.save(path) - return Response().ok(data={"filename": filename}).__dict__ + # 创建 attachment 记录 + attachment = await self.db.insert_attachment( + path=path, + type=attach_type, + mime_type=content_type, + ) + + if not attachment: + return Response().error("Failed to create attachment").__dict__ + + filename = os.path.basename(attachment.path) + + return ( + Response() + .ok( + data={ + "attachment_id": attachment.attachment_id, + "filename": filename, + "type": attach_type, + } + ) + .__dict__ + ) + + async def _build_user_message_parts(self, message: str | list) -> list[dict]: + """构建用户消息的部分列表 + + Args: + message: 文本消息 (str) 或消息段列表 (list) + """ + parts = [] + + if isinstance(message, list): + for part in message: + part_type = part.get("type") + if part_type == "plain": + parts.append({"type": "plain", "text": part.get("text", "")}) + elif part_type == "reply": + parts.append( + {"type": "reply", "message_id": part.get("message_id")} + ) + elif attachment_id := part.get("attachment_id"): + attachment = await self.db.get_attachment_by_id(attachment_id) + if attachment: + parts.append( + { + "type": attachment.type, + "attachment_id": attachment.attachment_id, + "filename": os.path.basename(attachment.path), + "path": attachment.path, # will be deleted + } + ) + return parts + + if message: + parts.append({"type": "plain", "text": message}) + + return parts + + async def _create_attachment_from_file( + self, filename: str, attach_type: str + ) -> dict | None: + """从本地文件创建 attachment 并返回消息部分 + + 用于处理 bot 回复中的媒体文件 + + Args: + filename: 存储的文件名 + attach_type: 附件类型 (image, record, file, video) + """ + file_path = os.path.join(self.imgs_dir, os.path.basename(filename)) + if not os.path.exists(file_path): + return None + + # guess mime type + mime_type, _ = mimetypes.guess_type(filename) + if not mime_type: + mime_type = "application/octet-stream" + + # insert attachment + attachment = await self.db.insert_attachment( + path=file_path, + type=attach_type, + mime_type=mime_type, + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": os.path.basename(file_path), + } + + async def _save_bot_message( + self, + webchat_conv_id: str, + text: str, + media_parts: list, + reasoning: str, + ): + """保存 bot 消息到历史记录,返回保存的记录""" + bot_message_parts = [] + if text: + bot_message_parts.append({"type": "plain", "text": text}) + bot_message_parts.extend(media_parts) + + new_his = {"type": "bot", "message": bot_message_parts} + if reasoning: + new_his["reasoning"] = reasoning + + record = await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", + ) + return record async def chat(self): username = g.get("username", "guest") post_data = await request.json - if "message" not in post_data and "image_url" not in post_data: - return Response().error("Missing key: message or image_url").__dict__ + if "message" not in post_data and "files" not in post_data: + return Response().error("Missing key: message or files").__dict__ if "session_id" not in post_data and "conversation_id" not in post_data: return ( @@ -126,44 +258,40 @@ class ChatRoute(Route): ) message = post_data["message"] - # conversation_id = post_data["conversation_id"] session_id = post_data.get("session_id", post_data.get("conversation_id")) - image_url = post_data.get("image_url") - audio_url = post_data.get("audio_url") selected_provider = post_data.get("selected_provider") selected_model = post_data.get("selected_model") - enable_streaming = post_data.get("enable_streaming", True) # 默认为 True + enable_streaming = post_data.get("enable_streaming", True) - if not message and not image_url and not audio_url: - return ( - Response() - .error("Message and image_url and audio_url are empty") - .__dict__ + # 检查消息是否为空 + if isinstance(message, list): + has_content = any( + part.get("type") in ("plain", "image", "record", "file", "video") + for part in message ) + if not has_content: + return ( + Response() + .error("Message content is empty (reply only is not allowed)") + .__dict__ + ) + elif not message: + return Response().error("Message are both empty").__dict__ + if not session_id: return Response().error("session_id is empty").__dict__ - # 追加用户消息 webchat_conv_id = session_id - - # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) - new_his = {"type": "user", "message": message} - if image_url: - new_his["image_url"] = image_url - if audio_url: - new_his["audio_url"] = audio_url - await self.platform_history_mgr.insert( - platform_id="webchat", - user_id=webchat_conv_id, - content=new_his, - sender_id=username, - sender_name=username, - ) + # 构建用户消息段(包含 path 用于传递给 adapter) + message_parts = await self._build_user_message_parts(message) async def stream(): client_disconnected = False + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" try: async with track_conversation(self.running_convs, webchat_conv_id): @@ -182,16 +310,17 @@ class ChatRoute(Route): continue result_text = result["data"] - type = result.get("type") + msg_type = result.get("type") streaming = result.get("streaming", False) + # 发送 SSE 数据 try: if not client_disconnected: yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" except Exception as e: if not client_disconnected: logger.debug( - f"[WebChat] 用户 {username} 断开聊天长连接。 {e}", + f"[WebChat] 用户 {username} 断开聊天长连接。 {e}" ) client_disconnected = True @@ -202,24 +331,68 @@ class ChatRoute(Route): logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") client_disconnected = True - if type == "end": + # 累积消息部分 + if msg_type == "plain": + chain_type = result.get("chain_type", "normal") + if chain_type == "reasoning": + accumulated_reasoning += result_text + else: + accumulated_text += result_text + elif msg_type == "image": + filename = result_text.replace("[IMAGE]", "") + part = await self._create_attachment_from_file( + filename, "image" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = result_text.replace("[RECORD]", "") + part = await self._create_attachment_from_file( + filename, "record" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "file": + # 格式: [FILE]filename + filename = result_text.replace("[FILE]", "") + part = await self._create_attachment_from_file( + filename, "file" + ) + if part: + accumulated_parts.append(part) + + # 消息结束处理 + if msg_type == "end": break elif ( - (streaming and type == "complete") + (streaming and msg_type == "complete") or not streaming - or type == "break" + or msg_type == "break" ): - # 追加机器人消息 - new_his = {"type": "bot", "message": result_text} - if "reasoning" in result: - new_his["reasoning"] = result["reasoning"] - await self.platform_history_mgr.insert( - platform_id="webchat", - user_id=webchat_conv_id, - content=new_his, - sender_id="bot", - sender_name="bot", + saved_record = await self._save_bot_message( + webchat_conv_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, ) + # 发送保存的消息信息给前端 + if saved_record and not client_disconnected: + saved_info = { + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + } + try: + yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n" + except Exception: + pass + # 重置累积变量 (对于 break 后的下一段消息) + if msg_type == "break": + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" except BaseException as e: logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) @@ -230,9 +403,7 @@ class ChatRoute(Route): username, webchat_conv_id, { - "message": message, - "image_url": image_url, # list - "audio_url": audio_url, + "message": message_parts, "selected_provider": selected_provider, "selected_model": selected_model, "enable_streaming": enable_streaming, @@ -240,6 +411,19 @@ class ChatRoute(Route): ), ) + message_parts_for_storage = [] + for part in message_parts: + part_copy = {k: v for k, v in part.items() if k != "path"} + message_parts_for_storage.append(part_copy) + + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=username, + sender_name=username, + ) + response = await make_response( stream(), { @@ -249,7 +433,7 @@ class ChatRoute(Route): "Connection": "keep-alive", }, ) - response.timeout = None # fix SSE auto disconnect issue + response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue] return response async def delete_webchat_session(self): @@ -271,6 +455,17 @@ class ChatRoute(Route): unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}" await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) + # 获取消息历史中的所有附件 ID 并删除附件 + history_list = await self.platform_history_mgr.get( + platform_id=session.platform_id, + user_id=session_id, + page=1, + page_size=100000, # 获取足够多的记录 + ) + attachment_ids = self._extract_attachment_ids(history_list) + if attachment_ids: + await self._delete_attachments(attachment_ids) + # 删除消息历史 await self.platform_history_mgr.delete( platform_id=session.platform_id, @@ -297,6 +492,41 @@ class ChatRoute(Route): return Response().ok().__dict__ + def _extract_attachment_ids(self, history_list) -> list[str]: + """从消息历史中提取所有 attachment_id""" + attachment_ids = [] + for history in history_list: + content = history.content + if not content or "message" not in content: + continue + message_parts = content.get("message", []) + for part in message_parts: + if isinstance(part, dict) and "attachment_id" in part: + attachment_ids.append(part["attachment_id"]) + return attachment_ids + + async def _delete_attachments(self, attachment_ids: list[str]): + """删除附件(包括数据库记录和磁盘文件)""" + try: + attachments = await self.db.get_attachments(attachment_ids) + for attachment in attachments: + if not os.path.exists(attachment.path): + continue + try: + os.remove(attachment.path) + except OSError as e: + logger.warning( + f"Failed to delete attachment file {attachment.path}: {e}" + ) + except Exception as e: + logger.warning(f"Failed to get attachments: {e}") + + # 批量删除数据库记录 + try: + await self.db.delete_attachments(attachment_ids) + except Exception as e: + logger.warning(f"Failed to delete attachments: {e}") + async def new_session(self): """Create a new Platform session (default: webchat).""" username = g.get("username", "guest") diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index caff448cc..509971ca8 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -71,6 +71,7 @@
@@ -84,19 +85,23 @@ v-model:prompt="prompt" :stagedImagesUrl="stagedImagesUrl" :stagedAudioUrl="stagedAudioUrl" + :stagedFiles="stagedNonImageFiles" :disabled="isStreaming" :enableStreaming="enableStreaming" :isRecording="isRecording" :session-id="currSessionId || null" :current-session="getCurrentSession" + :replyTo="replyTo" @send="handleSendMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" @removeAudio="removeAudio" + @removeFile="removeFile" @startRecording="handleStartRecording" @stopRecording="handleStopRecording" @pasteImage="handlePaste" @fileSelect="handleFileSelect" + @clearReply="clearReply" ref="chatInputRef" />
@@ -189,14 +194,17 @@ const { } = useSessions(props.chatboxMode); const { - stagedImagesName, stagedImagesUrl, stagedAudioUrl, + stagedFiles, + stagedNonImageFiles, getMediaFile, processAndUploadImage, + processAndUploadFile, handlePaste, removeImage, removeAudio, + removeFile, clearStaged, cleanupMediaCache } = useMediaHandling(); @@ -220,6 +228,13 @@ const chatInputRef = ref | null>(null); // 输入状态 const prompt = ref(''); +// 引用消息状态 +interface ReplyInfo { + messageId: number; // PlatformSessionHistoryMessage 的 id + messageContent: string; // 用于显示的消息内容 +} +const replyTo = ref(null); + const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark'); // 检测是否为手机端 @@ -250,6 +265,41 @@ function openImagePreview(imageUrl: string) { imagePreviewDialog.value = true; } +function handleReplyMessage(msg: any, index: number) { + // 从消息中获取 id (PlatformSessionHistoryMessage 的 id) + const messageId = msg.id; + if (!messageId) { + console.warn('Message does not have an id'); + return; + } + + // 获取消息内容用于显示 + let messageContent = ''; + if (typeof msg.content.message === 'string') { + messageContent = msg.content.message; + } else if (Array.isArray(msg.content.message)) { + // 从消息段数组中提取纯文本 + const textParts = msg.content.message + .filter((part: any) => part.type === 'plain' && part.text) + .map((part: any) => part.text); + messageContent = textParts.join(''); + } + + // 截断过长的内容 + if (messageContent.length > 100) { + messageContent = messageContent.substring(0, 100) + '...'; + } + + replyTo.value = { + messageId, + messageContent: messageContent || '[媒体内容]' + }; +} + +function clearReply() { + replyTo.value = null; +} + async function handleSelectConversation(sessionIds: string[]) { if (!sessionIds[0]) return; @@ -265,6 +315,9 @@ async function handleSelectConversation(sessionIds: string[]) { closeMobileSidebar(); } + // 清除引用状态 + clearReply(); + currSessionId.value = sessionIds[0]; selectedSessions.value = [sessionIds[0]]; @@ -278,6 +331,7 @@ async function handleSelectConversation(sessionIds: string[]) { function handleNewChat() { newChat(closeMobileSidebar); messages.value = []; + clearReply(); } async function handleDeleteConversation(sessionId: string) { @@ -295,13 +349,19 @@ async function handleStopRecording() { } async function handleFileSelect(files: FileList) { + const imageTypes = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']; for (const file of files) { - await processAndUploadImage(file); + if (imageTypes.includes(file.type)) { + await processAndUploadImage(file); + } else { + await processAndUploadFile(file); + } } } async function handleSendMessage() { - if (!prompt.value.trim() && stagedImagesName.value.length === 0 && !stagedAudioUrl.value) { + // 只有引用不能发送,必须有输入内容 + if (!prompt.value.trim() && stagedFiles.value.length === 0 && !stagedAudioUrl.value) { return; } @@ -310,12 +370,19 @@ async function handleSendMessage() { } const promptToSend = prompt.value.trim(); - const imageNamesToSend = [...stagedImagesName.value]; const audioNameToSend = stagedAudioUrl.value; + const filesToSend = stagedFiles.value.map(f => ({ + attachment_id: f.attachment_id, + url: f.url, + original_name: f.original_name, + type: f.type + })); + const replyToSend = replyTo.value ? { ...replyTo.value } : null; - // 清空输入和附件 + // 清空输入和附件和引用 prompt.value = ''; clearStaged(); + clearReply(); // 获取选择的提供商和模型 const selection = chatInputRef.value?.getCurrentSelection(); @@ -324,10 +391,11 @@ async function handleSendMessage() { await sendMsg( promptToSend, - imageNamesToSend, + filesToSend, audioNameToSend, selectedProviderId, - selectedModelName + selectedModelName, + replyToSend ); } diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index 79ce27654..53e1e30c0 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -2,6 +2,14 @@
+ +
+
+ mdi-reply + "{{ props.replyTo.messageContent }}" +
+ +