Files
AstrBot/astrbot/core/db/sqlite.py
T
Ruochen Pan 8910ab3a47 feat: implement persona folder for advanced persona management (#4443)
* feat(db): add persona folder management for hierarchical organization

Implement hierarchical folder structure for organizing personas:
- Add PersonaFolder model with recursive parent-child relationships
- Add folder_id and sort_order fields to Persona model
- Implement CRUD operations for persona folders in database layer
- Add migration support for existing databases
- Extend PersonaManager with folder management methods
- Add dashboard API routes for folder operations

* feat(persona): add batch sort order update endpoint for personas and folders

Add new API endpoint POST /persona/reorder to batch update sort_order
for both personas and folders. This enables drag-and-drop reordering
in the dashboard UI.

Changes:
- Add abstract batch_update_sort_order method to BaseDatabase
- Implement batch_update_sort_order in SQLiteDatabase
- Add batch_update_sort_order to PersonaManager with cache refresh
- Add reorder_items route handler with input validation

* feat(persona): add folder_id and sort_order params to persona creation

Extend persona creation flow to support folder placement and ordering:
- Add folder_id and sort_order parameters to insert_persona in db layer
- Update PersonaManager.create_persona to accept and pass folder params
- Add get_folder_detail API endpoint for retrieving folder information
- Include folder_id and sort_order in persona creation response
- Add session flush/refresh to return complete persona object

* feat(dashboard): implement persona folder management UI

- Add folder management system with tree view and breadcrumbs
- Implement create, rename, delete, and move operations for folders
- Add drag-and-drop support for organizing personas and folders
- Create new PersonaManager component and Pinia store for state management
- Refactor PersonaPage to support hierarchical structure
- Update locale files with folder-related translations
- Handle empty parent_id correctly in backend route

* feat(dashboard): centralize folder expansion state in persona store

Move folder expansion logic from local component state to global Pinia
store to persist expansion state.
- Add `expandedFolderIds` state and toggle actions to `personaStore`
- Update `FolderTreeNode` to use store state instead of local data
- Automatically navigate to target folder after moving a persona

* feat(dashboard): add reusable folder management component library

Extract folder management UI into reusable base components and create
persona-specific wrapper components that integrate with personaStore.

- Add base folder components (tree, breadcrumb, card, dialogs) with
  customizable labels for i18n support
- Create useFolderManager composable for folder state management
- Implement drag-and-drop support for moving personas between folders
- Add persona-specific wrapper components connecting to personaStore
- Reorganize PersonaManager into views/persona directory structure
- Include comprehensive README documentation for component usage

* refactor(dashboard): remove legacy persona folder management components

Remove deprecated persona folder management Vue components that have been
superseded by the new reusable folder management component library.

Deleted components:
- CreateFolderDialog.vue
- FolderBreadcrumb.vue
- FolderCard.vue
- FolderTree.vue
- FolderTreeNode.vue
- MoveTargetNode.vue
- MoveToFolderDialog.vue
- PersonaCard.vue
- PersonaManager.vue

These components are replaced by the centralized folder management
implementation introduced in commit 3fbb3db2.

* fix(dashboard): add delayed skeleton loading to prevent UI flicker

Implement a 150ms delay before showing the skeleton loader in
PersonaManager to prevent visual flicker during fast loading operations.

- Add showSkeleton state with timer-based delay mechanism
- Use v-fade-transition for smooth skeleton visibility transitions
- Clean up timer on component unmount to prevent memory leaks
- Only display skeleton when loading exceeds threshold duration

* feat(dashboard): add generic folder item selector component for persona selection

Introduce BaseFolderItemSelector.vue as a reusable component for selecting
items within folder hierarchies. Refactor PersonaSelector to use this new
base component instead of its previous flat list implementation.

Changes:
- Add BaseFolderItemSelector with folder tree navigation and item selection
- Extend folder types with SelectableItem and FolderItemSelectorLabels
- Refactor PersonaSelector to leverage the new base component
- Add i18n translations for rootFolder and emptyFolder labels

* feat(persona): add tree-view display for persona list command

Add hierarchical folder tree output for the persona list command,
showing personas organized by folders with visual tree connectors.

- Add _build_tree_output method for recursive tree structure rendering
- Display folders with 📁 icon and personas with 👤 icon
- Show root-level personas separately from folder contents
- Include total persona count in output

* refactor(persona): simplify tree-view output with shorter indentation lines

Replace complex tree connector logic with simpler depth-based indentation
using "│ " prefix. Remove unnecessary parameters (prefix, is_last) and
computed variables (has_content, total_items, item_idx) in favor of a
cleaner depth-based approach.

* feat(dashboard): add duplicate persona ID validation in create form

Add frontend validation to prevent creating personas with duplicate IDs.
Load existing persona IDs when opening the create form and validate
against them in real-time.

- Add existingPersonaIds array and loadExistingPersonaIds method
- Add validation rule to check for duplicate persona IDs
- Add i18n messages for duplicate ID error (en-US and zh-CN)
- Fix minLength validation to require at least 1 character

* i18n(persona): add createButton translation key for folder dialog

Move create button label to folder-specific translation path
instead of using generic buttons.create key.

* feat(persona): show target folder name in persona creation dialog

Add visual feedback showing which folder a new persona will be created in.

- Add info alert in PersonaForm displaying the target folder name
- Pass currentFolderName prop from PersonaManager and PersonaSelector
- Add recursive findFolderName helper to resolve folder ID to name
- Add i18n translations for createInFolder and rootFolder labels

* style:format code

* fix: remove 'persistent' attribute from dialog components

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-01-21 13:05:33 +08:00

1561 lines
56 KiB
Python

import asyncio
import threading
import typing as T
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import (
Attachment,
ChatUIProject,
CommandConfig,
CommandConflict,
ConversationV2,
Persona,
PersonaFolder,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
SessionProjectRelation,
SQLModel,
)
from astrbot.core.db.po import (
Platform as DeprecatedPlatformStat,
)
from astrbot.core.db.po import (
Stats as DeprecatedStats,
)
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
TxResult = T.TypeVar("TxResult")
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
super().__init__()
async def initialize(self) -> None:
"""Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
# 确保 personas 表有 folder_id 和 sort_order 列(前向兼容)
await self._ensure_persona_folder_columns(conn)
await conn.commit()
async def _ensure_persona_folder_columns(self, conn) -> None:
"""确保 personas 表有 folder_id 和 sort_order 列。
这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel
的 metadata.create_all 自动创建这些列。
"""
result = await conn.execute(text("PRAGMA table_info(personas)"))
columns = {row[1] for row in result.fetchall()}
if "folder_id" not in columns:
await conn.execute(
text(
"ALTER TABLE personas ADD COLUMN folder_id VARCHAR(36) DEFAULT NULL"
)
)
if "sort_order" not in columns:
await conn.execute(
text("ALTER TABLE personas ADD COLUMN sort_order INTEGER DEFAULT 0")
)
# ====
# Platform Statistics
# ====
async def insert_platform_stats(
self,
platform_id,
platform_type,
count=1,
timestamp=None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
if timestamp is None:
timestamp = datetime.now().replace(
minute=0,
second=0,
microsecond=0,
)
current_hour = timestamp
await session.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": current_hour,
"platform_id": platform_id,
"platform_type": platform_type,
"count": count,
},
)
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.count(col(PlatformStat.platform_id))).select_from(
PlatformStat,
),
)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
text("""
SELECT * FROM platform_stats
WHERE timestamp >= :start_time
GROUP BY platform_id
ORDER BY timestamp DESC
"""),
{"start_time": start_time},
)
return list(result.scalars().all())
# ====
# Conversation Management
# ====
async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
if user_id:
query = query.where(ConversationV2.user_id == user_id)
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(desc(ConversationV2.created_at))
result = await session.execute(query)
return result.scalars().all()
async def get_conversation_by_id(self, cid):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_all_conversations(self, page=1, page_size=20):
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size),
)
return result.scalars().all()
async def get_filtered_conversations(
self,
page=1,
page_size=20,
platform_ids=None,
search_query="",
**kwargs,
):
async with self.get_db() as session:
session: AsyncSession
# Build the base query with filters
base_query = select(ConversationV2)
if platform_ids:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(platform_ids),
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
or_(
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
),
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
)
# Get total count matching the filters
count_query = select(func.count()).select_from(base_query.subquery())
total_count = await session.execute(count_query)
total = total_count.scalar_one()
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(result_query)
conversations = result.scalars().all()
return conversations, total
async def create_conversation(
self,
user_id,
platform_id,
content=None,
title=None,
persona_id=None,
cid=None,
created_at=None,
updated_at=None,
):
kwargs = {}
if cid:
kwargs["conversation_id"] = cid
if created_at:
kwargs["created_at"] = created_at
if updated_at:
kwargs["updated_at"] = updated_at
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_conversation = ConversationV2(
user_id=user_id,
content=content or [],
platform_id=platform_id,
title=title,
persona_id=persona_id,
**kwargs,
)
session.add(new_conversation)
return new_conversation
async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
)
values = {}
if title is not None:
values["title"] = title
if persona_id is not None:
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
),
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.user_id) == user_id
),
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id",
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history(
self,
platform_id,
user_id,
content,
sender_id=None,
sender_name=None,
):
"""Insert a new platform message history record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=user_id,
content=content,
sender_id=sender_id,
sender_name=sender_name,
)
session.add(new_history)
return new_history
async def delete_platform_message_offset(
self,
platform_id,
user_id,
offset_sec=86400,
):
"""Delete platform message history records newer than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
now = datetime.now()
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) >= cutoff_time,
),
)
async def get_platform_message_history(
self,
platform_id,
user_id,
page=1,
page_size=20,
):
"""Get platform message history records."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = (
select(PlatformMessageHistory)
.where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(desc(PlatformMessageHistory.created_at))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
async def get_platform_message_history_by_id(
self, message_id: int
) -> PlatformMessageHistory | None:
"""Get a platform message history record by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformMessageHistory).where(
PlatformMessageHistory.id == message_id
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_attachment = Attachment(
path=path,
type=type,
mime_type=mime_type,
)
session.add(new_attachment)
return new_attachment
async def get_attachment_by_id(self, attachment_id):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_attachments(self, attachment_ids: list[str]) -> list:
"""Get multiple attachments by their IDs."""
if not attachment_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
if not attachment_ids:
return 0
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount
async def insert_persona(
self,
persona_id,
system_prompt,
begin_dialogs=None,
tools=None,
folder_id=None,
sort_order=0,
):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_persona = Persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
tools=tools,
folder_id=folder_id,
sort_order=sort_order,
)
session.add(new_persona)
await session.flush()
await session.refresh(new_persona)
return new_persona
async def get_persona_by_id(self, persona_id):
"""Get a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona).where(Persona.persona_id == persona_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_personas(self):
"""Get all personas for a specific bot."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona)
result = await session.execute(query)
return result.scalars().all()
async def update_persona(
self,
persona_id,
system_prompt=None,
begin_dialogs=None,
tools=NOT_GIVEN,
):
"""Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(col(Persona.persona_id) == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
if begin_dialogs is not None:
values["begin_dialogs"] = begin_dialogs
if tools is not NOT_GIVEN:
values["tools"] = tools
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id):
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(col(Persona.persona_id) == persona_id),
)
# ====
# Persona Folder Management
# ====
async def insert_persona_folder(
self,
name: str,
parent_id: str | None = None,
description: str | None = None,
sort_order: int = 0,
) -> PersonaFolder:
"""Insert a new persona folder."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_folder = PersonaFolder(
name=name,
parent_id=parent_id,
description=description,
sort_order=sort_order,
)
session.add(new_folder)
await session.flush()
await session.refresh(new_folder)
return new_folder
async def get_persona_folder_by_id(self, folder_id: str) -> PersonaFolder | None:
"""Get a persona folder by its folder_id."""
async with self.get_db() as session:
session: AsyncSession
query = select(PersonaFolder).where(PersonaFolder.folder_id == folder_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_persona_folders(
self, parent_id: str | None = None
) -> list[PersonaFolder]:
"""Get all persona folders, optionally filtered by parent_id.
Args:
parent_id: If None, returns root folders only. If specified, returns
children of that folder.
"""
async with self.get_db() as session:
session: AsyncSession
if parent_id is None:
# Get root folders (parent_id is NULL)
query = (
select(PersonaFolder)
.where(col(PersonaFolder.parent_id).is_(None))
.order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name))
)
else:
query = (
select(PersonaFolder)
.where(PersonaFolder.parent_id == parent_id)
.order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name))
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_all_persona_folders(self) -> list[PersonaFolder]:
"""Get all persona folders."""
async with self.get_db() as session:
session: AsyncSession
query = select(PersonaFolder).order_by(
col(PersonaFolder.sort_order), col(PersonaFolder.name)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_persona_folder(
self,
folder_id: str,
name: str | None = None,
parent_id: T.Any = NOT_GIVEN,
description: T.Any = NOT_GIVEN,
sort_order: int | None = None,
) -> PersonaFolder | None:
"""Update a persona folder."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(PersonaFolder).where(
col(PersonaFolder.folder_id) == folder_id
)
values: dict[str, T.Any] = {}
if name is not None:
values["name"] = name
if parent_id is not NOT_GIVEN:
values["parent_id"] = parent_id
if description is not NOT_GIVEN:
values["description"] = description
if sort_order is not None:
values["sort_order"] = sort_order
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_persona_folder_by_id(folder_id)
async def delete_persona_folder(self, folder_id: str) -> None:
"""Delete a persona folder by its folder_id.
Note: This will also set folder_id to NULL for all personas in this folder,
moving them to the root directory.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# Move personas to root directory
await session.execute(
update(Persona)
.where(col(Persona.folder_id) == folder_id)
.values(folder_id=None)
)
# Delete the folder
await session.execute(
delete(PersonaFolder).where(
col(PersonaFolder.folder_id) == folder_id
),
)
async def move_persona_to_folder(
self, persona_id: str, folder_id: str | None
) -> Persona | None:
"""Move a persona to a folder (or root if folder_id is None)."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
update(Persona)
.where(col(Persona.persona_id) == persona_id)
.values(folder_id=folder_id)
)
return await self.get_persona_by_id(persona_id)
async def get_personas_by_folder(
self, folder_id: str | None = None
) -> list[Persona]:
"""Get all personas in a specific folder.
Args:
folder_id: If None, returns personas in root directory.
"""
async with self.get_db() as session:
session: AsyncSession
if folder_id is None:
query = (
select(Persona)
.where(col(Persona.folder_id).is_(None))
.order_by(col(Persona.sort_order), col(Persona.persona_id))
)
else:
query = (
select(Persona)
.where(Persona.folder_id == folder_id)
.order_by(col(Persona.sort_order), col(Persona.persona_id))
)
result = await session.execute(query)
return list(result.scalars().all())
async def batch_update_sort_order(
self,
items: list[dict],
) -> None:
"""Batch update sort_order for personas and/or folders.
Args:
items: List of dicts with keys:
- id: The persona_id or folder_id
- type: Either "persona" or "folder"
- sort_order: The new sort_order value
"""
if not items:
return
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
for item in items:
item_id = item.get("id")
item_type = item.get("type")
sort_order = item.get("sort_order")
if item_id is None or item_type is None or sort_order is None:
continue
if item_type == "persona":
await session.execute(
update(Persona)
.where(col(Persona.persona_id) == item_id)
.values(sort_order=sort_order)
)
elif item_type == "folder":
await session.execute(
update(PersonaFolder)
.where(col(PersonaFolder.folder_id) == item_id)
.values(sort_order=sort_order)
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
existing_preference = result.scalar_one_or_none()
if existing_preference:
existing_preference.value = value
else:
new_preference = Preference(
scope=scope,
scope_id=scope_id,
key=key,
value=value,
)
session.add(new_preference)
return existing_preference or new_preference
async def get_preference(self, scope, scope_id, key):
"""Get a preference by key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_preferences(self, scope, scope_id=None, key=None):
"""Get all preferences for a specific scope ID or key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(Preference.scope == scope)
if scope_id is not None:
query = query.where(Preference.scope_id == scope_id)
if key is not None:
query = query.where(Preference.key == key)
result = await session.execute(query)
return result.scalars().all()
async def remove_preference(self, scope, scope_id, key):
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
),
)
await session.commit()
async def clear_preferences(self, scope, scope_id):
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
),
)
await session.commit()
# ====
# Command Configuration & Conflict Tracking
# ====
async def _run_in_tx(
self,
fn: Callable[[AsyncSession], Awaitable[TxResult]],
) -> TxResult:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
return await fn(session)
@staticmethod
def _apply_updates(model, **updates) -> None:
for field, value in updates.items():
if value is not None:
setattr(model, field, value)
@staticmethod
def _new_command_config(
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
return CommandConfig(
handler_full_name=handler_full_name,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=True if enabled is None else enabled,
keep_original_alias=False
if keep_original_alias is None
else keep_original_alias,
conflict_key=conflict_key or original_command,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=bool(auto_managed),
)
@staticmethod
def _new_command_conflict(
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
return CommandConflict(
conflict_key=conflict_key,
handler_full_name=handler_full_name,
plugin_name=plugin_name,
status=status or "pending",
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=bool(auto_generated),
)
async def get_command_configs(self) -> list[CommandConfig]:
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(select(CommandConfig))
return list(result.scalars().all())
async def get_command_config(
self,
handler_full_name: str,
) -> CommandConfig | None:
async with self.get_db() as session:
session: AsyncSession
return await session.get(CommandConfig, handler_full_name)
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
async def _op(session: AsyncSession) -> CommandConfig:
config = await session.get(CommandConfig, handler_full_name)
if not config:
config = self._new_command_config(
handler_full_name,
plugin_name,
module_path,
original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
session.add(config)
else:
self._apply_updates(
config,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
await session.flush()
await session.refresh(config)
return config
return await self._run_in_tx(_op)
async def delete_command_config(self, handler_full_name: str) -> None:
await self.delete_command_configs([handler_full_name])
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
if not handler_full_names:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConfig).where(
col(CommandConfig.handler_full_name).in_(handler_full_names),
),
)
await self._run_in_tx(_op)
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
async with self.get_db() as session:
session: AsyncSession
query = select(CommandConflict)
if status:
query = query.where(CommandConflict.status == status)
result = await session.execute(query)
return list(result.scalars().all())
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
async def _op(session: AsyncSession) -> CommandConflict:
result = await session.execute(
select(CommandConflict).where(
CommandConflict.conflict_key == conflict_key,
CommandConflict.handler_full_name == handler_full_name,
),
)
record = result.scalar_one_or_none()
if not record:
record = self._new_command_conflict(
conflict_key,
handler_full_name,
plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
session.add(record)
else:
self._apply_updates(
record,
plugin_name=plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
await session.flush()
await session.refresh(record)
return record
return await self._run_in_tx(_op)
async def delete_command_conflicts(self, ids: list[int]) -> None:
if not ids:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
)
await self._run_in_tx(_op)
# ====
# Deprecated Methods
# ====
def get_base_stats(self, offset_sec=86400):
"""Get base statistics within the specified offset in seconds."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
)
all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats()
for data in all_datas:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=int(data.timestamp.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_total_message_count(self):
"""Get the total message count from platform statistics."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
)
total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_grouped_base_stats(self, offset_sec=86400):
# group by platform_id
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id),
)
grouped_stats = result.all()
deprecated_stats = DeprecatedStats()
for platform_id, count in grouped_stats:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=int(start_time.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
# ====
# Platform Session Management
# ====
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> PlatformSession:
"""Create a new Platform session."""
kwargs = {}
if session_id:
kwargs["session_id"] = session_id
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_session = PlatformSession(
creator=creator,
platform_id=platform_id,
display_name=display_name,
is_group=is_group,
**kwargs,
)
session.add(new_session)
await session.flush()
await session.refresh(new_session)
return new_session
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformSession).where(
PlatformSession.session_id == session_id,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[dict]:
"""Get all Platform sessions for a specific creator (username) and optionally platform.
Returns a list of dicts containing session info and project info (if session belongs to a project).
"""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
# LEFT JOIN with SessionProjectRelation and ChatUIProject to get project info
query = (
select(
PlatformSession,
col(ChatUIProject.project_id),
col(ChatUIProject.title).label("project_title"),
col(ChatUIProject.emoji).label("project_emoji"),
)
.outerjoin(
SessionProjectRelation,
col(PlatformSession.session_id)
== col(SessionProjectRelation.session_id),
)
.outerjoin(
ChatUIProject,
col(SessionProjectRelation.project_id)
== col(ChatUIProject.project_id),
)
.where(col(PlatformSession.creator) == creator)
)
if platform_id:
query = query.where(PlatformSession.platform_id == platform_id)
query = (
query.order_by(desc(PlatformSession.updated_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(query)
# Convert to list of dicts with session and project info
sessions_with_projects = []
for row in result.all():
platform_session = row[0]
project_id = row[1]
project_title = row[2]
project_emoji = row[3]
session_dict = {
"session": platform_session,
"project_id": project_id,
"project_title": project_title,
"project_emoji": project_emoji,
}
sessions_with_projects.append(session_dict)
return sessions_with_projects
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a Platform session's updated_at timestamp and optionally display_name."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if display_name is not None:
values["display_name"] = display_name
await session.execute(
update(PlatformSession)
.where(col(PlatformSession.session_id) == session_id)
.values(**values),
)
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(PlatformSession).where(
col(PlatformSession.session_id) == session_id,
),
)
# ====
# ChatUI Project Management
# ====
async def create_chatui_project(
self,
creator: str,
title: str,
emoji: str | None = "📁",
description: str | None = None,
) -> ChatUIProject:
"""Create a new ChatUI project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
project = ChatUIProject(
creator=creator,
title=title,
emoji=emoji,
description=description,
)
session.add(project)
await session.flush()
await session.refresh(project)
return project
async def get_chatui_project_by_id(self, project_id: str) -> ChatUIProject | None:
"""Get a ChatUI project by its ID."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(ChatUIProject).where(
col(ChatUIProject.project_id) == project_id,
),
)
return result.scalar_one_or_none()
async def get_chatui_projects_by_creator(
self,
creator: str,
page: int = 1,
page_size: int = 100,
) -> list[ChatUIProject]:
"""Get all ChatUI projects for a specific creator."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(ChatUIProject)
.where(col(ChatUIProject.creator) == creator)
.order_by(desc(ChatUIProject.updated_at))
.limit(page_size)
.offset(offset),
)
return list(result.scalars().all())
async def update_chatui_project(
self,
project_id: str,
title: str | None = None,
emoji: str | None = None,
description: str | None = None,
) -> None:
"""Update a ChatUI project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if title is not None:
values["title"] = title
if emoji is not None:
values["emoji"] = emoji
if description is not None:
values["description"] = description
await session.execute(
update(ChatUIProject)
.where(col(ChatUIProject.project_id) == project_id)
.values(**values),
)
async def delete_chatui_project(self, project_id: str) -> None:
"""Delete a ChatUI project by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# First remove all session relations
await session.execute(
delete(SessionProjectRelation).where(
col(SessionProjectRelation.project_id) == project_id,
),
)
# Then delete the project
await session.execute(
delete(ChatUIProject).where(
col(ChatUIProject.project_id) == project_id,
),
)
async def add_session_to_project(
self,
session_id: str,
project_id: str,
) -> SessionProjectRelation:
"""Add a session to a project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# First remove existing relation if any
await session.execute(
delete(SessionProjectRelation).where(
col(SessionProjectRelation.session_id) == session_id,
),
)
# Then create new relation
relation = SessionProjectRelation(
session_id=session_id,
project_id=project_id,
)
session.add(relation)
await session.flush()
await session.refresh(relation)
return relation
async def remove_session_from_project(self, session_id: str) -> None:
"""Remove a session from its project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(SessionProjectRelation).where(
col(SessionProjectRelation.session_id) == session_id,
),
)
async def get_project_sessions(
self,
project_id: str,
page: int = 1,
page_size: int = 100,
) -> list[PlatformSession]:
"""Get all sessions in a project."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(PlatformSession)
.join(
SessionProjectRelation,
col(PlatformSession.session_id)
== col(SessionProjectRelation.session_id),
)
.where(col(SessionProjectRelation.project_id) == project_id)
.order_by(desc(PlatformSession.updated_at))
.limit(page_size)
.offset(offset),
)
return list(result.scalars().all())
async def get_project_by_session(
self, session_id: str, creator: str
) -> ChatUIProject | None:
"""Get the project that a session belongs to."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(ChatUIProject)
.join(
SessionProjectRelation,
col(ChatUIProject.project_id)
== col(SessionProjectRelation.project_id),
)
.where(
col(SessionProjectRelation.session_id) == session_id,
col(ChatUIProject.creator) == creator,
),
)
return result.scalar_one_or_none()