refactor: change to platform session

This commit is contained in:
Soulter
2025-11-18 22:37:55 +08:00
parent 0fe87d6b98
commit 0e2adab3fd
9 changed files with 262 additions and 80 deletions
+1 -1
View File
@@ -104,7 +104,7 @@ class AstrBotCoreLifecycle:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# 4.6 to 4.7 migration for webchat sessions and group feature
# 4.6 to 4.7 migration for platform sessions and group feature
try:
await migrate_46_to_47(self.db)
except Exception as e:
+20 -14
View File
@@ -13,10 +13,10 @@ from astrbot.core.db.po import (
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
Stats,
WebChatSession,
)
@@ -316,43 +316,49 @@ class BaseDatabase(abc.ABC):
...
# ====
# WebChat Session Management
# Platform Session Management
# ====
@abc.abstractmethod
async def create_webchat_session(
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> WebChatSession:
"""Create a new WebChat session."""
) -> PlatformSession:
"""Create a new Platform session."""
...
@abc.abstractmethod
async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None:
"""Get a WebChat session by its ID."""
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
...
@abc.abstractmethod
async def get_webchat_sessions_by_creator(
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[WebChatSession]:
"""Get all WebChat sessions for a specific creator (username)."""
) -> list[PlatformSession]:
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
...
@abc.abstractmethod
async def update_webchat_session(
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a WebChat session's updated_at timestamp."""
"""Update a Platform session's updated_at timestamp and optionally display_name."""
...
@abc.abstractmethod
async def delete_webchat_session(self, session_id: str) -> None:
"""Delete a WebChat session by its ID."""
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
...
+17 -13
View File
@@ -1,6 +1,12 @@
"""Migration script from version 4.6 to 4.7.
This migration creates WebChat sessions from existing platform_message_history records.
This migration creates PlatformSession from existing platform_message_history records.
Changes:
- Creates platform_sessions table
- Adds platform_id field (default: 'webchat')
- Adds display_name field
- Session_id format: {platform_id}_{uuid}
"""
from sqlalchemy import func, select
@@ -12,10 +18,10 @@ from astrbot.core.db.po import PlatformMessageHistory
async def migrate_46_to_47(db_helper: BaseDatabase):
"""Migrate WebChat data to the new session table.
"""Create PlatformSession records from platform_message_history.
This migration extracts all unique user_ids from platform_message_history
where platform_id='webchat' and creates corresponding WebChatSession records.
where platform_id='webchat' and creates corresponding PlatformSession records.
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
@@ -28,7 +34,7 @@ async def migrate_46_to_47(db_helper: BaseDatabase):
try:
async with db_helper.get_db() as session:
# 1. 查询所有 webchat 的唯一 user_id 以及它们的最早和最新消息时间
# 从 platform_message_history 创建 PlatformSession
query = (
select(
col(PlatformMessageHistory.user_id),
@@ -51,7 +57,7 @@ async def migrate_46_to_47(db_helper: BaseDatabase):
logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移")
# 2. 为每个 user_id 创建 WebChatSession 记录
# 为每个 user_id 创建 PlatformSession 记录
migrated_count = 0
skipped_count = 0
@@ -60,28 +66,26 @@ async def migrate_46_to_47(db_helper: BaseDatabase):
session_id = user_id
# sender_name 通常是 username,但可能为 None
# 从第一条消息中提取 creator
creator = sender_name if sender_name else "guest"
# 检查是否已经存在该会话
existing_session = await db_helper.get_webchat_session_by_id(session_id)
existing_session = await db_helper.get_platform_session_by_id(
session_id
)
if existing_session:
logger.debug(f"会话 {session_id} 已存在,跳过")
skipped_count += 1
continue
# 创建新的 WebChatSession
# 创建新的 PlatformSession
try:
await db_helper.create_webchat_session(
await db_helper.create_platform_session(
creator=creator,
session_id=session_id,
platform_id="webchat",
is_group=0,
)
# 更新时间戳以匹配历史记录
# 注意:这里我们需要直接更新数据库,因为 create 方法会设置当前时间
# 但我们希望保留原始的创建和更新时间
migrated_count += 1
if migrated_count % 100 == 0:
+12 -8
View File
@@ -155,14 +155,14 @@ class PlatformMessageHistory(SQLModel, table=True):
)
class WebChatSession(SQLModel, table=True):
"""WebChat session table for managing user sessions.
class PlatformSession(SQLModel, table=True):
"""Platform session table for managing user sessions across different platforms.
A session represents a chat window for a specific user. Each session can have
multiple conversations (对话) associated with it.
A session represents a chat window for a specific user on a specific platform.
Each session can have multiple conversations (对话) associated with it.
"""
__tablename__ = "webchat_sessions"
__tablename__ = "platform_sessions"
inner_id: int | None = Field(
primary_key=True,
@@ -170,13 +170,17 @@ class WebChatSession(SQLModel, table=True):
default=None,
)
session_id: str = Field(
max_length=36,
max_length=100,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
default_factory=lambda: f"webchat_{uuid.uuid4()}",
)
platform_id: str = Field(default="webchat", nullable=False)
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
creator: str = Field(nullable=False)
"""Username of the session creator"""
display_name: str | None = Field(default=None, max_length=255)
"""Display name for the session"""
is_group: int = Field(default=0, nullable=False)
"""0 for private chat, 1 for group chat (not implemented yet)"""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@@ -188,7 +192,7 @@ class WebChatSession(SQLModel, table=True):
__table_args__ = (
UniqueConstraint(
"session_id",
name="uix_webchat_session_id",
name="uix_platform_session_id",
),
)
+44 -25
View File
@@ -1,6 +1,7 @@
import asyncio
import threading
import typing as T
import uuid
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
@@ -12,10 +13,10 @@ from astrbot.core.db.po import (
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
SQLModel,
WebChatSession,
)
from astrbot.core.db.po import (
Platform as DeprecatedPlatformStat,
@@ -712,25 +713,32 @@ class SQLiteDatabase(BaseDatabase):
return result
# ====
# WebChat Session Management
# Platform Session Management
# ====
async def create_webchat_session(
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> WebChatSession:
"""Create a new WebChat session."""
) -> PlatformSession:
"""Create a new Platform session."""
kwargs = {}
if session_id:
kwargs["session_id"] = session_id
else:
# Auto-generate session_id with platform_id prefix
kwargs["session_id"] = f"{platform_id}_{uuid.uuid4()}"
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_session = WebChatSession(
new_session = PlatformSession(
creator=creator,
platform_id=platform_id,
display_name=display_name,
is_group=is_group,
**kwargs,
)
@@ -739,57 +747,68 @@ class SQLiteDatabase(BaseDatabase):
await session.refresh(new_session)
return new_session
async def get_webchat_session_by_id(self, session_id: str) -> WebChatSession | None:
"""Get a WebChat session by its ID."""
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(WebChatSession).where(
WebChatSession.session_id == session_id,
query = select(PlatformSession).where(
PlatformSession.session_id == session_id,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_webchat_sessions_by_creator(
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[WebChatSession]:
"""Get all WebChat sessions for a specific creator (username)."""
) -> list[PlatformSession]:
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = select(PlatformSession).where(PlatformSession.creator == creator)
if platform_id:
query = query.where(PlatformSession.platform_id == platform_id)
query = (
select(WebChatSession)
.where(WebChatSession.creator == creator)
.order_by(desc(WebChatSession.updated_at))
query.order_by(desc(PlatformSession.updated_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_webchat_session(
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a WebChat session's updated_at timestamp."""
"""Update a Platform session's updated_at timestamp and optionally display_name."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values = {"updated_at": datetime.now()}
if display_name is not None:
values["display_name"] = display_name
await session.execute(
update(WebChatSession)
.where(WebChatSession.session_id == session_id)
.values(updated_at=datetime.now()),
update(PlatformSession)
.where(PlatformSession.session_id == session_id)
.values(**values),
)
async def delete_webchat_session(self, session_id: str) -> None:
"""Delete a WebChat session by its ID."""
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(WebChatSession).where(
WebChatSession.session_id == session_id,
delete(PlatformSession).where(
PlatformSession.session_id == session_id,
),
)
+70 -13
View File
@@ -39,6 +39,10 @@ class ChatRoute(Route):
"/chat/sessions": ("GET", self.get_sessions),
"/chat/get_session": ("GET", self.get_session),
"/chat/delete_session": ("GET", self.delete_webchat_session),
"/chat/update_session_display_name": (
"POST",
self.update_session_display_name,
),
"/chat/get_file": ("GET", self.get_file),
"/chat/post_image": ("POST", self.post_image),
"/chat/post_file": ("POST", self.post_file),
@@ -246,56 +250,74 @@ class ChatRoute(Route):
return response
async def delete_webchat_session(self):
"""Delete a WebChat session and all its related data."""
"""Delete a Platform session and all its related data."""
session_id = request.args.get("session_id")
if not session_id:
return Response().error("Missing key: session_id").__dict__
username = g.get("username", "guest")
# 验证会话是否存在且属于当前用户
session = await self.db.get_webchat_session_by_id(session_id)
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
# 删除该会话下的所有对话
unified_msg_origin = f"webchat:FriendMessage:webchat!{username}!{session_id}"
unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}"
await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
# 删除消息历史
await self.platform_history_mgr.delete(
platform_id="webchat",
platform_id=session.platform_id,
user_id=session_id,
offset_sec=99999999,
)
# 清理队列
webchat_queue_mgr.remove_queues(session_id)
# 清理队列(仅对 webchat
if session.platform_id == "webchat":
webchat_queue_mgr.remove_queues(session_id)
# 删除会话
await self.db.delete_webchat_session(session_id)
await self.db.delete_platform_session(session_id)
return Response().ok().__dict__
async def new_session(self):
"""Create a new WebChat session."""
"""Create a new Platform session (default: webchat)."""
username = g.get("username", "guest")
# 获取可选的 platform_id 参数,默认为 webchat
platform_id = request.args.get("platform_id", "webchat")
# 创建新会话
session = await self.db.create_webchat_session(
session = await self.db.create_platform_session(
creator=username,
platform_id=platform_id,
is_group=0,
)
return Response().ok(data={"session_id": session.session_id}).__dict__
return (
Response()
.ok(
data={
"session_id": session.session_id,
"platform_id": session.platform_id,
}
)
.__dict__
)
async def get_sessions(self):
"""Get all WebChat sessions for the current user."""
"""Get all Platform sessions for the current user."""
username = g.get("username", "guest")
sessions = await self.db.get_webchat_sessions_by_creator(
# 获取可选的 platform_id 参数
platform_id = request.args.get("platform_id")
sessions = await self.db.get_platform_sessions_by_creator(
creator=username,
platform_id=platform_id,
page=1,
page_size=100, # 暂时返回前100个
)
@@ -306,7 +328,9 @@ class ChatRoute(Route):
sessions_data.append(
{
"session_id": session.session_id,
"platform_id": session.platform_id,
"creator": session.creator,
"display_name": session.display_name,
"is_group": session.is_group,
"created_at": int(session.created_at.timestamp()),
"updated_at": int(session.updated_at.timestamp()),
@@ -321,9 +345,13 @@ class ChatRoute(Route):
if not session_id:
return Response().error("Missing key: session_id").__dict__
# 获取会话信息以确定 platform_id
session = await self.db.get_platform_session_by_id(session_id)
platform_id = session.platform_id if session else "webchat"
# Get platform message history using session_id
history_ls = await self.platform_history_mgr.get(
platform_id="webchat",
platform_id=platform_id,
user_id=session_id,
page=1,
page_size=1000,
@@ -341,3 +369,32 @@ class ChatRoute(Route):
)
.__dict__
)
async def update_session_display_name(self):
"""Update a Platform session's display name."""
post_data = await request.json
session_id = post_data.get("session_id")
display_name = post_data.get("display_name")
if not session_id:
return Response().error("Missing key: session_id").__dict__
if display_name is None:
return Response().error("Missing key: display_name").__dict__
username = g.get("username", "guest")
# 验证会话是否存在且属于当前用户
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
# 更新 display_name
await self.db.update_platform_session(
session_id=session_id,
display_name=display_name,
)
return Response().ok().__dict__
+88 -4
View File
@@ -59,7 +59,7 @@
<v-list-item v-for="(session, i) in sessions" :key="session.session_id" :value="session.session_id"
rounded="lg" class="session-item" active-color="secondary">
<v-list-item-title v-if="!sidebarCollapsed || isMobile" class="session-title">
{{ tm('conversation.newConversation') }}
{{ session.display_name || tm('conversation.newConversation') }}
</v-list-item-title>
<v-list-item-subtitle v-if="!sidebarCollapsed || isMobile" class="timestamp">
{{ formatDate(session.updated_at) }}
@@ -67,6 +67,9 @@
<template v-if="!sidebarCollapsed || isMobile" v-slot:append>
<div class="session-actions">
<v-btn icon="mdi-pencil" size="x-small" variant="text"
class="edit-session-btn" color="primary"
@click.stop="editSessionDisplayName(session)" />
<v-btn icon="mdi-delete" size="x-small" variant="text"
class="delete-session-btn" color="error"
@click.stop="deleteSession(session.session_id)" />
@@ -228,6 +231,34 @@
</v-card-text>
</v-card>
</v-dialog>
<!-- 编辑会话名称对话框 -->
<v-dialog v-model="editDisplayNameDialog" max-width="500px">
<v-card>
<v-card-title class="d-flex justify-space-between align-center pa-4">
<span>{{ tm('conversation.editDisplayName') }}</span>
<v-btn icon="mdi-close" variant="text" @click="editDisplayNameDialog = false" />
</v-card-title>
<v-card-text class="pa-4">
<v-text-field
v-model="editingDisplayName"
:label="tm('conversation.displayName')"
variant="outlined"
density="comfortable"
@keydown.enter="saveDisplayName"
/>
</v-card-text>
<v-card-actions class="pa-4">
<v-spacer />
<v-btn variant="text" @click="editDisplayNameDialog = false">
{{ t('core.common.cancel') }}
</v-btn>
<v-btn color="primary" variant="flat" @click="saveDisplayName">
{{ t('core.common.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<script>
@@ -272,7 +303,7 @@ export default {
return {
prompt: '',
messages: [],
sessions: [], // WebChat 会话列表
sessions: [], // Platform 会话列表
selectedSessions: [], // 当前选中的会话
currSessionId: '', // 当前会话ID
stagedImagesName: [], // 用于存储图片文件名的数组
@@ -318,6 +349,11 @@ export default {
// 手机端相关变量
isMobile: false,
mobileMenuOpen: false,
// 编辑会话名称相关变量
editDisplayNameDialog: false,
editingDisplayName: '',
editingSession: null,
}
},
@@ -1095,6 +1131,37 @@ export default {
});
this.mediaCache = {};
},
// 编辑会话显示名称
editSessionDisplayName(session) {
this.editingSession = session;
this.editingDisplayName = session.display_name || '';
this.editDisplayNameDialog = true;
},
// 保存会话显示名称
async saveDisplayName() {
if (!this.editingSession) return;
try {
await axios.post('/api/chat/update_session_display_name', {
session_id: this.editingSession.session_id,
display_name: this.editingDisplayName,
});
// 更新本地数据
this.editingSession.display_name = this.editingDisplayName;
// 刷新会话列表
this.getSessions();
this.editDisplayNameDialog = false;
useToast().success(this.tm('conversation.displayNameUpdated'), { timeout: 2000 });
} catch (err) {
console.error('更新会话名称失败:', err);
useToast().error(this.tm('conversation.displayNameUpdateFailed'), { timeout: 3000 });
}
},
},
}
</script>
@@ -1368,14 +1435,31 @@ export default {
transition: all 0.2s ease;
}
.session-item:hover .session-actions {
opacity: 1;
visibility: visible;
}
.session-actions {
display: flex;
gap: 4px;
opacity: 0;
visibility: hidden;
transition: all 0.2s ease;
}
.edit-title-btn,
.delete-conversation-btn {
.delete-conversation-btn,
.edit-session-btn,
.delete-session-btn {
opacity: 0.7;
transition: opacity 0.2s ease;
}
.edit-title-btn:hover,
.delete-conversation-btn:hover {
.delete-conversation-btn:hover,
.edit-session-btn:hover,
.delete-session-btn:hover {
opacity: 1;
}
@@ -47,7 +47,11 @@
"noHistory": "No conversation history",
"systemStatus": "System Status",
"llmService": "LLM Service",
"speechToText": "Speech to Text"
"speechToText": "Speech to Text",
"editDisplayName": "Edit Session Name",
"displayName": "Session Name",
"displayNameUpdated": "Session name updated",
"displayNameUpdateFailed": "Failed to update session name"
},
"modes": {
"darkMode": "Switch to Dark Mode",
@@ -47,7 +47,11 @@
"noHistory": "暂无对话历史",
"systemStatus": "系统状态",
"llmService": "LLM 服务",
"speechToText": "语音转文本"
"speechToText": "语音转文本",
"editDisplayName": "编辑会话名称",
"displayName": "会话名称",
"displayNameUpdated": "会话名称已更新",
"displayNameUpdateFailed": "更新会话名称失败"
},
"modes": {
"darkMode": "切换到夜间模式",