feat: 添加分页和搜索功能以获取会话列表,优化前端与后端的数据交互 (#2906)

* feat: 添加分页和搜索功能以获取会话列表,优化前端与后端的数据交互

* fix: 修复会话计数显示,使用总项数替代会话数组长度

* fix: 将参数类型和名称与实现内容匹配。

* perf: convert for loop into list comprehension

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* fix: type checking error

* fix: 优化 persona_id 的获取逻辑

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
ctrlkk
2025-09-28 23:25:30 +08:00
committed by GitHub
parent 9c6b31e71c
commit 68ff8951de
5 changed files with 260 additions and 100 deletions
+12 -1
View File
@@ -164,7 +164,7 @@ class BaseDatabase(abc.ABC):
self,
platform_id: str,
user_id: str,
content: list[dict],
content: dict,
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
@@ -287,3 +287,14 @@ class BaseDatabase(abc.ABC):
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation."""
# ...
@abc.abstractmethod
async def get_session_conversations(
self,
page: int = 1,
page_size: int = 20,
search_query: str | None = None,
platform: str | None = None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
...
+8 -4
View File
@@ -75,7 +75,9 @@ class Persona(SQLModel, table=True):
__tablename__ = "personas"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
@@ -135,7 +137,9 @@ class PlatformMessageHistory(SQLModel, table=True):
__tablename__ = "platform_message_history"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
@@ -158,8 +162,8 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments"
inner_attachment_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
inner_attachment_id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
attachment_id: str = Field(
max_length=36,
+144 -34
View File
@@ -15,10 +15,8 @@ from astrbot.core.db.po import (
SQLModel,
)
from sqlalchemy import select, update, delete, text
from sqlmodel import select, update, delete, text, func, or_, desc, col
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
from sqlalchemy import or_
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
@@ -42,10 +40,10 @@ class SQLiteDatabase(BaseDatabase):
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime = None,
platform_id,
platform_type,
count=1,
timestamp=None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
@@ -76,7 +74,9 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
select(func.count(col(PlatformStat.platform_id))).select_from(
PlatformStat
)
)
count = result.scalar_one_or_none()
return count if count is not None else 0
@@ -96,7 +96,7 @@ class SQLiteDatabase(BaseDatabase):
"""),
{"start_time": start_time},
)
return result.scalars().all()
return list(result.scalars().all())
# ====
# Conversation Management
@@ -112,7 +112,7 @@ class SQLiteDatabase(BaseDatabase):
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(ConversationV2.created_at.desc())
query = query.order_by(desc(ConversationV2.created_at))
result = await session.execute(query)
return result.scalars().all()
@@ -130,7 +130,7 @@ class SQLiteDatabase(BaseDatabase):
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(ConversationV2.created_at.desc())
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
@@ -151,25 +151,25 @@ class SQLiteDatabase(BaseDatabase):
if platform_ids:
base_query = base_query.where(
ConversationV2.platform_id.in_(platform_ids)
col(ConversationV2.platform_id).in_(platform_ids)
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
or_(
ConversationV2.title.ilike(f"%{search_query}%"),
ConversationV2.content.ilike(f"%{search_query}%"),
ConversationV2.user_id.ilike(f"%{search_query}%"),
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
)
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
ConversationV2.platform_id.in_(kwargs["platforms"])
col(ConversationV2.platform_id).in_(kwargs["platforms"])
)
# Get total count matching the filters
@@ -180,7 +180,7 @@ class SQLiteDatabase(BaseDatabase):
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(ConversationV2.created_at.desc())
base_query.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
@@ -226,7 +226,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
ConversationV2.conversation_id == cid
col(ConversationV2.conversation_id) == cid
)
values = {}
if title is not None:
@@ -246,7 +246,9 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid
)
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
@@ -254,9 +256,116 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.user_id == user_id)
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id"
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
)
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern)
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
)
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern)
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history(
self,
platform_id,
@@ -290,9 +399,9 @@ class SQLiteDatabase(BaseDatabase):
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
PlatformMessageHistory.created_at < cutoff_time,
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) < cutoff_time,
)
)
@@ -309,7 +418,7 @@ class SQLiteDatabase(BaseDatabase):
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(PlatformMessageHistory.created_at.desc())
.order_by(desc(PlatformMessageHistory.created_at))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
@@ -331,7 +440,7 @@ class SQLiteDatabase(BaseDatabase):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.id == attachment_id)
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
@@ -374,7 +483,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(Persona.persona_id == persona_id)
query = update(Persona).where(col(Persona.persona_id) == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
@@ -394,7 +503,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(Persona.persona_id == persona_id)
delete(Persona).where(col(Persona.persona_id) == persona_id)
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
@@ -449,9 +558,9 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
)
)
await session.commit()
@@ -463,7 +572,8 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope, Preference.scope_id == scope_id
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
)
)
await session.commit()
@@ -490,7 +600,7 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=data.timestamp.timestamp(),
timestamp=int(data.timestamp.timestamp()),
)
)
return deprecated_stats
@@ -548,7 +658,7 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=start_time.timestamp(),
timestamp=int(start_time.timestamp()),
)
)
return deprecated_stats
+37 -29
View File
@@ -20,6 +20,7 @@ class SessionManagementRoute(Route):
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.db_helper = db_helper
self.routes = {
"/session/list": ("GET", self.list_sessions),
"/session/update_persona": ("POST", self.update_session_persona),
@@ -39,22 +40,42 @@ class SessionManagementRoute(Route):
async def list_sessions(self):
"""获取所有会话的列表,包括 persona 和 provider 信息"""
try:
preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[])
session_conversations = {}
for pref in preferences:
session_conversations[pref.scope_id] = pref.value["val"]
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 20))
search_query = request.args.get("search", "")
platform = request.args.get("platform", "")
# 获取活跃的会话数据(处于对话内的会话)
sessions_data, total = await self.db_helper.get_session_conversations(
page, page_size, search_query, platform
)
provider_manager = self.core_lifecycle.provider_manager
persona_mgr = self.core_lifecycle.persona_mgr
personas = persona_mgr.personas_v3
sessions = []
# 构建会话信息
for session_id, conversation_id in session_conversations.items():
# 循环补充非数据库信息,如 provider 和 session 状态
for data in sessions_data:
session_id = data["session_id"]
conversation_id = data["conversation_id"]
conv_persona_id = data["persona_id"]
title = data["title"]
persona_name = data["persona_name"]
# 处理 persona 显示
if conv_persona_id == "[%None]":
persona_name = "无人格"
else:
default_persona = persona_mgr.selected_default_persona_v3
if default_persona:
persona_name = default_persona["name"]
session_info = {
"session_id": session_id,
"conversation_id": conversation_id,
"persona_id": None,
"persona_id": persona_name,
"chat_provider_id": None,
"stt_provider_id": None,
"tts_provider_id": None,
@@ -79,31 +100,10 @@ class SessionManagementRoute(Route):
"session_raw_name": session_id.split(":")[2]
if session_id.count(":") >= 2
else session_id,
"title": title,
}
# 获取对话信息
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=session_id, conversation_id=conversation_id
)
if conversation:
session_info["persona_id"] = conversation.persona_id
# 查找 persona 名称
if conversation.persona_id and conversation.persona_id != "[%None]":
for persona in personas:
if persona["name"] == conversation.persona_id:
session_info["persona_id"] = persona["name"]
break
elif conversation.persona_id == "[%None]":
session_info["persona_id"] = "无人格"
else:
# 使用默认人格
default_persona = persona_mgr.selected_default_persona_v3
if default_persona:
session_info["persona_id"] = default_persona["name"]
# 获取 provider 信息
provider_manager = self.core_lifecycle.provider_manager
chat_provider = provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION, umo=session_id
)
@@ -172,6 +172,14 @@ class SessionManagementRoute(Route):
"available_chat_providers": available_chat_providers,
"available_stt_providers": available_stt_providers,
"available_tts_providers": available_tts_providers,
"pagination": {
"page": page,
"page_size": page_size,
"total": total,
"total_pages": (total + page_size - 1) // page_size
if page_size > 0
else 0,
},
}
return Response().ok(result).__dict__
+59 -32
View File
@@ -4,13 +4,13 @@
<v-card flat>
<v-card-title class="d-flex align-center py-3 px-4">
<span class="text-h4">{{ tm('sessions.activeSessions') }}</span>
<v-chip size="small" class="ml-2">{{ sessions.length }} {{ tm('sessions.sessionCount') }}</v-chip>
<v-chip size="small" class="ml-2">{{ totalItems }} {{ tm('sessions.sessionCount') }}</v-chip>
<v-row class="me-4 ms-4" dense>
<v-text-field v-model="searchQuery" prepend-inner-icon="mdi-magnify" :label="tm('search.placeholder')"
hide-details clearable variant="solo-filled" flat class="me-4" density="compact"></v-text-field>
hide-details clearable variant="solo-filled" flat class="me-4" density="compact" @update:model-value="handleSearchChange"></v-text-field>
<v-select v-model="filterPlatform" :items="platformOptions" :label="tm('search.platformFilter')"
hide-details clearable variant="solo-filled" flat class="me-4" style="max-width: 150px;"
density="compact"></v-select>
density="compact" @update:model-value="handlePlatformChange"></v-select>
</v-row>
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="refreshSessions" :loading="loading"
size="small">
@@ -22,8 +22,17 @@
<v-card-text class="pa-0">
<!-- 会话列表 -->
<v-data-table :headers="headers" :items="filteredSessions" :loading="loading" :items-per-page="itemsPerPage" density="compact"
class="elevation-0" style="font-size: 11px;">
<v-data-table-server
:headers="headers"
:items="sessions"
:loading="loading"
:items-per-page="itemsPerPage"
:page="currentPage"
:items-length="totalItems"
@update:options="handlePaginationUpdate"
density="compact"
class="elevation-0"
style="font-size: 11px;">
<!-- 会话启停 -->
<template v-slot:item.session_enabled="{ item }">
@@ -160,7 +169,7 @@
<div class="text-body-2 text-grey-500">{{ tm('sessions.noActiveSessionsDesc') }}</div>
</div>
</template>
</v-data-table>
</v-data-table-server>
</v-card-text>
</v-card>
@@ -357,7 +366,10 @@ export default {
filterPlatform: null,
// 分页相关
currentPage: 1,
itemsPerPage: 10,
totalItems: 0,
totalPages: 0,
// 可用选项
availablePersonas: [],
@@ -424,30 +436,6 @@ export default {
]
},
// 懒加载过滤会话 - 使用客户端分页
filteredSessions() {
let filtered = this.sessions;
// 搜索筛选
if (this.searchQuery) {
const query = this.searchQuery.toLowerCase().trim();
filtered = filtered.filter(session =>
session.session_name.toLowerCase().includes(query) ||
session.platform.toLowerCase().includes(query) ||
session.persona_name?.toLowerCase().includes(query) ||
session.chat_provider_name?.toLowerCase().includes(query) ||
session.session_id.toLowerCase().includes(query)
);
}
// 平台筛选
if (this.filterPlatform) {
filtered = filtered.filter(session => session.platform === this.filterPlatform);
}
return filtered;
},
platformOptions() {
const platforms = [...new Set(this.sessions.map(s => s.platform))];
return platforms.map(p => ({ title: p, value: p }));
@@ -494,7 +482,20 @@ export default {
async loadSessions() {
this.loading = true;
try {
const response = await axios.get('/api/session/list');
const params = {
page: this.currentPage,
page_size: this.itemsPerPage
};
// 添加搜索和平台筛选参数
if (this.searchQuery) {
params.search = this.searchQuery;
}
if (this.filterPlatform) {
params.platform = this.filterPlatform;
}
const response = await axios.get('/api/session/list', { params });
if (response.data.status === 'ok') {
const data = response.data.data;
this.sessions = data.sessions.map(session => ({
@@ -507,6 +508,13 @@ export default {
this.availableChatProviders = data.available_chat_providers;
this.availableSttProviders = data.available_stt_providers;
this.availableTtsProviders = data.available_tts_providers;
// 处理分页信息
if (data.pagination) {
this.totalItems = data.pagination.total;
this.totalPages = data.pagination.total_pages;
this.currentPage = data.pagination.page;
}
} else {
this.showError(response.data.message || this.tm('messages.loadSessionsError'));
}
@@ -679,7 +687,7 @@ export default {
let totalErrorCount = 0;
let allErrorSessions = [];
const sessions = this.filteredSessions;
const sessions = this.sessions;
try {
// 定义批量操作任务
@@ -936,6 +944,25 @@ export default {
session.deleting = false;
},
// 处理分页更新事件
handlePaginationUpdate(options) {
this.currentPage = options.page;
this.itemsPerPage = options.itemsPerPage;
this.loadSessions();
},
// 处理搜索变化
handleSearchChange() {
this.currentPage = 1; // 重置到第一页
this.loadSessions();
},
// 处理平台筛选变化
handlePlatformChange() {
this.currentPage = 1; // 重置到第一页
this.loadSessions();
},
},
}
</script>