diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index a8a4b0ad5..9929e9ce2 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -203,6 +203,23 @@ class BaseDatabase(abc.ABC): """Get platform message history for a specific user.""" ... + @abc.abstractmethod + async def search_platform_sessions( + self, + creator: str, + query: str, + context_len: int = 40, + page: int = 1, + page_size: int = 10, + ) -> tuple[list[dict], int]: + """Search platform sessions (title or message content) for a given creator. + + Returns a tuple of (results, total) where results is a list of dicts with keys: + session_id, title, match_field, match_index, match_length, snippet, snippet_start, + created_at, updated_at + """ + ... + @abc.abstractmethod async def get_platform_message_history_by_id( self, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 3af08f248..7899113e3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,4 +1,5 @@ import asyncio +import json import threading import typing as T from collections.abc import Awaitable, Callable @@ -483,6 +484,144 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() + def _build_snippet(self, text: str, match_index: int, match_length: int, context_len: int): + if match_index < 0: + return "", 0 + start = max(match_index - context_len, 0) + end = min(match_index + match_length + context_len, len(text)) + return text[start:end], start + + async def search_platform_sessions( + self, + creator: str, + query: str, + context_len: int = 40, + page: int = 1, + page_size: int = 10, + ) -> tuple[list[dict], int]: + """Search platform sessions (by title or by message content) for a given creator. + + This implementation performs searching at DB level using SQL LIKE on + `platform_sessions.display_name` and the JSON `platform_message_history.content`. + To keep work minimal and compatible with SQLite JSON storage, the content + column is searched as text using LIKE. + Returns (results, total) where results are dicts suitable for the caller. + """ + async with self.get_db() as session: + session: AsyncSession + pattern = f"%{query}%" + + # 1) Title matches + title_q = ( + select(PlatformSession) + .where(col(PlatformSession.creator) == creator) + .where(col(PlatformSession.display_name).ilike(pattern)) + .order_by(desc(PlatformSession.updated_at)) + ) + title_result = await session.execute(title_q) + title_rows = title_result.scalars().all() + + results: list[dict] = [] + for session_row in title_rows: + title = session_row.display_name or "" + title_lower = title.lower() + qlower = query.lower() + match_index = title_lower.find(qlower) if title else -1 + snippet, snippet_start = self._build_snippet(title, match_index, len(query), context_len) + results.append( + { + "session_id": session_row.session_id, + "title": session_row.display_name, + "match_field": "title", + "match_index": match_index, + "match_length": len(query), + "snippet": snippet, + "snippet_start": snippet_start, + "created_at": session_row.created_at.astimezone().isoformat(), + "updated_at": session_row.updated_at.astimezone().isoformat(), + } + ) + + # 2) Content matches: find latest matching message per session (user_id) + # Use a subquery to select the latest message id per user that matches the pattern + subq = ( + select(func.max(col(PlatformMessageHistory.id)).label("max_id")) + .select_from(PlatformMessageHistory) + .join( + PlatformSession, + col(PlatformMessageHistory.user_id) == col(PlatformSession.session_id), + ) + .where(col(PlatformSession.creator) == creator) + .where(col(PlatformMessageHistory.content).ilike(pattern)) + .group_by(col(PlatformMessageHistory.user_id)) + ) + + ids_result = await session.execute(subq) + id_rows = [r[0] for r in ids_result.fetchall() if r[0] is not None] + + if id_rows: + q = select(PlatformMessageHistory).where(col(PlatformMessageHistory.id).in_(id_rows)) + q = q.order_by(desc(PlatformMessageHistory.created_at)) + hist_result = await session.execute(q) + histories = hist_result.scalars().all() + + for history in histories: + # find associated session to get display_name/created_at/updated_at + ps_q = select(PlatformSession).where(col(PlatformSession.session_id) == history.user_id) + ps_res = await session.execute(ps_q) + ps = ps_res.scalar_one_or_none() + text = None + try: + # convert content json to plain text similar to ChatRoute._extract_plain_text + msg = history.content + if isinstance(msg, dict): + message = msg.get("message") + if isinstance(message, str): + text = message + elif isinstance(message, list): + parts = [] + for part in message: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type == "plain" and part.get("text"): + parts.append(str(part.get("text"))) + elif part_type == "reply" and part.get("selected_text"): + parts.append(str(part.get("selected_text"))) + text = "\n".join(parts) + except Exception: + text = None + + if not text: + # fallback to stringified JSON + text = json.dumps(history.content, ensure_ascii=False) + + lower_text = text.lower() if isinstance(text, str) else "" + match_index = lower_text.find(query.lower()) + if match_index == -1: + continue + snippet, snippet_start = self._build_snippet(text, match_index, len(query), context_len) + results.append( + { + "session_id": history.user_id, + "title": ps.display_name if ps else None, + "match_field": "content", + "match_index": match_index, + "match_length": len(query), + "snippet": snippet, + "snippet_start": snippet_start, + "created_at": ps.created_at.astimezone().isoformat() if ps else history.created_at.astimezone().isoformat(), + "updated_at": ps.updated_at.astimezone().isoformat() if ps else history.updated_at.astimezone().isoformat(), + } + ) + + # sort and paginate + results.sort(key=lambda item: item["updated_at"], reverse=True) + total = len(results) + offset = (page - 1) * page_size + paged = results[offset : offset + page_size] + return paged, total + async def get_platform_message_history_by_id( self, message_id: int ) -> PlatformMessageHistory | None: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 92ff4c3fe..805a524c6 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -46,6 +46,7 @@ class ChatRoute(Route): "POST", self.update_session_display_name, ), + "/chat/search": ("GET", self.search_sessions), "/chat/get_file": ("GET", self.get_file), "/chat/get_attachment": ("GET", self.get_attachment), "/chat/post_file": ("POST", self.post_file), @@ -63,6 +64,35 @@ class ChatRoute(Route): self.running_convs: dict[str, bool] = {} + @staticmethod + def _extract_plain_text(content: dict) -> str: + if not isinstance(content, dict): + return "" + message = content.get("message") + if isinstance(message, str): + return message + if not isinstance(message, list): + return "" + + parts = [] + for part in message: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type == "plain" and part.get("text"): + parts.append(str(part.get("text"))) + elif part_type == "reply" and part.get("selected_text"): + parts.append(str(part.get("selected_text"))) + return "\n".join(parts) + + @staticmethod + def _build_snippet(text: str, match_index: int, match_length: int, context_len: int): + if match_index < 0: + return "", 0 + start = max(match_index - context_len, 0) + end = min(match_index + match_length + context_len, len(text)) + return text[start:end], start + async def get_file(self): filename = request.args.get("filename") if not filename: @@ -731,6 +761,57 @@ class ChatRoute(Route): return Response().ok(data=sessions_data).__dict__ + async def search_sessions(self): + """Search sessions by title or content, with pagination.""" + username = g.get("username", "guest") + query = request.args.get("query", "", type=str).strip() + page = max(request.args.get("page", 1, type=int), 1) + page_size = min(max(request.args.get("page_size", 10, type=int), 1), 100) + context_len = min(max(request.args.get("context", 40, type=int), 0), 200) + + if not query: + return ( + Response() + .ok( + data={ + "results": [], + "pagination": { + "page": page, + "page_size": page_size, + "total": 0, + "total_pages": 1, + }, + } + ) + .__dict__ + ) + + # Delegate searching to the database implementation for efficiency + paged_results, total = await self.db.search_platform_sessions( + creator=username, + query=query, + context_len=context_len, + page=page, + page_size=page_size, + ) + + total_pages = (total + page_size - 1) // page_size if total > 0 else 1 + return ( + Response() + .ok( + data={ + "results": paged_results, + "pagination": { + "page": page, + "page_size": page_size, + "total": total, + "total_pages": total_pages, + }, + } + ) + .__dict__ + ) + async def get_session(self): """Get session information and message history by session_id.""" session_id = request.args.get("session_id") diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 71e46e690..d650b1e50 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -26,6 +26,7 @@ @createProject="showCreateProjectDialog" @editProject="showEditProjectDialog" @deleteProject="handleDeleteProject" + @openSearch="handleOpenSearch" /> @@ -42,65 +43,99 @@ - -
- - -