7bf44bd8d2
* feat: support persona custom error reply message with fallback * refactor: centralize persona custom error message helpers
1854 lines
67 KiB
Python
1854 lines
67 KiB
Python
import asyncio
|
|
import threading
|
|
import typing as T
|
|
from collections.abc import Awaitable, Callable
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy import CursorResult, Row
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
|
|
|
from astrbot.core.db import BaseDatabase
|
|
from astrbot.core.db.po import (
|
|
ApiKey,
|
|
Attachment,
|
|
ChatUIProject,
|
|
CommandConfig,
|
|
CommandConflict,
|
|
ConversationV2,
|
|
CronJob,
|
|
Persona,
|
|
PersonaFolder,
|
|
PlatformMessageHistory,
|
|
PlatformSession,
|
|
PlatformStat,
|
|
Preference,
|
|
SessionProjectRelation,
|
|
SQLModel,
|
|
)
|
|
from astrbot.core.db.po import (
|
|
Platform as DeprecatedPlatformStat,
|
|
)
|
|
from astrbot.core.db.po import (
|
|
Stats as DeprecatedStats,
|
|
)
|
|
from astrbot.core.sentinels import NOT_GIVEN
|
|
|
|
TxResult = T.TypeVar("TxResult")
|
|
CRON_FIELD_NOT_SET = object()
|
|
|
|
|
|
class SQLiteDatabase(BaseDatabase):
|
|
def __init__(self, db_path: str) -> None:
|
|
self.db_path = db_path
|
|
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
|
self.inited = False
|
|
super().__init__()
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the database by creating tables if they do not exist."""
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
|
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
|
await conn.execute(text("PRAGMA cache_size=20000"))
|
|
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
|
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
|
await conn.execute(text("PRAGMA optimize"))
|
|
# 确保 personas 表有 folder_id、sort_order、skills 列(前向兼容)
|
|
await self._ensure_persona_folder_columns(conn)
|
|
await self._ensure_persona_skills_column(conn)
|
|
await self._ensure_persona_custom_error_message_column(conn)
|
|
await conn.commit()
|
|
|
|
async def _ensure_persona_folder_columns(self, conn) -> None:
|
|
"""确保 personas 表有 folder_id 和 sort_order 列。
|
|
|
|
这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel
|
|
的 metadata.create_all 自动创建这些列。
|
|
"""
|
|
result = await conn.execute(text("PRAGMA table_info(personas)"))
|
|
columns = {row[1] for row in result.fetchall()}
|
|
|
|
if "folder_id" not in columns:
|
|
await conn.execute(
|
|
text(
|
|
"ALTER TABLE personas ADD COLUMN folder_id VARCHAR(36) DEFAULT NULL"
|
|
)
|
|
)
|
|
if "sort_order" not in columns:
|
|
await conn.execute(
|
|
text("ALTER TABLE personas ADD COLUMN sort_order INTEGER DEFAULT 0")
|
|
)
|
|
|
|
async def _ensure_persona_skills_column(self, conn) -> None:
|
|
"""确保 personas 表有 skills 列。
|
|
|
|
这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel
|
|
的 metadata.create_all 自动创建这些列。
|
|
"""
|
|
result = await conn.execute(text("PRAGMA table_info(personas)"))
|
|
columns = {row[1] for row in result.fetchall()}
|
|
|
|
if "skills" not in columns:
|
|
await conn.execute(text("ALTER TABLE personas ADD COLUMN skills JSON"))
|
|
|
|
async def _ensure_persona_custom_error_message_column(self, conn) -> None:
|
|
"""确保 personas 表有 custom_error_message 列。"""
|
|
result = await conn.execute(text("PRAGMA table_info(personas)"))
|
|
columns = {row[1] for row in result.fetchall()}
|
|
|
|
if "custom_error_message" not in columns:
|
|
await conn.execute(
|
|
text("ALTER TABLE personas ADD COLUMN custom_error_message TEXT")
|
|
)
|
|
|
|
# ====
|
|
# Platform Statistics
|
|
# ====
|
|
|
|
async def insert_platform_stats(
|
|
self,
|
|
platform_id,
|
|
platform_type,
|
|
count=1,
|
|
timestamp=None,
|
|
) -> None:
|
|
"""Insert a new platform statistic record."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
if timestamp is None:
|
|
timestamp = datetime.now().replace(
|
|
minute=0,
|
|
second=0,
|
|
microsecond=0,
|
|
)
|
|
current_hour = timestamp
|
|
await session.execute(
|
|
text("""
|
|
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
|
|
VALUES (:timestamp, :platform_id, :platform_type, :count)
|
|
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
|
|
count = platform_stats.count + EXCLUDED.count
|
|
"""),
|
|
{
|
|
"timestamp": current_hour,
|
|
"platform_id": platform_id,
|
|
"platform_type": platform_type,
|
|
"count": count,
|
|
},
|
|
)
|
|
|
|
async def count_platform_stats(self) -> int:
|
|
"""Count the number of platform statistics records."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(
|
|
select(func.count(col(PlatformStat.platform_id))).select_from(
|
|
PlatformStat,
|
|
),
|
|
)
|
|
count = result.scalar_one_or_none()
|
|
return count if count is not None else 0
|
|
|
|
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
|
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
now = datetime.now()
|
|
start_time = now - timedelta(seconds=offset_sec)
|
|
result = await session.execute(
|
|
text("""
|
|
SELECT * FROM platform_stats
|
|
WHERE timestamp >= :start_time
|
|
GROUP BY platform_id
|
|
ORDER BY timestamp DESC
|
|
"""),
|
|
{"start_time": start_time},
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
# ====
|
|
# Conversation Management
|
|
# ====
|
|
|
|
async def get_conversations(self, user_id=None, platform_id=None):
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(ConversationV2)
|
|
|
|
if user_id:
|
|
query = query.where(ConversationV2.user_id == user_id)
|
|
if platform_id:
|
|
query = query.where(ConversationV2.platform_id == platform_id)
|
|
# order by
|
|
query = query.order_by(desc(ConversationV2.created_at))
|
|
result = await session.execute(query)
|
|
|
|
return result.scalars().all()
|
|
|
|
async def get_conversation_by_id(self, cid):
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_all_conversations(self, page=1, page_size=20):
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
offset = (page - 1) * page_size
|
|
result = await session.execute(
|
|
select(ConversationV2)
|
|
.order_by(desc(ConversationV2.created_at))
|
|
.offset(offset)
|
|
.limit(page_size),
|
|
)
|
|
return result.scalars().all()
|
|
|
|
async def get_filtered_conversations(
|
|
self,
|
|
page=1,
|
|
page_size=20,
|
|
platform_ids=None,
|
|
search_query="",
|
|
**kwargs,
|
|
):
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
# Build the base query with filters
|
|
base_query = select(ConversationV2)
|
|
|
|
if platform_ids:
|
|
base_query = base_query.where(
|
|
col(ConversationV2.platform_id).in_(platform_ids),
|
|
)
|
|
if search_query:
|
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
|
base_query = base_query.where(
|
|
or_(
|
|
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
|
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
|
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
|
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
|
),
|
|
)
|
|
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
|
for msg_type in kwargs["message_types"]:
|
|
base_query = base_query.where(
|
|
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
|
|
)
|
|
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
|
base_query = base_query.where(
|
|
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
|
|
)
|
|
|
|
# Get total count matching the filters
|
|
count_query = select(func.count()).select_from(base_query.subquery())
|
|
total_count = await session.execute(count_query)
|
|
total = total_count.scalar_one()
|
|
|
|
# Get paginated results
|
|
offset = (page - 1) * page_size
|
|
result_query = (
|
|
base_query.order_by(desc(ConversationV2.created_at))
|
|
.offset(offset)
|
|
.limit(page_size)
|
|
)
|
|
result = await session.execute(result_query)
|
|
conversations = result.scalars().all()
|
|
|
|
return conversations, total
|
|
|
|
async def create_conversation(
|
|
self,
|
|
user_id,
|
|
platform_id,
|
|
content=None,
|
|
title=None,
|
|
persona_id=None,
|
|
cid=None,
|
|
created_at=None,
|
|
updated_at=None,
|
|
):
|
|
kwargs = {}
|
|
if cid:
|
|
kwargs["conversation_id"] = cid
|
|
if created_at:
|
|
kwargs["created_at"] = created_at
|
|
if updated_at:
|
|
kwargs["updated_at"] = updated_at
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
new_conversation = ConversationV2(
|
|
user_id=user_id,
|
|
content=content or [],
|
|
platform_id=platform_id,
|
|
title=title,
|
|
persona_id=persona_id,
|
|
**kwargs,
|
|
)
|
|
session.add(new_conversation)
|
|
return new_conversation
|
|
|
|
async def update_conversation(
|
|
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
|
):
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
query = update(ConversationV2).where(
|
|
col(ConversationV2.conversation_id) == cid,
|
|
)
|
|
values = {}
|
|
if title is not None:
|
|
values["title"] = title
|
|
if persona_id is not None:
|
|
values["persona_id"] = persona_id
|
|
if content is not None:
|
|
values["content"] = content
|
|
if token_usage is not None:
|
|
values["token_usage"] = token_usage
|
|
if not values:
|
|
return None
|
|
query = query.values(**values)
|
|
await session.execute(query)
|
|
return await self.get_conversation_by_id(cid)
|
|
|
|
async def delete_conversation(self, cid) -> None:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(ConversationV2).where(
|
|
col(ConversationV2.conversation_id) == cid,
|
|
),
|
|
)
|
|
|
|
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(ConversationV2).where(
|
|
col(ConversationV2.user_id) == user_id
|
|
),
|
|
)
|
|
|
|
async def get_session_conversations(
|
|
self,
|
|
page=1,
|
|
page_size=20,
|
|
search_query=None,
|
|
platform=None,
|
|
) -> tuple[list[dict], int]:
|
|
"""Get paginated session conversations with joined conversation and persona details."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
offset = (page - 1) * page_size
|
|
|
|
base_query = (
|
|
select(
|
|
col(Preference.scope_id).label("session_id"),
|
|
func.json_extract(Preference.value, "$.val").label(
|
|
"conversation_id",
|
|
), # type: ignore
|
|
col(ConversationV2.persona_id).label("persona_id"),
|
|
col(ConversationV2.title).label("title"),
|
|
col(Persona.persona_id).label("persona_name"),
|
|
)
|
|
.select_from(Preference)
|
|
.outerjoin(
|
|
ConversationV2,
|
|
func.json_extract(Preference.value, "$.val")
|
|
== ConversationV2.conversation_id,
|
|
)
|
|
.outerjoin(
|
|
Persona,
|
|
col(ConversationV2.persona_id) == Persona.persona_id,
|
|
)
|
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
)
|
|
|
|
# 搜索筛选
|
|
if search_query:
|
|
search_pattern = f"%{search_query}%"
|
|
base_query = base_query.where(
|
|
or_(
|
|
col(Preference.scope_id).ilike(search_pattern),
|
|
col(ConversationV2.title).ilike(search_pattern),
|
|
col(Persona.persona_id).ilike(search_pattern),
|
|
),
|
|
)
|
|
|
|
# 平台筛选
|
|
if platform:
|
|
platform_pattern = f"{platform}:%"
|
|
base_query = base_query.where(
|
|
col(Preference.scope_id).like(platform_pattern),
|
|
)
|
|
|
|
# 排序
|
|
base_query = base_query.order_by(Preference.scope_id)
|
|
|
|
# 分页结果
|
|
result_query = base_query.offset(offset).limit(page_size)
|
|
result = await session.execute(result_query)
|
|
rows = result.fetchall()
|
|
|
|
# 查询总数(应用相同的筛选条件)
|
|
count_base_query = (
|
|
select(func.count(col(Preference.scope_id)))
|
|
.select_from(Preference)
|
|
.outerjoin(
|
|
ConversationV2,
|
|
func.json_extract(Preference.value, "$.val")
|
|
== ConversationV2.conversation_id,
|
|
)
|
|
.outerjoin(
|
|
Persona,
|
|
col(ConversationV2.persona_id) == Persona.persona_id,
|
|
)
|
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
)
|
|
|
|
# 应用相同的搜索和平台筛选条件到计数查询
|
|
if search_query:
|
|
search_pattern = f"%{search_query}%"
|
|
count_base_query = count_base_query.where(
|
|
or_(
|
|
col(Preference.scope_id).ilike(search_pattern),
|
|
col(ConversationV2.title).ilike(search_pattern),
|
|
col(Persona.persona_id).ilike(search_pattern),
|
|
),
|
|
)
|
|
|
|
if platform:
|
|
platform_pattern = f"{platform}:%"
|
|
count_base_query = count_base_query.where(
|
|
col(Preference.scope_id).like(platform_pattern),
|
|
)
|
|
|
|
total_result = await session.execute(count_base_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
sessions_data = [
|
|
{
|
|
"session_id": row.session_id,
|
|
"conversation_id": row.conversation_id,
|
|
"persona_id": row.persona_id,
|
|
"title": row.title,
|
|
"persona_name": row.persona_name,
|
|
}
|
|
for row in rows
|
|
]
|
|
return sessions_data, total
|
|
|
|
async def insert_platform_message_history(
|
|
self,
|
|
platform_id,
|
|
user_id,
|
|
content,
|
|
sender_id=None,
|
|
sender_name=None,
|
|
):
|
|
"""Insert a new platform message history record."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
new_history = PlatformMessageHistory(
|
|
platform_id=platform_id,
|
|
user_id=user_id,
|
|
content=content,
|
|
sender_id=sender_id,
|
|
sender_name=sender_name,
|
|
)
|
|
session.add(new_history)
|
|
return new_history
|
|
|
|
async def delete_platform_message_offset(
|
|
self,
|
|
platform_id,
|
|
user_id,
|
|
offset_sec=86400,
|
|
) -> None:
|
|
"""Delete platform message history records newer than the specified offset."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
now = datetime.now()
|
|
cutoff_time = now - timedelta(seconds=offset_sec)
|
|
await session.execute(
|
|
delete(PlatformMessageHistory).where(
|
|
col(PlatformMessageHistory.platform_id) == platform_id,
|
|
col(PlatformMessageHistory.user_id) == user_id,
|
|
col(PlatformMessageHistory.created_at) >= cutoff_time,
|
|
),
|
|
)
|
|
|
|
async def get_platform_message_history(
|
|
self,
|
|
platform_id,
|
|
user_id,
|
|
page=1,
|
|
page_size=20,
|
|
):
|
|
"""Get platform message history records."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
offset = (page - 1) * page_size
|
|
query = (
|
|
select(PlatformMessageHistory)
|
|
.where(
|
|
PlatformMessageHistory.platform_id == platform_id,
|
|
PlatformMessageHistory.user_id == user_id,
|
|
)
|
|
.order_by(desc(PlatformMessageHistory.created_at))
|
|
)
|
|
result = await session.execute(query.offset(offset).limit(page_size))
|
|
return result.scalars().all()
|
|
|
|
async def get_platform_message_history_by_id(
|
|
self, message_id: int
|
|
) -> PlatformMessageHistory | None:
|
|
"""Get a platform message history record by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(PlatformMessageHistory).where(
|
|
PlatformMessageHistory.id == message_id
|
|
)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def insert_attachment(self, path, type, mime_type):
|
|
"""Insert a new attachment record."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
new_attachment = Attachment(
|
|
path=path,
|
|
type=type,
|
|
mime_type=mime_type,
|
|
)
|
|
session.add(new_attachment)
|
|
return new_attachment
|
|
|
|
async def get_attachment_by_id(self, attachment_id):
|
|
"""Get an attachment by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_attachments(self, attachment_ids: list[str]) -> list:
|
|
"""Get multiple attachments by their IDs."""
|
|
if not attachment_ids:
|
|
return []
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(Attachment).where(
|
|
col(Attachment.attachment_id).in_(attachment_ids)
|
|
)
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def delete_attachment(self, attachment_id: str) -> bool:
|
|
"""Delete an attachment by its ID.
|
|
|
|
Returns True if the attachment was deleted, False if it was not found.
|
|
"""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
query = delete(Attachment).where(
|
|
col(Attachment.attachment_id) == attachment_id
|
|
)
|
|
result = T.cast(CursorResult, await session.execute(query))
|
|
return result.rowcount > 0
|
|
|
|
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
|
"""Delete multiple attachments by their IDs.
|
|
|
|
Returns the number of attachments deleted.
|
|
"""
|
|
if not attachment_ids:
|
|
return 0
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
query = delete(Attachment).where(
|
|
col(Attachment.attachment_id).in_(attachment_ids)
|
|
)
|
|
result = T.cast(CursorResult, await session.execute(query))
|
|
return result.rowcount
|
|
|
|
async def create_api_key(
|
|
self,
|
|
name: str,
|
|
key_hash: str,
|
|
key_prefix: str,
|
|
scopes: list[str] | None,
|
|
created_by: str,
|
|
expires_at: datetime | None = None,
|
|
) -> ApiKey:
|
|
"""Create a new API key record."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
api_key = ApiKey(
|
|
name=name,
|
|
key_hash=key_hash,
|
|
key_prefix=key_prefix,
|
|
scopes=scopes,
|
|
created_by=created_by,
|
|
expires_at=expires_at,
|
|
)
|
|
session.add(api_key)
|
|
await session.flush()
|
|
await session.refresh(api_key)
|
|
return api_key
|
|
|
|
async def list_api_keys(self) -> list[ApiKey]:
|
|
"""List all API keys."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(
|
|
select(ApiKey).order_by(desc(ApiKey.created_at))
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_api_key_by_id(self, key_id: str) -> ApiKey | None:
|
|
"""Get an API key by key_id."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(
|
|
select(ApiKey).where(ApiKey.key_id == key_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None:
|
|
"""Get an active API key by hash (not revoked, not expired)."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
now = datetime.now(timezone.utc)
|
|
query = select(ApiKey).where(
|
|
ApiKey.key_hash == key_hash,
|
|
col(ApiKey.revoked_at).is_(None),
|
|
or_(col(ApiKey.expires_at).is_(None), col(ApiKey.expires_at) > now),
|
|
)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def touch_api_key(self, key_id: str) -> None:
|
|
"""Update last_used_at of an API key."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
update(ApiKey)
|
|
.where(col(ApiKey.key_id) == key_id)
|
|
.values(last_used_at=datetime.now(timezone.utc)),
|
|
)
|
|
|
|
async def revoke_api_key(self, key_id: str) -> bool:
|
|
"""Revoke an API key."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
query = (
|
|
update(ApiKey)
|
|
.where(col(ApiKey.key_id) == key_id)
|
|
.values(revoked_at=datetime.now(timezone.utc))
|
|
)
|
|
result = T.cast(CursorResult, await session.execute(query))
|
|
return result.rowcount > 0
|
|
|
|
async def delete_api_key(self, key_id: str) -> bool:
|
|
"""Delete an API key."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
result = T.cast(
|
|
CursorResult,
|
|
await session.execute(
|
|
delete(ApiKey).where(col(ApiKey.key_id) == key_id)
|
|
),
|
|
)
|
|
return result.rowcount > 0
|
|
|
|
async def insert_persona(
|
|
self,
|
|
persona_id,
|
|
system_prompt,
|
|
begin_dialogs=None,
|
|
tools=None,
|
|
skills=None,
|
|
custom_error_message=None,
|
|
folder_id=None,
|
|
sort_order=0,
|
|
):
|
|
"""Insert a new persona record."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
new_persona = Persona(
|
|
persona_id=persona_id,
|
|
system_prompt=system_prompt,
|
|
begin_dialogs=begin_dialogs or [],
|
|
tools=tools,
|
|
skills=skills,
|
|
custom_error_message=custom_error_message,
|
|
folder_id=folder_id,
|
|
sort_order=sort_order,
|
|
)
|
|
session.add(new_persona)
|
|
await session.flush()
|
|
await session.refresh(new_persona)
|
|
return new_persona
|
|
|
|
async def get_persona_by_id(self, persona_id):
|
|
"""Get a persona by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(Persona).where(Persona.persona_id == persona_id)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_personas(self):
|
|
"""Get all personas for a specific bot."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(Persona)
|
|
result = await session.execute(query)
|
|
return result.scalars().all()
|
|
|
|
async def update_persona(
|
|
self,
|
|
persona_id,
|
|
system_prompt=None,
|
|
begin_dialogs=None,
|
|
tools=NOT_GIVEN,
|
|
skills=NOT_GIVEN,
|
|
custom_error_message=NOT_GIVEN,
|
|
):
|
|
"""Update a persona's system prompt or begin dialogs."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
|
values = {}
|
|
if system_prompt is not None:
|
|
values["system_prompt"] = system_prompt
|
|
if begin_dialogs is not None:
|
|
values["begin_dialogs"] = begin_dialogs
|
|
if tools is not NOT_GIVEN:
|
|
values["tools"] = tools
|
|
if skills is not NOT_GIVEN:
|
|
values["skills"] = skills
|
|
if custom_error_message is not NOT_GIVEN:
|
|
values["custom_error_message"] = custom_error_message
|
|
if not values:
|
|
return None
|
|
query = query.values(**values)
|
|
await session.execute(query)
|
|
return await self.get_persona_by_id(persona_id)
|
|
|
|
async def delete_persona(self, persona_id) -> None:
|
|
"""Delete a persona by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(Persona).where(col(Persona.persona_id) == persona_id),
|
|
)
|
|
|
|
# ====
|
|
# Persona Folder Management
|
|
# ====
|
|
|
|
async def insert_persona_folder(
|
|
self,
|
|
name: str,
|
|
parent_id: str | None = None,
|
|
description: str | None = None,
|
|
sort_order: int = 0,
|
|
) -> PersonaFolder:
|
|
"""Insert a new persona folder."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
new_folder = PersonaFolder(
|
|
name=name,
|
|
parent_id=parent_id,
|
|
description=description,
|
|
sort_order=sort_order,
|
|
)
|
|
session.add(new_folder)
|
|
await session.flush()
|
|
await session.refresh(new_folder)
|
|
return new_folder
|
|
|
|
async def get_persona_folder_by_id(self, folder_id: str) -> PersonaFolder | None:
|
|
"""Get a persona folder by its folder_id."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(PersonaFolder).where(PersonaFolder.folder_id == folder_id)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_persona_folders(
|
|
self, parent_id: str | None = None
|
|
) -> list[PersonaFolder]:
|
|
"""Get all persona folders, optionally filtered by parent_id.
|
|
|
|
Args:
|
|
parent_id: If None, returns root folders only. If specified, returns
|
|
children of that folder.
|
|
"""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
if parent_id is None:
|
|
# Get root folders (parent_id is NULL)
|
|
query = (
|
|
select(PersonaFolder)
|
|
.where(col(PersonaFolder.parent_id).is_(None))
|
|
.order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name))
|
|
)
|
|
else:
|
|
query = (
|
|
select(PersonaFolder)
|
|
.where(PersonaFolder.parent_id == parent_id)
|
|
.order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name))
|
|
)
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_all_persona_folders(self) -> list[PersonaFolder]:
|
|
"""Get all persona folders."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(PersonaFolder).order_by(
|
|
col(PersonaFolder.sort_order), col(PersonaFolder.name)
|
|
)
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_persona_folder(
|
|
self,
|
|
folder_id: str,
|
|
name: str | None = None,
|
|
parent_id: T.Any = NOT_GIVEN,
|
|
description: T.Any = NOT_GIVEN,
|
|
sort_order: int | None = None,
|
|
) -> PersonaFolder | None:
|
|
"""Update a persona folder."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
query = update(PersonaFolder).where(
|
|
col(PersonaFolder.folder_id) == folder_id
|
|
)
|
|
values: dict[str, T.Any] = {}
|
|
if name is not None:
|
|
values["name"] = name
|
|
if parent_id is not NOT_GIVEN:
|
|
values["parent_id"] = parent_id
|
|
if description is not NOT_GIVEN:
|
|
values["description"] = description
|
|
if sort_order is not None:
|
|
values["sort_order"] = sort_order
|
|
if not values:
|
|
return None
|
|
query = query.values(**values)
|
|
await session.execute(query)
|
|
return await self.get_persona_folder_by_id(folder_id)
|
|
|
|
async def delete_persona_folder(self, folder_id: str) -> None:
|
|
"""Delete a persona folder by its folder_id.
|
|
|
|
Note: This will also set folder_id to NULL for all personas in this folder,
|
|
moving them to the root directory.
|
|
"""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
# Move personas to root directory
|
|
await session.execute(
|
|
update(Persona)
|
|
.where(col(Persona.folder_id) == folder_id)
|
|
.values(folder_id=None)
|
|
)
|
|
# Delete the folder
|
|
await session.execute(
|
|
delete(PersonaFolder).where(
|
|
col(PersonaFolder.folder_id) == folder_id
|
|
),
|
|
)
|
|
|
|
async def move_persona_to_folder(
|
|
self, persona_id: str, folder_id: str | None
|
|
) -> Persona | None:
|
|
"""Move a persona to a folder (or root if folder_id is None)."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
update(Persona)
|
|
.where(col(Persona.persona_id) == persona_id)
|
|
.values(folder_id=folder_id)
|
|
)
|
|
return await self.get_persona_by_id(persona_id)
|
|
|
|
async def get_personas_by_folder(
|
|
self, folder_id: str | None = None
|
|
) -> list[Persona]:
|
|
"""Get all personas in a specific folder.
|
|
|
|
Args:
|
|
folder_id: If None, returns personas in root directory.
|
|
"""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
if folder_id is None:
|
|
query = (
|
|
select(Persona)
|
|
.where(col(Persona.folder_id).is_(None))
|
|
.order_by(col(Persona.sort_order), col(Persona.persona_id))
|
|
)
|
|
else:
|
|
query = (
|
|
select(Persona)
|
|
.where(Persona.folder_id == folder_id)
|
|
.order_by(col(Persona.sort_order), col(Persona.persona_id))
|
|
)
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def batch_update_sort_order(
|
|
self,
|
|
items: list[dict],
|
|
) -> None:
|
|
"""Batch update sort_order for personas and/or folders.
|
|
|
|
Args:
|
|
items: List of dicts with keys:
|
|
- id: The persona_id or folder_id
|
|
- type: Either "persona" or "folder"
|
|
- sort_order: The new sort_order value
|
|
"""
|
|
if not items:
|
|
return
|
|
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
for item in items:
|
|
item_id = item.get("id")
|
|
item_type = item.get("type")
|
|
sort_order = item.get("sort_order")
|
|
|
|
if item_id is None or item_type is None or sort_order is None:
|
|
continue
|
|
|
|
if item_type == "persona":
|
|
await session.execute(
|
|
update(Persona)
|
|
.where(col(Persona.persona_id) == item_id)
|
|
.values(sort_order=sort_order)
|
|
)
|
|
elif item_type == "folder":
|
|
await session.execute(
|
|
update(PersonaFolder)
|
|
.where(col(PersonaFolder.folder_id) == item_id)
|
|
.values(sort_order=sort_order)
|
|
)
|
|
|
|
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
|
"""Insert a new preference record or update if it exists."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
query = select(Preference).where(
|
|
Preference.scope == scope,
|
|
Preference.scope_id == scope_id,
|
|
Preference.key == key,
|
|
)
|
|
result = await session.execute(query)
|
|
existing_preference = result.scalar_one_or_none()
|
|
if existing_preference:
|
|
existing_preference.value = value
|
|
else:
|
|
new_preference = Preference(
|
|
scope=scope,
|
|
scope_id=scope_id,
|
|
key=key,
|
|
value=value,
|
|
)
|
|
session.add(new_preference)
|
|
return existing_preference or new_preference
|
|
|
|
async def get_preference(self, scope, scope_id, key):
|
|
"""Get a preference by key."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(Preference).where(
|
|
Preference.scope == scope,
|
|
Preference.scope_id == scope_id,
|
|
Preference.key == key,
|
|
)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_preferences(self, scope, scope_id=None, key=None):
|
|
"""Get all preferences for a specific scope ID or key."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(Preference).where(Preference.scope == scope)
|
|
if scope_id is not None:
|
|
query = query.where(Preference.scope_id == scope_id)
|
|
if key is not None:
|
|
query = query.where(Preference.key == key)
|
|
result = await session.execute(query)
|
|
return result.scalars().all()
|
|
|
|
async def remove_preference(self, scope, scope_id, key) -> None:
|
|
"""Remove a preference by scope ID and key."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(Preference).where(
|
|
col(Preference.scope) == scope,
|
|
col(Preference.scope_id) == scope_id,
|
|
col(Preference.key) == key,
|
|
),
|
|
)
|
|
await session.commit()
|
|
|
|
async def clear_preferences(self, scope, scope_id) -> None:
|
|
"""Clear all preferences for a specific scope ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(Preference).where(
|
|
col(Preference.scope) == scope,
|
|
col(Preference.scope_id) == scope_id,
|
|
),
|
|
)
|
|
await session.commit()
|
|
|
|
# ====
|
|
# Command Configuration & Conflict Tracking
|
|
# ====
|
|
|
|
async def _run_in_tx(
|
|
self,
|
|
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
|
) -> TxResult:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
return await fn(session)
|
|
|
|
@staticmethod
|
|
def _apply_updates(model, **updates) -> None:
|
|
for field, value in updates.items():
|
|
if value is not None:
|
|
setattr(model, field, value)
|
|
|
|
@staticmethod
|
|
def _new_command_config(
|
|
handler_full_name: str,
|
|
plugin_name: str,
|
|
module_path: str,
|
|
original_command: str,
|
|
*,
|
|
resolved_command: str | None = None,
|
|
enabled: bool | None = None,
|
|
keep_original_alias: bool | None = None,
|
|
conflict_key: str | None = None,
|
|
resolution_strategy: str | None = None,
|
|
note: str | None = None,
|
|
extra_data: dict | None = None,
|
|
auto_managed: bool | None = None,
|
|
) -> CommandConfig:
|
|
return CommandConfig(
|
|
handler_full_name=handler_full_name,
|
|
plugin_name=plugin_name,
|
|
module_path=module_path,
|
|
original_command=original_command,
|
|
resolved_command=resolved_command,
|
|
enabled=True if enabled is None else enabled,
|
|
keep_original_alias=False
|
|
if keep_original_alias is None
|
|
else keep_original_alias,
|
|
conflict_key=conflict_key or original_command,
|
|
resolution_strategy=resolution_strategy,
|
|
note=note,
|
|
extra_data=extra_data,
|
|
auto_managed=bool(auto_managed),
|
|
)
|
|
|
|
@staticmethod
|
|
def _new_command_conflict(
|
|
conflict_key: str,
|
|
handler_full_name: str,
|
|
plugin_name: str,
|
|
*,
|
|
status: str | None = None,
|
|
resolution: str | None = None,
|
|
resolved_command: str | None = None,
|
|
note: str | None = None,
|
|
extra_data: dict | None = None,
|
|
auto_generated: bool | None = None,
|
|
) -> CommandConflict:
|
|
return CommandConflict(
|
|
conflict_key=conflict_key,
|
|
handler_full_name=handler_full_name,
|
|
plugin_name=plugin_name,
|
|
status=status or "pending",
|
|
resolution=resolution,
|
|
resolved_command=resolved_command,
|
|
note=note,
|
|
extra_data=extra_data,
|
|
auto_generated=bool(auto_generated),
|
|
)
|
|
|
|
async def get_command_configs(self) -> list[CommandConfig]:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(select(CommandConfig))
|
|
return list(result.scalars().all())
|
|
|
|
async def get_command_config(
|
|
self,
|
|
handler_full_name: str,
|
|
) -> CommandConfig | None:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
return await session.get(CommandConfig, handler_full_name)
|
|
|
|
async def upsert_command_config(
|
|
self,
|
|
handler_full_name: str,
|
|
plugin_name: str,
|
|
module_path: str,
|
|
original_command: str,
|
|
*,
|
|
resolved_command: str | None = None,
|
|
enabled: bool | None = None,
|
|
keep_original_alias: bool | None = None,
|
|
conflict_key: str | None = None,
|
|
resolution_strategy: str | None = None,
|
|
note: str | None = None,
|
|
extra_data: dict | None = None,
|
|
auto_managed: bool | None = None,
|
|
) -> CommandConfig:
|
|
async def _op(session: AsyncSession) -> CommandConfig:
|
|
config = await session.get(CommandConfig, handler_full_name)
|
|
if not config:
|
|
config = self._new_command_config(
|
|
handler_full_name,
|
|
plugin_name,
|
|
module_path,
|
|
original_command,
|
|
resolved_command=resolved_command,
|
|
enabled=enabled,
|
|
keep_original_alias=keep_original_alias,
|
|
conflict_key=conflict_key,
|
|
resolution_strategy=resolution_strategy,
|
|
note=note,
|
|
extra_data=extra_data,
|
|
auto_managed=auto_managed,
|
|
)
|
|
session.add(config)
|
|
else:
|
|
self._apply_updates(
|
|
config,
|
|
plugin_name=plugin_name,
|
|
module_path=module_path,
|
|
original_command=original_command,
|
|
resolved_command=resolved_command,
|
|
enabled=enabled,
|
|
keep_original_alias=keep_original_alias,
|
|
conflict_key=conflict_key,
|
|
resolution_strategy=resolution_strategy,
|
|
note=note,
|
|
extra_data=extra_data,
|
|
auto_managed=auto_managed,
|
|
)
|
|
await session.flush()
|
|
await session.refresh(config)
|
|
return config
|
|
|
|
return await self._run_in_tx(_op)
|
|
|
|
async def delete_command_config(self, handler_full_name: str) -> None:
|
|
await self.delete_command_configs([handler_full_name])
|
|
|
|
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
|
if not handler_full_names:
|
|
return
|
|
|
|
async def _op(session: AsyncSession) -> None:
|
|
await session.execute(
|
|
delete(CommandConfig).where(
|
|
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
|
),
|
|
)
|
|
|
|
await self._run_in_tx(_op)
|
|
|
|
async def list_command_conflicts(
|
|
self,
|
|
status: str | None = None,
|
|
) -> list[CommandConflict]:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(CommandConflict)
|
|
if status:
|
|
query = query.where(CommandConflict.status == status)
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def upsert_command_conflict(
|
|
self,
|
|
conflict_key: str,
|
|
handler_full_name: str,
|
|
plugin_name: str,
|
|
*,
|
|
status: str | None = None,
|
|
resolution: str | None = None,
|
|
resolved_command: str | None = None,
|
|
note: str | None = None,
|
|
extra_data: dict | None = None,
|
|
auto_generated: bool | None = None,
|
|
) -> CommandConflict:
|
|
async def _op(session: AsyncSession) -> CommandConflict:
|
|
result = await session.execute(
|
|
select(CommandConflict).where(
|
|
CommandConflict.conflict_key == conflict_key,
|
|
CommandConflict.handler_full_name == handler_full_name,
|
|
),
|
|
)
|
|
record = result.scalar_one_or_none()
|
|
if not record:
|
|
record = self._new_command_conflict(
|
|
conflict_key,
|
|
handler_full_name,
|
|
plugin_name,
|
|
status=status,
|
|
resolution=resolution,
|
|
resolved_command=resolved_command,
|
|
note=note,
|
|
extra_data=extra_data,
|
|
auto_generated=auto_generated,
|
|
)
|
|
session.add(record)
|
|
else:
|
|
self._apply_updates(
|
|
record,
|
|
plugin_name=plugin_name,
|
|
status=status,
|
|
resolution=resolution,
|
|
resolved_command=resolved_command,
|
|
note=note,
|
|
extra_data=extra_data,
|
|
auto_generated=auto_generated,
|
|
)
|
|
await session.flush()
|
|
await session.refresh(record)
|
|
return record
|
|
|
|
return await self._run_in_tx(_op)
|
|
|
|
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
|
if not ids:
|
|
return
|
|
|
|
async def _op(session: AsyncSession) -> None:
|
|
await session.execute(
|
|
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
|
)
|
|
|
|
await self._run_in_tx(_op)
|
|
|
|
# ====
|
|
# Deprecated Methods
|
|
# ====
|
|
|
|
def get_base_stats(self, offset_sec=86400):
|
|
"""Get base statistics within the specified offset in seconds."""
|
|
|
|
async def _inner():
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
now = datetime.now()
|
|
start_time = now - timedelta(seconds=offset_sec)
|
|
result = await session.execute(
|
|
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
|
)
|
|
all_datas = result.scalars().all()
|
|
deprecated_stats = DeprecatedStats()
|
|
for data in all_datas:
|
|
deprecated_stats.platform.append(
|
|
DeprecatedPlatformStat(
|
|
name=data.platform_id,
|
|
count=data.count,
|
|
timestamp=int(data.timestamp.timestamp()),
|
|
),
|
|
)
|
|
return deprecated_stats
|
|
|
|
result = None
|
|
|
|
def runner() -> None:
|
|
nonlocal result
|
|
result = asyncio.run(_inner())
|
|
|
|
t = threading.Thread(target=runner)
|
|
t.start()
|
|
t.join()
|
|
return result
|
|
|
|
def get_total_message_count(self):
|
|
"""Get the total message count from platform statistics."""
|
|
|
|
async def _inner():
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(
|
|
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
|
)
|
|
total_count = result.scalar_one_or_none()
|
|
return total_count if total_count is not None else 0
|
|
|
|
result = None
|
|
|
|
def runner() -> None:
|
|
nonlocal result
|
|
result = asyncio.run(_inner())
|
|
|
|
t = threading.Thread(target=runner)
|
|
t.start()
|
|
t.join()
|
|
return result
|
|
|
|
def get_grouped_base_stats(self, offset_sec=86400):
|
|
# group by platform_id
|
|
async def _inner():
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
now = datetime.now()
|
|
start_time = now - timedelta(seconds=offset_sec)
|
|
result = await session.execute(
|
|
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
|
.where(PlatformStat.timestamp >= start_time)
|
|
.group_by(PlatformStat.platform_id),
|
|
)
|
|
grouped_stats = result.all()
|
|
deprecated_stats = DeprecatedStats()
|
|
for platform_id, count in grouped_stats:
|
|
deprecated_stats.platform.append(
|
|
DeprecatedPlatformStat(
|
|
name=platform_id,
|
|
count=count,
|
|
timestamp=int(start_time.timestamp()),
|
|
),
|
|
)
|
|
return deprecated_stats
|
|
|
|
result = None
|
|
|
|
def runner() -> None:
|
|
nonlocal result
|
|
result = asyncio.run(_inner())
|
|
|
|
t = threading.Thread(target=runner)
|
|
t.start()
|
|
t.join()
|
|
return result
|
|
|
|
# ====
|
|
# Platform Session Management
|
|
# ====
|
|
|
|
async def create_platform_session(
|
|
self,
|
|
creator: str,
|
|
platform_id: str = "webchat",
|
|
session_id: str | None = None,
|
|
display_name: str | None = None,
|
|
is_group: int = 0,
|
|
) -> PlatformSession:
|
|
"""Create a new Platform session."""
|
|
kwargs = {}
|
|
if session_id:
|
|
kwargs["session_id"] = session_id
|
|
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
new_session = PlatformSession(
|
|
creator=creator,
|
|
platform_id=platform_id,
|
|
display_name=display_name,
|
|
is_group=is_group,
|
|
**kwargs,
|
|
)
|
|
session.add(new_session)
|
|
await session.flush()
|
|
await session.refresh(new_session)
|
|
return new_session
|
|
|
|
async def get_platform_session_by_id(
|
|
self, session_id: str
|
|
) -> PlatformSession | None:
|
|
"""Get a Platform session by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(PlatformSession).where(
|
|
PlatformSession.session_id == session_id,
|
|
)
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_platform_sessions_by_creator(
|
|
self,
|
|
creator: str,
|
|
platform_id: str | None = None,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
) -> list[dict]:
|
|
"""Get all Platform sessions for a specific creator (username) and optionally platform.
|
|
|
|
Returns a list of dicts containing session info and project info (if session belongs to a project).
|
|
"""
|
|
(
|
|
sessions_with_projects,
|
|
_,
|
|
) = await self.get_platform_sessions_by_creator_paginated(
|
|
creator=creator,
|
|
platform_id=platform_id,
|
|
page=page,
|
|
page_size=page_size,
|
|
exclude_project_sessions=False,
|
|
)
|
|
return sessions_with_projects
|
|
|
|
@staticmethod
|
|
def _build_platform_sessions_query(
|
|
creator: str,
|
|
platform_id: str | None = None,
|
|
exclude_project_sessions: bool = False,
|
|
):
|
|
query = (
|
|
select(
|
|
PlatformSession,
|
|
col(ChatUIProject.project_id),
|
|
col(ChatUIProject.title).label("project_title"),
|
|
col(ChatUIProject.emoji).label("project_emoji"),
|
|
)
|
|
.outerjoin(
|
|
SessionProjectRelation,
|
|
col(PlatformSession.session_id)
|
|
== col(SessionProjectRelation.session_id),
|
|
)
|
|
.outerjoin(
|
|
ChatUIProject,
|
|
col(SessionProjectRelation.project_id) == col(ChatUIProject.project_id),
|
|
)
|
|
.where(col(PlatformSession.creator) == creator)
|
|
)
|
|
|
|
if platform_id:
|
|
query = query.where(PlatformSession.platform_id == platform_id)
|
|
if exclude_project_sessions:
|
|
query = query.where(col(ChatUIProject.project_id).is_(None))
|
|
|
|
return query
|
|
|
|
@staticmethod
|
|
def _rows_to_session_dicts(rows: T.Sequence[Row[tuple]]) -> list[dict]:
|
|
sessions_with_projects = []
|
|
for row in rows:
|
|
platform_session = row[0]
|
|
project_id = row[1]
|
|
project_title = row[2]
|
|
project_emoji = row[3]
|
|
|
|
session_dict = {
|
|
"session": platform_session,
|
|
"project_id": project_id,
|
|
"project_title": project_title,
|
|
"project_emoji": project_emoji,
|
|
}
|
|
sessions_with_projects.append(session_dict)
|
|
|
|
return sessions_with_projects
|
|
|
|
async def get_platform_sessions_by_creator_paginated(
|
|
self,
|
|
creator: str,
|
|
platform_id: str | None = None,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
exclude_project_sessions: bool = False,
|
|
) -> tuple[list[dict], int]:
|
|
"""Get paginated Platform sessions for a creator with total count."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
offset = (page - 1) * page_size
|
|
|
|
base_query = self._build_platform_sessions_query(
|
|
creator=creator,
|
|
platform_id=platform_id,
|
|
exclude_project_sessions=exclude_project_sessions,
|
|
)
|
|
|
|
total_result = await session.execute(
|
|
select(func.count()).select_from(base_query.subquery())
|
|
)
|
|
total = int(total_result.scalar_one() or 0)
|
|
|
|
result_query = (
|
|
base_query.order_by(desc(PlatformSession.updated_at))
|
|
.offset(offset)
|
|
.limit(page_size)
|
|
)
|
|
result = await session.execute(result_query)
|
|
|
|
sessions_with_projects = self._rows_to_session_dicts(result.all())
|
|
return sessions_with_projects, total
|
|
|
|
async def update_platform_session(
|
|
self,
|
|
session_id: str,
|
|
display_name: str | None = None,
|
|
) -> None:
|
|
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
|
if display_name is not None:
|
|
values["display_name"] = display_name
|
|
|
|
await session.execute(
|
|
update(PlatformSession)
|
|
.where(col(PlatformSession.session_id) == session_id)
|
|
.values(**values),
|
|
)
|
|
|
|
async def delete_platform_session(self, session_id: str) -> None:
|
|
"""Delete a Platform session by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(PlatformSession).where(
|
|
col(PlatformSession.session_id) == session_id,
|
|
),
|
|
)
|
|
|
|
# ====
|
|
# ChatUI Project Management
|
|
# ====
|
|
|
|
async def create_chatui_project(
|
|
self,
|
|
creator: str,
|
|
title: str,
|
|
emoji: str | None = "📁",
|
|
description: str | None = None,
|
|
) -> ChatUIProject:
|
|
"""Create a new ChatUI project."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
project = ChatUIProject(
|
|
creator=creator,
|
|
title=title,
|
|
emoji=emoji,
|
|
description=description,
|
|
)
|
|
session.add(project)
|
|
await session.flush()
|
|
await session.refresh(project)
|
|
return project
|
|
|
|
async def get_chatui_project_by_id(self, project_id: str) -> ChatUIProject | None:
|
|
"""Get a ChatUI project by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(
|
|
select(ChatUIProject).where(
|
|
col(ChatUIProject.project_id) == project_id,
|
|
),
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_chatui_projects_by_creator(
|
|
self,
|
|
creator: str,
|
|
page: int = 1,
|
|
page_size: int = 100,
|
|
) -> list[ChatUIProject]:
|
|
"""Get all ChatUI projects for a specific creator."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
offset = (page - 1) * page_size
|
|
result = await session.execute(
|
|
select(ChatUIProject)
|
|
.where(col(ChatUIProject.creator) == creator)
|
|
.order_by(desc(ChatUIProject.updated_at))
|
|
.limit(page_size)
|
|
.offset(offset),
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_chatui_project(
|
|
self,
|
|
project_id: str,
|
|
title: str | None = None,
|
|
emoji: str | None = None,
|
|
description: str | None = None,
|
|
) -> None:
|
|
"""Update a ChatUI project."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
|
if title is not None:
|
|
values["title"] = title
|
|
if emoji is not None:
|
|
values["emoji"] = emoji
|
|
if description is not None:
|
|
values["description"] = description
|
|
|
|
await session.execute(
|
|
update(ChatUIProject)
|
|
.where(col(ChatUIProject.project_id) == project_id)
|
|
.values(**values),
|
|
)
|
|
|
|
async def delete_chatui_project(self, project_id: str) -> None:
|
|
"""Delete a ChatUI project by its ID."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
# First remove all session relations
|
|
await session.execute(
|
|
delete(SessionProjectRelation).where(
|
|
col(SessionProjectRelation.project_id) == project_id,
|
|
),
|
|
)
|
|
# Then delete the project
|
|
await session.execute(
|
|
delete(ChatUIProject).where(
|
|
col(ChatUIProject.project_id) == project_id,
|
|
),
|
|
)
|
|
|
|
async def add_session_to_project(
|
|
self,
|
|
session_id: str,
|
|
project_id: str,
|
|
) -> SessionProjectRelation:
|
|
"""Add a session to a project."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
# First remove existing relation if any
|
|
await session.execute(
|
|
delete(SessionProjectRelation).where(
|
|
col(SessionProjectRelation.session_id) == session_id,
|
|
),
|
|
)
|
|
# Then create new relation
|
|
relation = SessionProjectRelation(
|
|
session_id=session_id,
|
|
project_id=project_id,
|
|
)
|
|
session.add(relation)
|
|
await session.flush()
|
|
await session.refresh(relation)
|
|
return relation
|
|
|
|
async def remove_session_from_project(self, session_id: str) -> None:
|
|
"""Remove a session from its project."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(SessionProjectRelation).where(
|
|
col(SessionProjectRelation.session_id) == session_id,
|
|
),
|
|
)
|
|
|
|
async def get_project_sessions(
|
|
self,
|
|
project_id: str,
|
|
page: int = 1,
|
|
page_size: int = 100,
|
|
) -> list[PlatformSession]:
|
|
"""Get all sessions in a project."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
offset = (page - 1) * page_size
|
|
result = await session.execute(
|
|
select(PlatformSession)
|
|
.join(
|
|
SessionProjectRelation,
|
|
col(PlatformSession.session_id)
|
|
== col(SessionProjectRelation.session_id),
|
|
)
|
|
.where(col(SessionProjectRelation.project_id) == project_id)
|
|
.order_by(desc(PlatformSession.updated_at))
|
|
.limit(page_size)
|
|
.offset(offset),
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_project_by_session(
|
|
self, session_id: str, creator: str
|
|
) -> ChatUIProject | None:
|
|
"""Get the project that a session belongs to."""
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(
|
|
select(ChatUIProject)
|
|
.join(
|
|
SessionProjectRelation,
|
|
col(ChatUIProject.project_id)
|
|
== col(SessionProjectRelation.project_id),
|
|
)
|
|
.where(
|
|
col(SessionProjectRelation.session_id) == session_id,
|
|
col(ChatUIProject.creator) == creator,
|
|
),
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
# ====
|
|
# Cron Job Management
|
|
# ====
|
|
|
|
async def create_cron_job(
|
|
self,
|
|
name: str,
|
|
job_type: str,
|
|
cron_expression: str | None,
|
|
*,
|
|
timezone: str | None = None,
|
|
payload: dict | None = None,
|
|
description: str | None = None,
|
|
enabled: bool = True,
|
|
persistent: bool = True,
|
|
run_once: bool = False,
|
|
status: str | None = None,
|
|
job_id: str | None = None,
|
|
) -> CronJob:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
job = CronJob(
|
|
name=name,
|
|
job_type=job_type,
|
|
cron_expression=cron_expression,
|
|
timezone=timezone,
|
|
payload=payload or {},
|
|
description=description,
|
|
enabled=enabled,
|
|
persistent=persistent,
|
|
run_once=run_once,
|
|
status=status or "scheduled",
|
|
)
|
|
if job_id:
|
|
job.job_id = job_id
|
|
session.add(job)
|
|
await session.flush()
|
|
await session.refresh(job)
|
|
return job
|
|
|
|
async def update_cron_job(
|
|
self,
|
|
job_id: str,
|
|
*,
|
|
name: str | None | object = CRON_FIELD_NOT_SET,
|
|
cron_expression: str | None | object = CRON_FIELD_NOT_SET,
|
|
timezone: str | None | object = CRON_FIELD_NOT_SET,
|
|
payload: dict | None | object = CRON_FIELD_NOT_SET,
|
|
description: str | None | object = CRON_FIELD_NOT_SET,
|
|
enabled: bool | None | object = CRON_FIELD_NOT_SET,
|
|
persistent: bool | None | object = CRON_FIELD_NOT_SET,
|
|
run_once: bool | None | object = CRON_FIELD_NOT_SET,
|
|
status: str | None | object = CRON_FIELD_NOT_SET,
|
|
next_run_time: datetime | None | object = CRON_FIELD_NOT_SET,
|
|
last_run_at: datetime | None | object = CRON_FIELD_NOT_SET,
|
|
last_error: str | None | object = CRON_FIELD_NOT_SET,
|
|
) -> CronJob | None:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
updates: dict = {}
|
|
for key, val in {
|
|
"name": name,
|
|
"cron_expression": cron_expression,
|
|
"timezone": timezone,
|
|
"payload": payload,
|
|
"description": description,
|
|
"enabled": enabled,
|
|
"persistent": persistent,
|
|
"run_once": run_once,
|
|
"status": status,
|
|
"next_run_time": next_run_time,
|
|
"last_run_at": last_run_at,
|
|
"last_error": last_error,
|
|
}.items():
|
|
if val is CRON_FIELD_NOT_SET:
|
|
continue
|
|
updates[key] = val
|
|
|
|
stmt = (
|
|
update(CronJob)
|
|
.where(col(CronJob.job_id) == job_id)
|
|
.values(**updates)
|
|
.execution_options(synchronize_session="fetch")
|
|
)
|
|
await session.execute(stmt)
|
|
result = await session.execute(
|
|
select(CronJob).where(col(CronJob.job_id) == job_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def delete_cron_job(self, job_id: str) -> None:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(CronJob).where(col(CronJob.job_id) == job_id)
|
|
)
|
|
|
|
async def get_cron_job(self, job_id: str) -> CronJob | None:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
result = await session.execute(
|
|
select(CronJob).where(col(CronJob.job_id) == job_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_cron_jobs(self, job_type: str | None = None) -> list[CronJob]:
|
|
async with self.get_db() as session:
|
|
session: AsyncSession
|
|
query = select(CronJob)
|
|
if job_type:
|
|
query = query.where(col(CronJob.job_type) == job_type)
|
|
query = query.order_by(desc(CronJob.created_at))
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|