diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index a8a4b0ad5..9929e9ce2 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -203,6 +203,23 @@ class BaseDatabase(abc.ABC): """Get platform message history for a specific user.""" ... + @abc.abstractmethod + async def search_platform_sessions( + self, + creator: str, + query: str, + context_len: int = 40, + page: int = 1, + page_size: int = 10, + ) -> tuple[list[dict], int]: + """Search platform sessions (title or message content) for a given creator. + + Returns a tuple of (results, total) where results is a list of dicts with keys: + session_id, title, match_field, match_index, match_length, snippet, snippet_start, + created_at, updated_at + """ + ... + @abc.abstractmethod async def get_platform_message_history_by_id( self, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 3af08f248..7899113e3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,4 +1,5 @@ import asyncio +import json import threading import typing as T from collections.abc import Awaitable, Callable @@ -483,6 +484,144 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() + def _build_snippet(self, text: str, match_index: int, match_length: int, context_len: int): + if match_index < 0: + return "", 0 + start = max(match_index - context_len, 0) + end = min(match_index + match_length + context_len, len(text)) + return text[start:end], start + + async def search_platform_sessions( + self, + creator: str, + query: str, + context_len: int = 40, + page: int = 1, + page_size: int = 10, + ) -> tuple[list[dict], int]: + """Search platform sessions (by title or by message content) for a given creator. + + This implementation performs searching at DB level using SQL LIKE on + `platform_sessions.display_name` and the JSON `platform_message_history.content`. + To keep work minimal and compatible with SQLite JSON storage, the content + column is searched as text using LIKE. + Returns (results, total) where results are dicts suitable for the caller. + """ + async with self.get_db() as session: + session: AsyncSession + pattern = f"%{query}%" + + # 1) Title matches + title_q = ( + select(PlatformSession) + .where(col(PlatformSession.creator) == creator) + .where(col(PlatformSession.display_name).ilike(pattern)) + .order_by(desc(PlatformSession.updated_at)) + ) + title_result = await session.execute(title_q) + title_rows = title_result.scalars().all() + + results: list[dict] = [] + for session_row in title_rows: + title = session_row.display_name or "" + title_lower = title.lower() + qlower = query.lower() + match_index = title_lower.find(qlower) if title else -1 + snippet, snippet_start = self._build_snippet(title, match_index, len(query), context_len) + results.append( + { + "session_id": session_row.session_id, + "title": session_row.display_name, + "match_field": "title", + "match_index": match_index, + "match_length": len(query), + "snippet": snippet, + "snippet_start": snippet_start, + "created_at": session_row.created_at.astimezone().isoformat(), + "updated_at": session_row.updated_at.astimezone().isoformat(), + } + ) + + # 2) Content matches: find latest matching message per session (user_id) + # Use a subquery to select the latest message id per user that matches the pattern + subq = ( + select(func.max(col(PlatformMessageHistory.id)).label("max_id")) + .select_from(PlatformMessageHistory) + .join( + PlatformSession, + col(PlatformMessageHistory.user_id) == col(PlatformSession.session_id), + ) + .where(col(PlatformSession.creator) == creator) + .where(col(PlatformMessageHistory.content).ilike(pattern)) + .group_by(col(PlatformMessageHistory.user_id)) + ) + + ids_result = await session.execute(subq) + id_rows = [r[0] for r in ids_result.fetchall() if r[0] is not None] + + if id_rows: + q = select(PlatformMessageHistory).where(col(PlatformMessageHistory.id).in_(id_rows)) + q = q.order_by(desc(PlatformMessageHistory.created_at)) + hist_result = await session.execute(q) + histories = hist_result.scalars().all() + + for history in histories: + # find associated session to get display_name/created_at/updated_at + ps_q = select(PlatformSession).where(col(PlatformSession.session_id) == history.user_id) + ps_res = await session.execute(ps_q) + ps = ps_res.scalar_one_or_none() + text = None + try: + # convert content json to plain text similar to ChatRoute._extract_plain_text + msg = history.content + if isinstance(msg, dict): + message = msg.get("message") + if isinstance(message, str): + text = message + elif isinstance(message, list): + parts = [] + for part in message: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type == "plain" and part.get("text"): + parts.append(str(part.get("text"))) + elif part_type == "reply" and part.get("selected_text"): + parts.append(str(part.get("selected_text"))) + text = "\n".join(parts) + except Exception: + text = None + + if not text: + # fallback to stringified JSON + text = json.dumps(history.content, ensure_ascii=False) + + lower_text = text.lower() if isinstance(text, str) else "" + match_index = lower_text.find(query.lower()) + if match_index == -1: + continue + snippet, snippet_start = self._build_snippet(text, match_index, len(query), context_len) + results.append( + { + "session_id": history.user_id, + "title": ps.display_name if ps else None, + "match_field": "content", + "match_index": match_index, + "match_length": len(query), + "snippet": snippet, + "snippet_start": snippet_start, + "created_at": ps.created_at.astimezone().isoformat() if ps else history.created_at.astimezone().isoformat(), + "updated_at": ps.updated_at.astimezone().isoformat() if ps else history.updated_at.astimezone().isoformat(), + } + ) + + # sort and paginate + results.sort(key=lambda item: item["updated_at"], reverse=True) + total = len(results) + offset = (page - 1) * page_size + paged = results[offset : offset + page_size] + return paged, total + async def get_platform_message_history_by_id( self, message_id: int ) -> PlatformMessageHistory | None: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 92ff4c3fe..805a524c6 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -46,6 +46,7 @@ class ChatRoute(Route): "POST", self.update_session_display_name, ), + "/chat/search": ("GET", self.search_sessions), "/chat/get_file": ("GET", self.get_file), "/chat/get_attachment": ("GET", self.get_attachment), "/chat/post_file": ("POST", self.post_file), @@ -63,6 +64,35 @@ class ChatRoute(Route): self.running_convs: dict[str, bool] = {} + @staticmethod + def _extract_plain_text(content: dict) -> str: + if not isinstance(content, dict): + return "" + message = content.get("message") + if isinstance(message, str): + return message + if not isinstance(message, list): + return "" + + parts = [] + for part in message: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type == "plain" and part.get("text"): + parts.append(str(part.get("text"))) + elif part_type == "reply" and part.get("selected_text"): + parts.append(str(part.get("selected_text"))) + return "\n".join(parts) + + @staticmethod + def _build_snippet(text: str, match_index: int, match_length: int, context_len: int): + if match_index < 0: + return "", 0 + start = max(match_index - context_len, 0) + end = min(match_index + match_length + context_len, len(text)) + return text[start:end], start + async def get_file(self): filename = request.args.get("filename") if not filename: @@ -731,6 +761,57 @@ class ChatRoute(Route): return Response().ok(data=sessions_data).__dict__ + async def search_sessions(self): + """Search sessions by title or content, with pagination.""" + username = g.get("username", "guest") + query = request.args.get("query", "", type=str).strip() + page = max(request.args.get("page", 1, type=int), 1) + page_size = min(max(request.args.get("page_size", 10, type=int), 1), 100) + context_len = min(max(request.args.get("context", 40, type=int), 0), 200) + + if not query: + return ( + Response() + .ok( + data={ + "results": [], + "pagination": { + "page": page, + "page_size": page_size, + "total": 0, + "total_pages": 1, + }, + } + ) + .__dict__ + ) + + # Delegate searching to the database implementation for efficiency + paged_results, total = await self.db.search_platform_sessions( + creator=username, + query=query, + context_len=context_len, + page=page, + page_size=page_size, + ) + + total_pages = (total + page_size - 1) // page_size if total > 0 else 1 + return ( + Response() + .ok( + data={ + "results": paged_results, + "pagination": { + "page": page, + "page_size": page_size, + "total": total, + "total_pages": total_pages, + }, + } + ) + .__dict__ + ) + async def get_session(self): """Get session information and message history by session_id.""" session_id = request.args.get("session_id") diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 71e46e690..d650b1e50 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -26,6 +26,7 @@ @createProject="showCreateProjectDialog" @editProject="showEditProjectDialog" @deleteProject="handleDeleteProject" + @openSearch="handleOpenSearch" /> @@ -42,65 +43,99 @@ - - - -
- -
-
- - - - + @@ -200,6 +208,7 @@ + + diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue index fe25ef34c..f4e875a25 100644 --- a/dashboard/src/components/chat/ConversationSidebar.vue +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -20,6 +20,15 @@ +
+ + {{ t('core.actions.search') }} + + +
+
{{ tm('actions.newChat') }} @@ -178,6 +187,7 @@ const emit = defineEmits<{ createProject: []; editProject: [project: Project]; deleteProject: [projectId: string]; + openSearch: []; }>(); const { t } = useI18n(); @@ -264,6 +274,13 @@ function handleDeleteConversation(session: Session) { padding: 8px 16px !important; } +.search-chat-btn { + justify-content: flex-start; + background-color: transparent !important; + border-radius: 20px; + padding: 8px 16px !important; +} + .conversation-item { /* margin-bottom: 4px; */ border-radius: 20px !important; @@ -359,4 +376,3 @@ function handleDeleteConversation(session: Session) { justify-content: center; } - diff --git a/dashboard/src/i18n/locales/en-US/features/chat.json b/dashboard/src/i18n/locales/en-US/features/chat.json index 385914065..c81412109 100644 --- a/dashboard/src/i18n/locales/en-US/features/chat.json +++ b/dashboard/src/i18n/locales/en-US/features/chat.json @@ -98,6 +98,18 @@ "noSessions": "No conversations in this project", "confirmDelete": "Are you sure you want to delete project \"{title}\"? Conversations in this project will not be deleted." }, + "search": { + "title": "Search", + "placeholder": "Enter keywords to search titles or content", + "hint": "Enter keywords to start searching", + "noResults": "No matching conversations found", + "matchTitle": "Title match", + "matchContent": "Content match", + "matchPosition": "Match position", + "createdAt": "Created", + "updatedAt": "Updated", + "pageSize": "Items per page" + }, "time": { "today": "Today", "yesterday": "Yesterday" @@ -133,4 +145,4 @@ "sendMessageFailed": "Failed to send message, please try again", "createSessionFailed": "Failed to create session, please refresh the page" } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/zh-CN/features/chat.json b/dashboard/src/i18n/locales/zh-CN/features/chat.json index 086e335a1..e3b13706e 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/chat.json +++ b/dashboard/src/i18n/locales/zh-CN/features/chat.json @@ -100,6 +100,18 @@ "noSessions": "该项目暂无对话", "confirmDelete": "确定要删除项目 \"{title}\" 吗?项目中的对话不会被删除。" }, + "search": { + "title": "搜索", + "placeholder": "输入关键词搜索标题或内容", + "hint": "输入关键词开始搜索", + "noResults": "没有找到匹配的对话", + "matchTitle": "标题匹配", + "matchContent": "内容匹配", + "matchPosition": "匹配位置", + "createdAt": "创建", + "updatedAt": "更新", + "pageSize": "每页条数" + }, "time": { "today": "今天", "yesterday": "昨天" @@ -135,4 +147,4 @@ "sendMessageFailed": "发送消息失败,请重试", "createSessionFailed": "创建会话失败,请刷新页面重试" } -} \ No newline at end of file +} diff --git a/dashboard/src/stores/chatSearch.ts b/dashboard/src/stores/chatSearch.ts new file mode 100644 index 000000000..7d0818f9e --- /dev/null +++ b/dashboard/src/stores/chatSearch.ts @@ -0,0 +1,112 @@ +import { defineStore } from 'pinia'; +import { ref } from 'vue'; +import axios from 'axios'; + +export interface ChatSearchResult { + session_id: string; + title: string | null; + match_field: 'title' | 'content'; + match_index: number; + match_length: number; + snippet: string; + snippet_start: number; + created_at: string; + updated_at: string; +} + +interface ChatSearchPagination { + page: number; + page_size: number; + total: number; + total_pages: number; +} + +const defaultPagination: ChatSearchPagination = { + page: 1, + page_size: 10, + total: 0, + total_pages: 1 +}; + +export const useChatSearchStore = defineStore('chatSearch', () => { + const active = ref(false); + const query = ref(''); + const results = ref([]); + const pagination = ref({ ...defaultPagination }); + const isLoading = ref(false); + const searchPerformed = ref(false); + const contextLength = ref(40); + + function openSearch() { + active.value = true; + } + + function closeSearch() { + active.value = false; + } + + async function search() { + const trimmedQuery = query.value.trim(); + if (!trimmedQuery) { + results.value = []; + pagination.value = { ...defaultPagination }; + searchPerformed.value = false; + return; + } + + searchPerformed.value = true; + isLoading.value = true; + + try { + const response = await axios.get('/api/chat/search', { + params: { + query: trimmedQuery, + page: pagination.value.page, + page_size: pagination.value.page_size, + context: contextLength.value + } + }); + + const data = response.data?.data || {}; + results.value = data.results || []; + pagination.value = data.pagination || { ...defaultPagination }; + } catch (error) { + console.error('Search sessions failed:', error); + results.value = []; + } finally { + isLoading.value = false; + } + } + + async function setPage(page: number) { + pagination.value.page = page; + await search(); + } + + async function setPageSize(pageSize: number) { + pagination.value.page_size = pageSize; + pagination.value.page = 1; + await search(); + } + + async function runNewSearch() { + pagination.value.page = 1; + await search(); + } + + return { + active, + query, + results, + pagination, + isLoading, + searchPerformed, + contextLength, + openSearch, + closeSearch, + search, + setPage, + setPageSize, + runNewSearch + }; +});