Files
AstrBot/astrbot/core/db/sqlite.py
T
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00

866 lines
32 KiB
Python

import asyncio
import threading
import typing as T
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import (
Attachment,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
SQLModel,
)
from astrbot.core.db.po import (
Platform as DeprecatedPlatformStat,
)
from astrbot.core.db.po import (
Stats as DeprecatedStats,
)
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
super().__init__()
async def initialize(self) -> None:
"""Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
# ====
# Platform Statistics
# ====
async def insert_platform_stats(
self,
platform_id,
platform_type,
count=1,
timestamp=None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
if timestamp is None:
timestamp = datetime.now().replace(
minute=0,
second=0,
microsecond=0,
)
current_hour = timestamp
await session.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": current_hour,
"platform_id": platform_id,
"platform_type": platform_type,
"count": count,
},
)
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.count(col(PlatformStat.platform_id))).select_from(
PlatformStat,
),
)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
text("""
SELECT * FROM platform_stats
WHERE timestamp >= :start_time
GROUP BY platform_id
ORDER BY timestamp DESC
"""),
{"start_time": start_time},
)
return list(result.scalars().all())
# ====
# Conversation Management
# ====
async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
if user_id:
query = query.where(ConversationV2.user_id == user_id)
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(desc(ConversationV2.created_at))
result = await session.execute(query)
return result.scalars().all()
async def get_conversation_by_id(self, cid):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_all_conversations(self, page=1, page_size=20):
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size),
)
return result.scalars().all()
async def get_filtered_conversations(
self,
page=1,
page_size=20,
platform_ids=None,
search_query="",
**kwargs,
):
async with self.get_db() as session:
session: AsyncSession
# Build the base query with filters
base_query = select(ConversationV2)
if platform_ids:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(platform_ids),
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
or_(
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
),
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
)
# Get total count matching the filters
count_query = select(func.count()).select_from(base_query.subquery())
total_count = await session.execute(count_query)
total = total_count.scalar_one()
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(result_query)
conversations = result.scalars().all()
return conversations, total
async def create_conversation(
self,
user_id,
platform_id,
content=None,
title=None,
persona_id=None,
cid=None,
created_at=None,
updated_at=None,
):
kwargs = {}
if cid:
kwargs["conversation_id"] = cid
if created_at:
kwargs["created_at"] = created_at
if updated_at:
kwargs["updated_at"] = updated_at
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_conversation = ConversationV2(
user_id=user_id,
content=content or [],
platform_id=platform_id,
title=title,
persona_id=persona_id,
**kwargs,
)
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
)
values = {}
if title is not None:
values["title"] = title
if persona_id is not None:
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
),
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.user_id) == user_id
),
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id",
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history(
self,
platform_id,
user_id,
content,
sender_id=None,
sender_name=None,
):
"""Insert a new platform message history record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=user_id,
content=content,
sender_id=sender_id,
sender_name=sender_name,
)
session.add(new_history)
return new_history
async def delete_platform_message_offset(
self,
platform_id,
user_id,
offset_sec=86400,
):
"""Delete platform message history records newer than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
now = datetime.now()
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) >= cutoff_time,
),
)
async def get_platform_message_history(
self,
platform_id,
user_id,
page=1,
page_size=20,
):
"""Get platform message history records."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = (
select(PlatformMessageHistory)
.where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(desc(PlatformMessageHistory.created_at))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
async def get_platform_message_history_by_id(
self, message_id: int
) -> PlatformMessageHistory | None:
"""Get a platform message history record by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformMessageHistory).where(
PlatformMessageHistory.id == message_id
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_attachment = Attachment(
path=path,
type=type,
mime_type=mime_type,
)
session.add(new_attachment)
return new_attachment
async def get_attachment_by_id(self, attachment_id):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_attachments(self, attachment_ids: list[str]) -> list:
"""Get multiple attachments by their IDs."""
if not attachment_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
if not attachment_ids:
return 0
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount
async def insert_persona(
self,
persona_id,
system_prompt,
begin_dialogs=None,
tools=None,
):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_persona = Persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
tools=tools,
)
session.add(new_persona)
return new_persona
async def get_persona_by_id(self, persona_id):
"""Get a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona).where(Persona.persona_id == persona_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_personas(self):
"""Get all personas for a specific bot."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona)
result = await session.execute(query)
return result.scalars().all()
async def update_persona(
self,
persona_id,
system_prompt=None,
begin_dialogs=None,
tools=NOT_GIVEN,
):
"""Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(col(Persona.persona_id) == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
if begin_dialogs is not None:
values["begin_dialogs"] = begin_dialogs
if tools is not NOT_GIVEN:
values["tools"] = tools
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id):
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(col(Persona.persona_id) == persona_id),
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
existing_preference = result.scalar_one_or_none()
if existing_preference:
existing_preference.value = value
else:
new_preference = Preference(
scope=scope,
scope_id=scope_id,
key=key,
value=value,
)
session.add(new_preference)
return existing_preference or new_preference
async def get_preference(self, scope, scope_id, key):
"""Get a preference by key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_preferences(self, scope, scope_id=None, key=None):
"""Get all preferences for a specific scope ID or key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(Preference.scope == scope)
if scope_id is not None:
query = query.where(Preference.scope_id == scope_id)
if key is not None:
query = query.where(Preference.key == key)
result = await session.execute(query)
return result.scalars().all()
async def remove_preference(self, scope, scope_id, key):
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
),
)
await session.commit()
async def clear_preferences(self, scope, scope_id):
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
),
)
await session.commit()
# ====
# Deprecated Methods
# ====
def get_base_stats(self, offset_sec=86400):
"""Get base statistics within the specified offset in seconds."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
)
all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats()
for data in all_datas:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=int(data.timestamp.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_total_message_count(self):
"""Get the total message count from platform statistics."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
)
total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_grouped_base_stats(self, offset_sec=86400):
# group by platform_id
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id),
)
grouped_stats = result.all()
deprecated_stats = DeprecatedStats()
for platform_id, count in grouped_stats:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=int(start_time.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
# ====
# Platform Session Management
# ====
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> PlatformSession:
"""Create a new Platform session."""
kwargs = {}
if session_id:
kwargs["session_id"] = session_id
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_session = PlatformSession(
creator=creator,
platform_id=platform_id,
display_name=display_name,
is_group=is_group,
**kwargs,
)
session.add(new_session)
await session.flush()
await session.refresh(new_session)
return new_session
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformSession).where(
PlatformSession.session_id == session_id,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[PlatformSession]:
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = select(PlatformSession).where(PlatformSession.creator == creator)
if platform_id:
query = query.where(PlatformSession.platform_id == platform_id)
query = (
query.order_by(desc(PlatformSession.updated_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a Platform session's updated_at timestamp and optionally display_name."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if display_name is not None:
values["display_name"] = display_name
await session.execute(
update(PlatformSession)
.where(col(PlatformSession.session_id) == session_id)
.values(**values),
)
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(PlatformSession).where(
col(PlatformSession.session_id) == session_id,
),
)