Compare commits

...

2 Commits

Author SHA1 Message Date
Soulter 394dcf3199 chore: ruff format 2025-12-04 14:24:03 +08:00
Soulter e6deb46332 feat: integrate MySQL support and enhance database management
- Added MySQL database implementation with connection settings and session management.
- Introduced a new AstrBotMySQLSettings class for configuration.
- Updated database helper functions to support both SQLite and MySQL.
- Enhanced platform statistics retrieval with time series data for both database types.
- Refactored existing SQLite methods to align with new database structure and functionality.

closes: #3848
2025-12-04 14:22:54 +08:00
8 changed files with 1107 additions and 146 deletions
+35 -2
View File
@@ -1,8 +1,10 @@
import os import os
from pydantic_settings import BaseSettings, SettingsConfigDict
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH from astrbot.core.config.default import DB_PATH
from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.db.sqlite import BaseDatabase
from astrbot.core.file_token_service import FileTokenService from astrbot.core.file_token_service import FileTokenService
from astrbot.core.utils.pip_installer import PipInstaller from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.utils.shared_preferences import SharedPreferences from astrbot.core.utils.shared_preferences import SharedPreferences
@@ -14,13 +16,44 @@ from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹 # 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True) os.makedirs(get_astrbot_data_path(), exist_ok=True)
class AstrBotMySQLSettings(BaseSettings):
host: str = "localhost"
port: int = 3306
user: str = "root"
password: str = ""
database: str = "astrbot"
charset: str = "utf8mb4"
model_config = SettingsConfigDict(env_file=".env", env_prefix="ASTR_MYSQL_")
def get_db_helper() -> BaseDatabase:
db_type = os.getenv("ASTR_DB_TYPE", "sqlite")
match db_type:
case "sqlite":
from astrbot.core.db.sqlite import SQLiteDatabase
return SQLiteDatabase(DB_PATH)
case "mysql":
from astrbot.core.db.mysql import MySQLDatabase
mysql_settings = AstrBotMySQLSettings()
return MySQLDatabase(**mysql_settings.model_dump())
case _:
from astrbot.core.db.sqlite import SQLiteDatabase
return SQLiteDatabase(DB_PATH)
DEMO_MODE = os.getenv("DEMO_MODE", False) DEMO_MODE = os.getenv("DEMO_MODE", False)
astrbot_config = AstrBotConfig() astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
html_renderer = HtmlRenderer(t2i_base_url) html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot") logger = LogManager.GetLogger(log_name="astrbot")
db_helper = SQLiteDatabase(DB_PATH) db_helper = get_db_helper()
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences(db_helper=db_helper) sp = SharedPreferences(db_helper=db_helper)
# 文件令牌服务 # 文件令牌服务
+18 -1
View File
@@ -3,6 +3,7 @@ import datetime
import typing as T import typing as T
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from deprecated import deprecated from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@@ -20,11 +21,17 @@ from astrbot.core.db.po import (
) )
class DatabaseType(Enum):
SQLITE = "sqlite"
MYSQL = "mysql"
@dataclass @dataclass
class BaseDatabase(abc.ABC): class BaseDatabase(abc.ABC):
"""数据库基类""" """数据库基类"""
DATABASE_URL = "" DATABASE_URL = ""
database_type: DatabaseType
def __init__(self) -> None: def __init__(self) -> None:
self.engine = create_async_engine( self.engine = create_async_engine(
@@ -83,7 +90,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def count_platform_stats(self) -> int: async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records.""" """Sum the count of platform statistics records."""
... ...
@abc.abstractmethod @abc.abstractmethod
@@ -91,6 +98,16 @@ class BaseDatabase(abc.ABC):
"""Get platform statistics within the specified offset in seconds and group by platform_id.""" """Get platform statistics within the specified offset in seconds and group by platform_id."""
... ...
@abc.abstractmethod
async def get_platform_stats_time_series(
self, offset_sec: int = 86400
) -> list[tuple[int, int]]:
"""Get platform statistics time series data grouped by hour.
Returns a list of tuples (hour_timestamp, count) sorted by timestamp ascending.
"""
...
@abc.abstractmethod @abc.abstractmethod
async def get_conversations( async def get_conversations(
self, self,
+5 -1
View File
@@ -2,7 +2,7 @@ import os
from astrbot.api import logger, sp from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase, DatabaseType
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .migra_3_to_4 import ( from .migra_3_to_4 import (
@@ -24,6 +24,10 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
if not os.path.exists(data_v3_db): if not os.path.exists(data_v3_db):
return False return False
if db_helper.database_type == DatabaseType.MYSQL:
return False
migration_done = await db_helper.get_preference( migration_done = await db_helper.get_preference(
"global", "global",
"global", "global",
@@ -38,7 +38,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
query = ( query = (
select( select(
col(PlatformMessageHistory.user_id), col(PlatformMessageHistory.user_id),
col(PlatformMessageHistory.sender_name), func.max(PlatformMessageHistory.sender_name).label("sender_name"),
func.min(PlatformMessageHistory.created_at).label("earliest"), func.min(PlatformMessageHistory.created_at).label("earliest"),
func.max(PlatformMessageHistory.updated_at).label("latest"), func.max(PlatformMessageHistory.updated_at).label("latest"),
) )
+875
View File
@@ -0,0 +1,875 @@
import asyncio
import typing as T
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase, DatabaseType
from astrbot.core.db.po import (
Attachment,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
SQLModel,
)
from astrbot.core.db.po import Stats as DeprecatedStats
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
class MySQLDatabase(BaseDatabase):
"""MySQL 数据库实现
使用方式:
db = MySQLDatabase(
host="localhost",
port=3306,
user="root",
password="password",
database="astrbot"
)
await db.initialize()
"""
database_type = DatabaseType.MYSQL
def __init__(
self,
host: str = "localhost",
port: int = 3306,
user: str = "root",
password: str = "",
database: str = "astrbot",
charset: str = "utf8mb4",
) -> None:
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.charset = charset
self.DATABASE_URL = (
f"mysql+aiomysql://{user}:{password}@{host}:{port}/{database}"
f"?charset={charset}"
)
self.inited = False
self._current_loop: asyncio.AbstractEventLoop | None = None
super().__init__()
def _recreate_engine(self) -> None:
"""重新创建数据库引擎和会话工厂,用于处理事件循环切换的情况"""
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
"""Get a database session.
此方法会检查当前事件循环,如果事件循环发生变化会重新创建引擎,
以解决 aiomysql 的 "attached to a different loop" 问题。
"""
try:
current_loop = asyncio.get_running_loop()
except RuntimeError:
current_loop = None
# 检查事件循环是否变化,如果变化则重新创建引擎
if current_loop is not None and self._current_loop != current_loop:
self._recreate_engine()
self._current_loop = current_loop
self.inited = False # 需要重新初始化
if not self.inited:
await self.initialize()
self.inited = True
async with self.AsyncSessionLocal() as session:
yield session
async def initialize(self) -> None:
"""Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.commit()
# ====
# Platform Statistics
# ====
async def insert_platform_stats(
self,
platform_id,
platform_type,
count=1,
timestamp=None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
if timestamp is None:
timestamp = datetime.now().replace(
minute=0,
second=0,
microsecond=0,
)
current_hour = timestamp
await session.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON DUPLICATE KEY UPDATE
count = count + VALUES(count)
"""),
{
"timestamp": current_hour,
"platform_id": platform_id,
"platform_type": platform_type,
"count": count,
},
)
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
)
count = result.scalar_one_or_none()
return count or 0
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
text("""
SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts
FROM platform_stats
WHERE timestamp >= :start_time
GROUP BY platform_id, platform_type
ORDER BY latest_ts DESC
"""),
{"start_time": start_time},
)
rows = result.fetchall()
return [
PlatformStat(
id=0,
platform_id=row.platform_id,
platform_type=row.platform_type,
count=row.total_count,
timestamp=row.latest_ts,
)
for row in rows
]
async def get_platform_stats_time_series(
self, offset_sec: int = 86400
) -> list[tuple[int, int]]:
"""Get platform statistics time series data grouped by hour."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
text("""
SELECT UNIX_TIMESTAMP(DATE_FORMAT(timestamp, '%Y-%m-%d %H:00:00')) as hour_ts, SUM(count) as total_count
FROM platform_stats
WHERE timestamp >= :start_time
GROUP BY hour_ts
ORDER BY hour_ts ASC
"""),
{"start_time": start_time},
)
rows = result.fetchall()
return [(int(row.hour_ts), row.total_count) for row in rows]
# ====
# Conversation Management
# ====
async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
if user_id:
query = query.where(ConversationV2.user_id == user_id)
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(desc(ConversationV2.created_at))
result = await session.execute(query)
return result.scalars().all()
async def get_conversation_by_id(self, cid):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_all_conversations(self, page=1, page_size=20):
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size),
)
return result.scalars().all()
async def get_filtered_conversations(
self,
page=1,
page_size=20,
platform_ids=None,
search_query="",
**kwargs,
):
async with self.get_db() as session:
session: AsyncSession
# Build the base query with filters
base_query = select(ConversationV2)
if platform_ids:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(platform_ids),
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
or_(
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
),
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
)
# Get total count matching the filters
count_query = select(func.count()).select_from(base_query.subquery())
total_count = await session.execute(count_query)
total = total_count.scalar_one()
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(result_query)
conversations = result.scalars().all()
return conversations, total
async def create_conversation(
self,
user_id,
platform_id,
content=None,
title=None,
persona_id=None,
cid=None,
created_at=None,
updated_at=None,
):
kwargs = {}
if cid:
kwargs["conversation_id"] = cid
if created_at:
kwargs["created_at"] = created_at
if updated_at:
kwargs["updated_at"] = updated_at
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_conversation = ConversationV2(
user_id=user_id,
content=content or [],
platform_id=platform_id,
title=title,
persona_id=persona_id,
**kwargs,
)
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
)
values = {}
if title is not None:
values["title"] = title
if persona_id is not None:
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
),
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.user_id) == user_id
),
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
# MySQL 使用 JSON_EXTRACT 函数(与 SQLite 的 json_extract 兼容)
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id",
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history(
self,
platform_id,
user_id,
content,
sender_id=None,
sender_name=None,
):
"""Insert a new platform message history record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=user_id,
content=content,
sender_id=sender_id,
sender_name=sender_name,
)
session.add(new_history)
return new_history
async def delete_platform_message_offset(
self,
platform_id,
user_id,
offset_sec=86400,
):
"""Delete platform message history records newer than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
now = datetime.now()
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) >= cutoff_time,
),
)
async def get_platform_message_history(
self,
platform_id,
user_id,
page=1,
page_size=20,
):
"""Get platform message history records."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = (
select(PlatformMessageHistory)
.where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(desc(PlatformMessageHistory.created_at))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
async def get_platform_message_history_by_id(
self, message_id: int
) -> PlatformMessageHistory | None:
"""Get a platform message history record by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformMessageHistory).where(
PlatformMessageHistory.id == message_id
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_attachment = Attachment(
path=path,
type=type,
mime_type=mime_type,
)
session.add(new_attachment)
return new_attachment
async def get_attachment_by_id(self, attachment_id):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_attachments(self, attachment_ids: list[str]) -> list:
"""Get multiple attachments by their IDs."""
if not attachment_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
Attachment.attachment_id.in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id
)
result = await session.execute(query)
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
if not attachment_ids:
return 0
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
return result.rowcount
async def insert_persona(
self,
persona_id,
system_prompt,
begin_dialogs=None,
tools=None,
):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_persona = Persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
tools=tools,
)
session.add(new_persona)
return new_persona
async def get_persona_by_id(self, persona_id):
"""Get a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona).where(Persona.persona_id == persona_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_personas(self):
"""Get all personas for a specific bot."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona)
result = await session.execute(query)
return result.scalars().all()
async def update_persona(
self,
persona_id,
system_prompt=None,
begin_dialogs=None,
tools=NOT_GIVEN,
):
"""Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(col(Persona.persona_id) == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
if begin_dialogs is not None:
values["begin_dialogs"] = begin_dialogs
if tools is not NOT_GIVEN:
values["tools"] = tools
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id):
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(col(Persona.persona_id) == persona_id),
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
existing_preference = result.scalar_one_or_none()
if existing_preference:
existing_preference.value = value
else:
new_preference = Preference(
scope=scope,
scope_id=scope_id,
key=key,
value=value,
)
session.add(new_preference)
return existing_preference or new_preference
async def get_preference(self, scope, scope_id, key):
"""Get a preference by key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_preferences(self, scope, scope_id=None, key=None):
"""Get all preferences for a specific scope ID or key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(Preference.scope == scope)
if scope_id is not None:
query = query.where(Preference.scope_id == scope_id)
if key is not None:
query = query.where(Preference.key == key)
result = await session.execute(query)
return result.scalars().all()
async def remove_preference(self, scope, scope_id, key):
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
),
)
await session.commit()
async def clear_preferences(self, scope, scope_id):
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
),
)
await session.commit()
# ====
# Platform Session Management
# ====
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> PlatformSession:
"""Create a new Platform session."""
kwargs = {}
if session_id:
kwargs["session_id"] = session_id
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_session = PlatformSession(
creator=creator,
platform_id=platform_id,
display_name=display_name,
is_group=is_group,
**kwargs,
)
session.add(new_session)
await session.flush()
await session.refresh(new_session)
return new_session
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformSession).where(
PlatformSession.session_id == session_id,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[PlatformSession]:
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = select(PlatformSession).where(PlatformSession.creator == creator)
if platform_id:
query = query.where(PlatformSession.platform_id == platform_id)
query = (
query.order_by(desc(PlatformSession.updated_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a Platform session's updated_at timestamp and optionally display_name."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if display_name is not None:
values["display_name"] = display_name
await session.execute(
update(PlatformSession)
.where(col(PlatformSession.session_id) == session_id)
.values(**values),
)
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(PlatformSession).where(
col(PlatformSession.session_id) == session_id,
),
)
# ====
# Deprecated Methods
# ====
def get_base_stats(self, offset_sec=86400):
"""Get base statistics within the specified offset in seconds."""
return DeprecatedStats()
def get_total_message_count(self):
"""Get the total message count from platform statistics."""
return 0
def get_grouped_base_stats(self, offset_sec=86400):
# group by platform_id
return DeprecatedStats()
+137 -105
View File
@@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase, DatabaseType
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment, Attachment,
ConversationV2, ConversationV2,
@@ -28,6 +28,8 @@ NOT_GIVEN = T.TypeVar("NOT_GIVEN")
class SQLiteDatabase(BaseDatabase): class SQLiteDatabase(BaseDatabase):
database_type = DatabaseType.SQLITE
def __init__(self, db_path: str) -> None: def __init__(self, db_path: str) -> None:
self.db_path = db_path self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
@@ -88,12 +90,10 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
result = await session.execute( result = await session.execute(
select(func.count(col(PlatformStat.platform_id))).select_from( select(func.sum(PlatformStat.count)).select_from(PlatformStat),
PlatformStat,
),
) )
count = result.scalar_one_or_none() count = result.scalar_one_or_none()
return count if count is not None else 0 return count or 0
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id.""" """Get platform statistics within the specified offset in seconds and group by platform_id."""
@@ -103,14 +103,46 @@ class SQLiteDatabase(BaseDatabase):
start_time = now - timedelta(seconds=offset_sec) start_time = now - timedelta(seconds=offset_sec)
result = await session.execute( result = await session.execute(
text(""" text("""
SELECT * FROM platform_stats SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts
FROM platform_stats
WHERE timestamp >= :start_time WHERE timestamp >= :start_time
GROUP BY platform_id GROUP BY platform_id, platform_type
ORDER BY timestamp DESC ORDER BY latest_ts DESC
"""), """),
{"start_time": start_time}, {"start_time": start_time},
) )
return list(result.scalars().all()) rows = result.fetchall()
return [
PlatformStat(
id=0,
platform_id=row.platform_id,
platform_type=row.platform_type,
count=row.total_count,
timestamp=row.latest_ts,
)
for row in rows
]
async def get_platform_stats_time_series(
self, offset_sec: int = 86400
) -> list[tuple[int, int]]:
"""Get platform statistics time series data grouped by hour."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
text("""
SELECT strftime('%s', datetime(timestamp, 'start of hour')) as hour_ts, SUM(count) as total_count
FROM platform_stats
WHERE timestamp >= :start_time
GROUP BY hour_ts
ORDER BY hour_ts ASC
"""),
{"start_time": start_time},
)
rows = result.fetchall()
return [(int(row.hour_ts), row.total_count) for row in rows]
# ==== # ====
# Conversation Management # Conversation Management
@@ -669,102 +701,6 @@ class SQLiteDatabase(BaseDatabase):
) )
await session.commit() await session.commit()
# ====
# Deprecated Methods
# ====
def get_base_stats(self, offset_sec=86400):
"""Get base statistics within the specified offset in seconds."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
)
all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats()
for data in all_datas:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=int(data.timestamp.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_total_message_count(self):
"""Get the total message count from platform statistics."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
)
total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_grouped_base_stats(self, offset_sec=86400):
# group by platform_id
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id),
)
grouped_stats = result.all()
deprecated_stats = DeprecatedStats()
for platform_id, count in grouped_stats:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=int(start_time.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
# ==== # ====
# Platform Session Management # Platform Session Management
# ==== # ====
@@ -862,3 +798,99 @@ class SQLiteDatabase(BaseDatabase):
col(PlatformSession.session_id) == session_id, col(PlatformSession.session_id) == session_id,
), ),
) )
# ====
# Deprecated Methods
# ====
def get_base_stats(self, offset_sec=86400):
"""Get base statistics within the specified offset in seconds."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
)
all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats()
for data in all_datas:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=int(data.timestamp.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_total_message_count(self):
"""Get the total message count from platform statistics."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
)
total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_grouped_base_stats(self, offset_sec=86400):
# group by platform_id
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id),
)
grouped_stats = result.all()
deprecated_stats = DeprecatedStats()
for platform_id, count in grouped_stats:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=int(start_time.timestamp()),
),
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
+35 -36
View File
@@ -84,23 +84,14 @@ class StatRoute(Route):
offset_sec = request.args.get("offset_sec", 86400) offset_sec = request.args.get("offset_sec", 86400)
offset_sec = int(offset_sec) offset_sec = int(offset_sec)
try: try:
stat = self.db_helper.get_base_stats(offset_sec) platform_stats = await self.db_helper.get_platform_stats(offset_sec)
now = int(time.time())
start_time = now - offset_sec
message_time_based_stats = []
idx = 0 # 获取时间序列数据(按小时分桶)
for bucket_end in range(start_time, now, 3600): time_series = await self.db_helper.get_platform_stats_time_series(
cnt = 0 offset_sec
while ( )
idx < len(stat.platform) message_time_based_stats = [[ts, cnt] for ts, cnt in time_series]
and stat.platform[idx].timestamp < bucket_end message_count = await self.db_helper.count_platform_stats()
):
cnt += stat.platform[idx].count
idx += 1
message_time_based_stats.append([bucket_end, cnt])
stat_dict = stat.__dict__
cpu_percent = psutil.cpu_percent(interval=0.5) cpu_percent = psutil.cpu_percent(interval=0.5)
thread_count = threading.active_count() thread_count = threading.active_count()
@@ -121,28 +112,36 @@ class StatRoute(Route):
int(time.time()) - self.core_lifecycle.start_time, int(time.time()) - self.core_lifecycle.start_time,
) )
stat_dict.update( # 构建平台统计数据
platform_data = [
{ {
"platform": self.db_helper.get_grouped_base_stats( "name": stat.platform_id,
offset_sec, "count": stat.count,
).platform, "timestamp": int(stat.timestamp.timestamp())
"message_count": self.db_helper.get_total_message_count() or 0, if stat.timestamp
"platform_count": len( else 0,
self.core_lifecycle.platform_manager.get_insts(), }
), for stat in platform_stats
"plugin_count": len(plugins), ]
"plugins": plugin_info,
"message_time_series": message_time_based_stats, stat_dict = {
"running": running_time, # 现在返回时间组件而不是格式化的字符串 "platform": platform_data,
"memory": { "message_count": message_count,
"process": psutil.Process().memory_info().rss >> 20, "platform_count": len(
"system": psutil.virtual_memory().total >> 20, self.core_lifecycle.platform_manager.get_insts(),
}, ),
"cpu_percent": round(cpu_percent, 1), "plugin_count": len(plugins),
"thread_count": thread_count, "plugins": plugin_info,
"start_time": self.core_lifecycle.start_time, "message_time_series": message_time_based_stats,
"running": running_time, # 现在返回时间组件而不是格式化的字符串
"memory": {
"process": psutil.Process().memory_info().rss >> 20,
"system": psutil.virtual_memory().total >> 20,
}, },
) "cpu_percent": round(cpu_percent, 1),
"thread_count": thread_count,
"start_time": self.core_lifecycle.start_time,
}
return Response().ok(stat_dict).__dict__ return Response().ok(stat_dict).__dict__
except Exception as e: except Exception as e:
+1
View File
@@ -60,6 +60,7 @@ dependencies = [
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2", "markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
"xinference-client", "xinference-client",
"tenacity>=9.1.2", "tenacity>=9.1.2",
"aiomysql>=0.3.2",
] ]
[dependency-groups] [dependency-groups]