Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 394dcf3199 | |||
| e6deb46332 |
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
from astrbot.core.config.default import DB_PATH
|
from astrbot.core.config.default import DB_PATH
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import BaseDatabase
|
||||||
from astrbot.core.file_token_service import FileTokenService
|
from astrbot.core.file_token_service import FileTokenService
|
||||||
from astrbot.core.utils.pip_installer import PipInstaller
|
from astrbot.core.utils.pip_installer import PipInstaller
|
||||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||||
@@ -14,13 +16,44 @@ from .utils.astrbot_path import get_astrbot_data_path
|
|||||||
# 初始化数据存储文件夹
|
# 初始化数据存储文件夹
|
||||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AstrBotMySQLSettings(BaseSettings):
|
||||||
|
host: str = "localhost"
|
||||||
|
port: int = 3306
|
||||||
|
user: str = "root"
|
||||||
|
password: str = ""
|
||||||
|
database: str = "astrbot"
|
||||||
|
charset: str = "utf8mb4"
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_prefix="ASTR_MYSQL_")
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_helper() -> BaseDatabase:
|
||||||
|
db_type = os.getenv("ASTR_DB_TYPE", "sqlite")
|
||||||
|
match db_type:
|
||||||
|
case "sqlite":
|
||||||
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
|
|
||||||
|
return SQLiteDatabase(DB_PATH)
|
||||||
|
case "mysql":
|
||||||
|
from astrbot.core.db.mysql import MySQLDatabase
|
||||||
|
|
||||||
|
mysql_settings = AstrBotMySQLSettings()
|
||||||
|
|
||||||
|
return MySQLDatabase(**mysql_settings.model_dump())
|
||||||
|
case _:
|
||||||
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
|
|
||||||
|
return SQLiteDatabase(DB_PATH)
|
||||||
|
|
||||||
|
|
||||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||||
|
|
||||||
astrbot_config = AstrBotConfig()
|
astrbot_config = AstrBotConfig()
|
||||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||||
html_renderer = HtmlRenderer(t2i_base_url)
|
html_renderer = HtmlRenderer(t2i_base_url)
|
||||||
logger = LogManager.GetLogger(log_name="astrbot")
|
logger = LogManager.GetLogger(log_name="astrbot")
|
||||||
db_helper = SQLiteDatabase(DB_PATH)
|
db_helper = get_db_helper()
|
||||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||||
sp = SharedPreferences(db_helper=db_helper)
|
sp = SharedPreferences(db_helper=db_helper)
|
||||||
# 文件令牌服务
|
# 文件令牌服务
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import datetime
|
|||||||
import typing as T
|
import typing as T
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
@@ -20,11 +21,17 @@ from astrbot.core.db.po import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseType(Enum):
|
||||||
|
SQLITE = "sqlite"
|
||||||
|
MYSQL = "mysql"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseDatabase(abc.ABC):
|
class BaseDatabase(abc.ABC):
|
||||||
"""数据库基类"""
|
"""数据库基类"""
|
||||||
|
|
||||||
DATABASE_URL = ""
|
DATABASE_URL = ""
|
||||||
|
database_type: DatabaseType
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.engine = create_async_engine(
|
self.engine = create_async_engine(
|
||||||
@@ -83,7 +90,7 @@ class BaseDatabase(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def count_platform_stats(self) -> int:
|
async def count_platform_stats(self) -> int:
|
||||||
"""Count the number of platform statistics records."""
|
"""Sum the count of platform statistics records."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -91,6 +98,16 @@ class BaseDatabase(abc.ABC):
|
|||||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def get_platform_stats_time_series(
|
||||||
|
self, offset_sec: int = 86400
|
||||||
|
) -> list[tuple[int, int]]:
|
||||||
|
"""Get platform statistics time series data grouped by hour.
|
||||||
|
|
||||||
|
Returns a list of tuples (hour_timestamp, count) sorted by timestamp ascending.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_conversations(
|
async def get_conversations(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.api import logger, sp
|
from astrbot.api import logger, sp
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase, DatabaseType
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
from .migra_3_to_4 import (
|
from .migra_3_to_4 import (
|
||||||
@@ -24,6 +24,10 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
|||||||
|
|
||||||
if not os.path.exists(data_v3_db):
|
if not os.path.exists(data_v3_db):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if db_helper.database_type == DatabaseType.MYSQL:
|
||||||
|
return False
|
||||||
|
|
||||||
migration_done = await db_helper.get_preference(
|
migration_done = await db_helper.get_preference(
|
||||||
"global",
|
"global",
|
||||||
"global",
|
"global",
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
|||||||
query = (
|
query = (
|
||||||
select(
|
select(
|
||||||
col(PlatformMessageHistory.user_id),
|
col(PlatformMessageHistory.user_id),
|
||||||
col(PlatformMessageHistory.sender_name),
|
func.max(PlatformMessageHistory.sender_name).label("sender_name"),
|
||||||
func.min(PlatformMessageHistory.created_at).label("earliest"),
|
func.min(PlatformMessageHistory.created_at).label("earliest"),
|
||||||
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,875 @@
|
|||||||
|
import asyncio
|
||||||
|
import typing as T
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||||
|
|
||||||
|
from astrbot.core.db import BaseDatabase, DatabaseType
|
||||||
|
from astrbot.core.db.po import (
|
||||||
|
Attachment,
|
||||||
|
ConversationV2,
|
||||||
|
Persona,
|
||||||
|
PlatformMessageHistory,
|
||||||
|
PlatformSession,
|
||||||
|
PlatformStat,
|
||||||
|
Preference,
|
||||||
|
SQLModel,
|
||||||
|
)
|
||||||
|
from astrbot.core.db.po import Stats as DeprecatedStats
|
||||||
|
|
||||||
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLDatabase(BaseDatabase):
|
||||||
|
"""MySQL 数据库实现
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
db = MySQLDatabase(
|
||||||
|
host="localhost",
|
||||||
|
port=3306,
|
||||||
|
user="root",
|
||||||
|
password="password",
|
||||||
|
database="astrbot"
|
||||||
|
)
|
||||||
|
await db.initialize()
|
||||||
|
"""
|
||||||
|
|
||||||
|
database_type = DatabaseType.MYSQL
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 3306,
|
||||||
|
user: str = "root",
|
||||||
|
password: str = "",
|
||||||
|
database: str = "astrbot",
|
||||||
|
charset: str = "utf8mb4",
|
||||||
|
) -> None:
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.user = user
|
||||||
|
self.password = password
|
||||||
|
self.database = database
|
||||||
|
self.charset = charset
|
||||||
|
self.DATABASE_URL = (
|
||||||
|
f"mysql+aiomysql://{user}:{password}@{host}:{port}/{database}"
|
||||||
|
f"?charset={charset}"
|
||||||
|
)
|
||||||
|
self.inited = False
|
||||||
|
self._current_loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _recreate_engine(self) -> None:
|
||||||
|
"""重新创建数据库引擎和会话工厂,用于处理事件循环切换的情况"""
|
||||||
|
self.engine = create_async_engine(
|
||||||
|
self.DATABASE_URL,
|
||||||
|
echo=False,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
self.AsyncSessionLocal = sessionmaker(
|
||||||
|
self.engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Get a database session.
|
||||||
|
|
||||||
|
此方法会检查当前事件循环,如果事件循环发生变化会重新创建引擎,
|
||||||
|
以解决 aiomysql 的 "attached to a different loop" 问题。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
current_loop = None
|
||||||
|
|
||||||
|
# 检查事件循环是否变化,如果变化则重新创建引擎
|
||||||
|
if current_loop is not None and self._current_loop != current_loop:
|
||||||
|
self._recreate_engine()
|
||||||
|
self._current_loop = current_loop
|
||||||
|
self.inited = False # 需要重新初始化
|
||||||
|
|
||||||
|
if not self.inited:
|
||||||
|
await self.initialize()
|
||||||
|
self.inited = True
|
||||||
|
|
||||||
|
async with self.AsyncSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the database by creating tables if they do not exist."""
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
await conn.commit()
|
||||||
|
|
||||||
|
# ====
|
||||||
|
# Platform Statistics
|
||||||
|
# ====
|
||||||
|
|
||||||
|
async def insert_platform_stats(
|
||||||
|
self,
|
||||||
|
platform_id,
|
||||||
|
platform_type,
|
||||||
|
count=1,
|
||||||
|
timestamp=None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert a new platform statistic record."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = datetime.now().replace(
|
||||||
|
minute=0,
|
||||||
|
second=0,
|
||||||
|
microsecond=0,
|
||||||
|
)
|
||||||
|
current_hour = timestamp
|
||||||
|
await session.execute(
|
||||||
|
text("""
|
||||||
|
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
|
||||||
|
VALUES (:timestamp, :platform_id, :platform_type, :count)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
count = count + VALUES(count)
|
||||||
|
"""),
|
||||||
|
{
|
||||||
|
"timestamp": current_hour,
|
||||||
|
"platform_id": platform_id,
|
||||||
|
"platform_type": platform_type,
|
||||||
|
"count": count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def count_platform_stats(self) -> int:
|
||||||
|
"""Count the number of platform statistics records."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
result = await session.execute(
|
||||||
|
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||||
|
)
|
||||||
|
count = result.scalar_one_or_none()
|
||||||
|
return count or 0
|
||||||
|
|
||||||
|
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||||
|
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
now = datetime.now()
|
||||||
|
start_time = now - timedelta(seconds=offset_sec)
|
||||||
|
result = await session.execute(
|
||||||
|
text("""
|
||||||
|
SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts
|
||||||
|
FROM platform_stats
|
||||||
|
WHERE timestamp >= :start_time
|
||||||
|
GROUP BY platform_id, platform_type
|
||||||
|
ORDER BY latest_ts DESC
|
||||||
|
"""),
|
||||||
|
{"start_time": start_time},
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
return [
|
||||||
|
PlatformStat(
|
||||||
|
id=0,
|
||||||
|
platform_id=row.platform_id,
|
||||||
|
platform_type=row.platform_type,
|
||||||
|
count=row.total_count,
|
||||||
|
timestamp=row.latest_ts,
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_platform_stats_time_series(
|
||||||
|
self, offset_sec: int = 86400
|
||||||
|
) -> list[tuple[int, int]]:
|
||||||
|
"""Get platform statistics time series data grouped by hour."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
now = datetime.now()
|
||||||
|
start_time = now - timedelta(seconds=offset_sec)
|
||||||
|
result = await session.execute(
|
||||||
|
text("""
|
||||||
|
SELECT UNIX_TIMESTAMP(DATE_FORMAT(timestamp, '%Y-%m-%d %H:00:00')) as hour_ts, SUM(count) as total_count
|
||||||
|
FROM platform_stats
|
||||||
|
WHERE timestamp >= :start_time
|
||||||
|
GROUP BY hour_ts
|
||||||
|
ORDER BY hour_ts ASC
|
||||||
|
"""),
|
||||||
|
{"start_time": start_time},
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
return [(int(row.hour_ts), row.total_count) for row in rows]
|
||||||
|
|
||||||
|
# ====
|
||||||
|
# Conversation Management
|
||||||
|
# ====
|
||||||
|
|
||||||
|
async def get_conversations(self, user_id=None, platform_id=None):
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(ConversationV2)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
query = query.where(ConversationV2.user_id == user_id)
|
||||||
|
if platform_id:
|
||||||
|
query = query.where(ConversationV2.platform_id == platform_id)
|
||||||
|
# order by
|
||||||
|
query = query.order_by(desc(ConversationV2.created_at))
|
||||||
|
result = await session.execute(query)
|
||||||
|
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_conversation_by_id(self, cid):
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_all_conversations(self, page=1, page_size=20):
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
result = await session.execute(
|
||||||
|
select(ConversationV2)
|
||||||
|
.order_by(desc(ConversationV2.created_at))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(page_size),
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_filtered_conversations(
|
||||||
|
self,
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
platform_ids=None,
|
||||||
|
search_query="",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
# Build the base query with filters
|
||||||
|
base_query = select(ConversationV2)
|
||||||
|
|
||||||
|
if platform_ids:
|
||||||
|
base_query = base_query.where(
|
||||||
|
col(ConversationV2.platform_id).in_(platform_ids),
|
||||||
|
)
|
||||||
|
if search_query:
|
||||||
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
|
base_query = base_query.where(
|
||||||
|
or_(
|
||||||
|
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
||||||
|
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
||||||
|
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
||||||
|
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||||
|
for msg_type in kwargs["message_types"]:
|
||||||
|
base_query = base_query.where(
|
||||||
|
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
|
||||||
|
)
|
||||||
|
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||||
|
base_query = base_query.where(
|
||||||
|
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get total count matching the filters
|
||||||
|
count_query = select(func.count()).select_from(base_query.subquery())
|
||||||
|
total_count = await session.execute(count_query)
|
||||||
|
total = total_count.scalar_one()
|
||||||
|
|
||||||
|
# Get paginated results
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
result_query = (
|
||||||
|
base_query.order_by(desc(ConversationV2.created_at))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(page_size)
|
||||||
|
)
|
||||||
|
result = await session.execute(result_query)
|
||||||
|
conversations = result.scalars().all()
|
||||||
|
|
||||||
|
return conversations, total
|
||||||
|
|
||||||
|
async def create_conversation(
|
||||||
|
self,
|
||||||
|
user_id,
|
||||||
|
platform_id,
|
||||||
|
content=None,
|
||||||
|
title=None,
|
||||||
|
persona_id=None,
|
||||||
|
cid=None,
|
||||||
|
created_at=None,
|
||||||
|
updated_at=None,
|
||||||
|
):
|
||||||
|
kwargs = {}
|
||||||
|
if cid:
|
||||||
|
kwargs["conversation_id"] = cid
|
||||||
|
if created_at:
|
||||||
|
kwargs["created_at"] = created_at
|
||||||
|
if updated_at:
|
||||||
|
kwargs["updated_at"] = updated_at
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
new_conversation = ConversationV2(
|
||||||
|
user_id=user_id,
|
||||||
|
content=content or [],
|
||||||
|
platform_id=platform_id,
|
||||||
|
title=title,
|
||||||
|
persona_id=persona_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
session.add(new_conversation)
|
||||||
|
return new_conversation
|
||||||
|
|
||||||
|
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
query = update(ConversationV2).where(
|
||||||
|
col(ConversationV2.conversation_id) == cid,
|
||||||
|
)
|
||||||
|
values = {}
|
||||||
|
if title is not None:
|
||||||
|
values["title"] = title
|
||||||
|
if persona_id is not None:
|
||||||
|
values["persona_id"] = persona_id
|
||||||
|
if content is not None:
|
||||||
|
values["content"] = content
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
query = query.values(**values)
|
||||||
|
await session.execute(query)
|
||||||
|
return await self.get_conversation_by_id(cid)
|
||||||
|
|
||||||
|
async def delete_conversation(self, cid):
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(
|
||||||
|
delete(ConversationV2).where(
|
||||||
|
col(ConversationV2.conversation_id) == cid,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(
|
||||||
|
delete(ConversationV2).where(
|
||||||
|
col(ConversationV2.user_id) == user_id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_session_conversations(
|
||||||
|
self,
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
search_query=None,
|
||||||
|
platform=None,
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
"""Get paginated session conversations with joined conversation and persona details."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# MySQL 使用 JSON_EXTRACT 函数(与 SQLite 的 json_extract 兼容)
|
||||||
|
base_query = (
|
||||||
|
select(
|
||||||
|
col(Preference.scope_id).label("session_id"),
|
||||||
|
func.json_extract(Preference.value, "$.val").label(
|
||||||
|
"conversation_id",
|
||||||
|
), # type: ignore
|
||||||
|
col(ConversationV2.persona_id).label("persona_id"),
|
||||||
|
col(ConversationV2.title).label("title"),
|
||||||
|
col(Persona.persona_id).label("persona_name"),
|
||||||
|
)
|
||||||
|
.select_from(Preference)
|
||||||
|
.outerjoin(
|
||||||
|
ConversationV2,
|
||||||
|
func.json_extract(Preference.value, "$.val")
|
||||||
|
== ConversationV2.conversation_id,
|
||||||
|
)
|
||||||
|
.outerjoin(
|
||||||
|
Persona,
|
||||||
|
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||||
|
)
|
||||||
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 搜索筛选
|
||||||
|
if search_query:
|
||||||
|
search_pattern = f"%{search_query}%"
|
||||||
|
base_query = base_query.where(
|
||||||
|
or_(
|
||||||
|
col(Preference.scope_id).ilike(search_pattern),
|
||||||
|
col(ConversationV2.title).ilike(search_pattern),
|
||||||
|
col(Persona.persona_id).ilike(search_pattern),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 平台筛选
|
||||||
|
if platform:
|
||||||
|
platform_pattern = f"{platform}:%"
|
||||||
|
base_query = base_query.where(
|
||||||
|
col(Preference.scope_id).like(platform_pattern),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 排序
|
||||||
|
base_query = base_query.order_by(Preference.scope_id)
|
||||||
|
|
||||||
|
# 分页结果
|
||||||
|
result_query = base_query.offset(offset).limit(page_size)
|
||||||
|
result = await session.execute(result_query)
|
||||||
|
rows = result.fetchall()
|
||||||
|
|
||||||
|
# 查询总数(应用相同的筛选条件)
|
||||||
|
count_base_query = (
|
||||||
|
select(func.count(col(Preference.scope_id)))
|
||||||
|
.select_from(Preference)
|
||||||
|
.outerjoin(
|
||||||
|
ConversationV2,
|
||||||
|
func.json_extract(Preference.value, "$.val")
|
||||||
|
== ConversationV2.conversation_id,
|
||||||
|
)
|
||||||
|
.outerjoin(
|
||||||
|
Persona,
|
||||||
|
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||||
|
)
|
||||||
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用相同的搜索和平台筛选条件到计数查询
|
||||||
|
if search_query:
|
||||||
|
search_pattern = f"%{search_query}%"
|
||||||
|
count_base_query = count_base_query.where(
|
||||||
|
or_(
|
||||||
|
col(Preference.scope_id).ilike(search_pattern),
|
||||||
|
col(ConversationV2.title).ilike(search_pattern),
|
||||||
|
col(Persona.persona_id).ilike(search_pattern),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if platform:
|
||||||
|
platform_pattern = f"{platform}:%"
|
||||||
|
count_base_query = count_base_query.where(
|
||||||
|
col(Preference.scope_id).like(platform_pattern),
|
||||||
|
)
|
||||||
|
|
||||||
|
total_result = await session.execute(count_base_query)
|
||||||
|
total = total_result.scalar() or 0
|
||||||
|
|
||||||
|
sessions_data = [
|
||||||
|
{
|
||||||
|
"session_id": row.session_id,
|
||||||
|
"conversation_id": row.conversation_id,
|
||||||
|
"persona_id": row.persona_id,
|
||||||
|
"title": row.title,
|
||||||
|
"persona_name": row.persona_name,
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
return sessions_data, total
|
||||||
|
|
||||||
|
async def insert_platform_message_history(
|
||||||
|
self,
|
||||||
|
platform_id,
|
||||||
|
user_id,
|
||||||
|
content,
|
||||||
|
sender_id=None,
|
||||||
|
sender_name=None,
|
||||||
|
):
|
||||||
|
"""Insert a new platform message history record."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
new_history = PlatformMessageHistory(
|
||||||
|
platform_id=platform_id,
|
||||||
|
user_id=user_id,
|
||||||
|
content=content,
|
||||||
|
sender_id=sender_id,
|
||||||
|
sender_name=sender_name,
|
||||||
|
)
|
||||||
|
session.add(new_history)
|
||||||
|
return new_history
|
||||||
|
|
||||||
|
async def delete_platform_message_offset(
|
||||||
|
self,
|
||||||
|
platform_id,
|
||||||
|
user_id,
|
||||||
|
offset_sec=86400,
|
||||||
|
):
|
||||||
|
"""Delete platform message history records newer than the specified offset."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
now = datetime.now()
|
||||||
|
cutoff_time = now - timedelta(seconds=offset_sec)
|
||||||
|
await session.execute(
|
||||||
|
delete(PlatformMessageHistory).where(
|
||||||
|
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||||
|
col(PlatformMessageHistory.user_id) == user_id,
|
||||||
|
col(PlatformMessageHistory.created_at) >= cutoff_time,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_platform_message_history(
|
||||||
|
self,
|
||||||
|
platform_id,
|
||||||
|
user_id,
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
):
|
||||||
|
"""Get platform message history records."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
query = (
|
||||||
|
select(PlatformMessageHistory)
|
||||||
|
.where(
|
||||||
|
PlatformMessageHistory.platform_id == platform_id,
|
||||||
|
PlatformMessageHistory.user_id == user_id,
|
||||||
|
)
|
||||||
|
.order_by(desc(PlatformMessageHistory.created_at))
|
||||||
|
)
|
||||||
|
result = await session.execute(query.offset(offset).limit(page_size))
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_platform_message_history_by_id(
|
||||||
|
self, message_id: int
|
||||||
|
) -> PlatformMessageHistory | None:
|
||||||
|
"""Get a platform message history record by its ID."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(PlatformMessageHistory).where(
|
||||||
|
PlatformMessageHistory.id == message_id
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def insert_attachment(self, path, type, mime_type):
|
||||||
|
"""Insert a new attachment record."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
new_attachment = Attachment(
|
||||||
|
path=path,
|
||||||
|
type=type,
|
||||||
|
mime_type=mime_type,
|
||||||
|
)
|
||||||
|
session.add(new_attachment)
|
||||||
|
return new_attachment
|
||||||
|
|
||||||
|
async def get_attachment_by_id(self, attachment_id):
|
||||||
|
"""Get an attachment by its ID."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_attachments(self, attachment_ids: list[str]) -> list:
|
||||||
|
"""Get multiple attachments by their IDs."""
|
||||||
|
if not attachment_ids:
|
||||||
|
return []
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(Attachment).where(
|
||||||
|
Attachment.attachment_id.in_(attachment_ids)
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||||
|
"""Delete an attachment by its ID.
|
||||||
|
|
||||||
|
Returns True if the attachment was deleted, False if it was not found.
|
||||||
|
"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
query = delete(Attachment).where(
|
||||||
|
col(Attachment.attachment_id) == attachment_id
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||||
|
"""Delete multiple attachments by their IDs.
|
||||||
|
|
||||||
|
Returns the number of attachments deleted.
|
||||||
|
"""
|
||||||
|
if not attachment_ids:
|
||||||
|
return 0
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
query = delete(Attachment).where(
|
||||||
|
col(Attachment.attachment_id).in_(attachment_ids)
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
async def insert_persona(
|
||||||
|
self,
|
||||||
|
persona_id,
|
||||||
|
system_prompt,
|
||||||
|
begin_dialogs=None,
|
||||||
|
tools=None,
|
||||||
|
):
|
||||||
|
"""Insert a new persona record."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
new_persona = Persona(
|
||||||
|
persona_id=persona_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
begin_dialogs=begin_dialogs or [],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
session.add(new_persona)
|
||||||
|
return new_persona
|
||||||
|
|
||||||
|
async def get_persona_by_id(self, persona_id):
|
||||||
|
"""Get a persona by its ID."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(Persona).where(Persona.persona_id == persona_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_personas(self):
|
||||||
|
"""Get all personas for a specific bot."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(Persona)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def update_persona(
|
||||||
|
self,
|
||||||
|
persona_id,
|
||||||
|
system_prompt=None,
|
||||||
|
begin_dialogs=None,
|
||||||
|
tools=NOT_GIVEN,
|
||||||
|
):
|
||||||
|
"""Update a persona's system prompt or begin dialogs."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
||||||
|
values = {}
|
||||||
|
if system_prompt is not None:
|
||||||
|
values["system_prompt"] = system_prompt
|
||||||
|
if begin_dialogs is not None:
|
||||||
|
values["begin_dialogs"] = begin_dialogs
|
||||||
|
if tools is not NOT_GIVEN:
|
||||||
|
values["tools"] = tools
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
query = query.values(**values)
|
||||||
|
await session.execute(query)
|
||||||
|
return await self.get_persona_by_id(persona_id)
|
||||||
|
|
||||||
|
async def delete_persona(self, persona_id):
|
||||||
|
"""Delete a persona by its ID."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(
|
||||||
|
delete(Persona).where(col(Persona.persona_id) == persona_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||||
|
"""Insert a new preference record or update if it exists."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
query = select(Preference).where(
|
||||||
|
Preference.scope == scope,
|
||||||
|
Preference.scope_id == scope_id,
|
||||||
|
Preference.key == key,
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
existing_preference = result.scalar_one_or_none()
|
||||||
|
if existing_preference:
|
||||||
|
existing_preference.value = value
|
||||||
|
else:
|
||||||
|
new_preference = Preference(
|
||||||
|
scope=scope,
|
||||||
|
scope_id=scope_id,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
session.add(new_preference)
|
||||||
|
return existing_preference or new_preference
|
||||||
|
|
||||||
|
async def get_preference(self, scope, scope_id, key):
|
||||||
|
"""Get a preference by key."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(Preference).where(
|
||||||
|
Preference.scope == scope,
|
||||||
|
Preference.scope_id == scope_id,
|
||||||
|
Preference.key == key,
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_preferences(self, scope, scope_id=None, key=None):
|
||||||
|
"""Get all preferences for a specific scope ID or key."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(Preference).where(Preference.scope == scope)
|
||||||
|
if scope_id is not None:
|
||||||
|
query = query.where(Preference.scope_id == scope_id)
|
||||||
|
if key is not None:
|
||||||
|
query = query.where(Preference.key == key)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def remove_preference(self, scope, scope_id, key):
|
||||||
|
"""Remove a preference by scope ID and key."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(
|
||||||
|
delete(Preference).where(
|
||||||
|
col(Preference.scope) == scope,
|
||||||
|
col(Preference.scope_id) == scope_id,
|
||||||
|
col(Preference.key) == key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def clear_preferences(self, scope, scope_id):
|
||||||
|
"""Clear all preferences for a specific scope ID."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(
|
||||||
|
delete(Preference).where(
|
||||||
|
col(Preference.scope) == scope,
|
||||||
|
col(Preference.scope_id) == scope_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# ====
|
||||||
|
# Platform Session Management
|
||||||
|
# ====
|
||||||
|
|
||||||
|
async def create_platform_session(
|
||||||
|
self,
|
||||||
|
creator: str,
|
||||||
|
platform_id: str = "webchat",
|
||||||
|
session_id: str | None = None,
|
||||||
|
display_name: str | None = None,
|
||||||
|
is_group: int = 0,
|
||||||
|
) -> PlatformSession:
|
||||||
|
"""Create a new Platform session."""
|
||||||
|
kwargs = {}
|
||||||
|
if session_id:
|
||||||
|
kwargs["session_id"] = session_id
|
||||||
|
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
new_session = PlatformSession(
|
||||||
|
creator=creator,
|
||||||
|
platform_id=platform_id,
|
||||||
|
display_name=display_name,
|
||||||
|
is_group=is_group,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
session.add(new_session)
|
||||||
|
await session.flush()
|
||||||
|
await session.refresh(new_session)
|
||||||
|
return new_session
|
||||||
|
|
||||||
|
async def get_platform_session_by_id(
|
||||||
|
self, session_id: str
|
||||||
|
) -> PlatformSession | None:
|
||||||
|
"""Get a Platform session by its ID."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(PlatformSession).where(
|
||||||
|
PlatformSession.session_id == session_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_platform_sessions_by_creator(
|
||||||
|
self,
|
||||||
|
creator: str,
|
||||||
|
platform_id: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
) -> list[PlatformSession]:
|
||||||
|
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
query = select(PlatformSession).where(PlatformSession.creator == creator)
|
||||||
|
|
||||||
|
if platform_id:
|
||||||
|
query = query.where(PlatformSession.platform_id == platform_id)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
query.order_by(desc(PlatformSession.updated_at))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(page_size)
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def update_platform_session(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
display_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
||||||
|
if display_name is not None:
|
||||||
|
values["display_name"] = display_name
|
||||||
|
|
||||||
|
await session.execute(
|
||||||
|
update(PlatformSession)
|
||||||
|
.where(col(PlatformSession.session_id) == session_id)
|
||||||
|
.values(**values),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_platform_session(self, session_id: str) -> None:
|
||||||
|
"""Delete a Platform session by its ID."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(
|
||||||
|
delete(PlatformSession).where(
|
||||||
|
col(PlatformSession.session_id) == session_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ====
|
||||||
|
# Deprecated Methods
|
||||||
|
# ====
|
||||||
|
|
||||||
|
def get_base_stats(self, offset_sec=86400):
|
||||||
|
"""Get base statistics within the specified offset in seconds."""
|
||||||
|
return DeprecatedStats()
|
||||||
|
|
||||||
|
def get_total_message_count(self):
|
||||||
|
"""Get the total message count from platform statistics."""
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_grouped_base_stats(self, offset_sec=86400):
|
||||||
|
# group by platform_id
|
||||||
|
return DeprecatedStats()
|
||||||
+137
-105
@@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||||
|
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase, DatabaseType
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import (
|
||||||
Attachment,
|
Attachment,
|
||||||
ConversationV2,
|
ConversationV2,
|
||||||
@@ -28,6 +28,8 @@ NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
|||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase(BaseDatabase):
|
class SQLiteDatabase(BaseDatabase):
|
||||||
|
database_type = DatabaseType.SQLITE
|
||||||
|
|
||||||
def __init__(self, db_path: str) -> None:
|
def __init__(self, db_path: str) -> None:
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||||
@@ -88,12 +90,10 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||||
PlatformStat,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
count = result.scalar_one_or_none()
|
count = result.scalar_one_or_none()
|
||||||
return count if count is not None else 0
|
return count or 0
|
||||||
|
|
||||||
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||||
@@ -103,14 +103,46 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
start_time = now - timedelta(seconds=offset_sec)
|
start_time = now - timedelta(seconds=offset_sec)
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
text("""
|
text("""
|
||||||
SELECT * FROM platform_stats
|
SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts
|
||||||
|
FROM platform_stats
|
||||||
WHERE timestamp >= :start_time
|
WHERE timestamp >= :start_time
|
||||||
GROUP BY platform_id
|
GROUP BY platform_id, platform_type
|
||||||
ORDER BY timestamp DESC
|
ORDER BY latest_ts DESC
|
||||||
"""),
|
"""),
|
||||||
{"start_time": start_time},
|
{"start_time": start_time},
|
||||||
)
|
)
|
||||||
return list(result.scalars().all())
|
rows = result.fetchall()
|
||||||
|
return [
|
||||||
|
PlatformStat(
|
||||||
|
id=0,
|
||||||
|
platform_id=row.platform_id,
|
||||||
|
platform_type=row.platform_type,
|
||||||
|
count=row.total_count,
|
||||||
|
timestamp=row.latest_ts,
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_platform_stats_time_series(
|
||||||
|
self, offset_sec: int = 86400
|
||||||
|
) -> list[tuple[int, int]]:
|
||||||
|
"""Get platform statistics time series data grouped by hour."""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
now = datetime.now()
|
||||||
|
start_time = now - timedelta(seconds=offset_sec)
|
||||||
|
result = await session.execute(
|
||||||
|
text("""
|
||||||
|
SELECT strftime('%s', datetime(timestamp, 'start of hour')) as hour_ts, SUM(count) as total_count
|
||||||
|
FROM platform_stats
|
||||||
|
WHERE timestamp >= :start_time
|
||||||
|
GROUP BY hour_ts
|
||||||
|
ORDER BY hour_ts ASC
|
||||||
|
"""),
|
||||||
|
{"start_time": start_time},
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
return [(int(row.hour_ts), row.total_count) for row in rows]
|
||||||
|
|
||||||
# ====
|
# ====
|
||||||
# Conversation Management
|
# Conversation Management
|
||||||
@@ -669,102 +701,6 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# ====
|
|
||||||
# Deprecated Methods
|
|
||||||
# ====
|
|
||||||
|
|
||||||
def get_base_stats(self, offset_sec=86400):
|
|
||||||
"""Get base statistics within the specified offset in seconds."""
|
|
||||||
|
|
||||||
async def _inner():
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
now = datetime.now()
|
|
||||||
start_time = now - timedelta(seconds=offset_sec)
|
|
||||||
result = await session.execute(
|
|
||||||
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
|
||||||
)
|
|
||||||
all_datas = result.scalars().all()
|
|
||||||
deprecated_stats = DeprecatedStats()
|
|
||||||
for data in all_datas:
|
|
||||||
deprecated_stats.platform.append(
|
|
||||||
DeprecatedPlatformStat(
|
|
||||||
name=data.platform_id,
|
|
||||||
count=data.count,
|
|
||||||
timestamp=int(data.timestamp.timestamp()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return deprecated_stats
|
|
||||||
|
|
||||||
result = None
|
|
||||||
|
|
||||||
def runner():
|
|
||||||
nonlocal result
|
|
||||||
result = asyncio.run(_inner())
|
|
||||||
|
|
||||||
t = threading.Thread(target=runner)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_total_message_count(self):
|
|
||||||
"""Get the total message count from platform statistics."""
|
|
||||||
|
|
||||||
async def _inner():
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
result = await session.execute(
|
|
||||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
|
||||||
)
|
|
||||||
total_count = result.scalar_one_or_none()
|
|
||||||
return total_count if total_count is not None else 0
|
|
||||||
|
|
||||||
result = None
|
|
||||||
|
|
||||||
def runner():
|
|
||||||
nonlocal result
|
|
||||||
result = asyncio.run(_inner())
|
|
||||||
|
|
||||||
t = threading.Thread(target=runner)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_grouped_base_stats(self, offset_sec=86400):
|
|
||||||
# group by platform_id
|
|
||||||
async def _inner():
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
now = datetime.now()
|
|
||||||
start_time = now - timedelta(seconds=offset_sec)
|
|
||||||
result = await session.execute(
|
|
||||||
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
|
||||||
.where(PlatformStat.timestamp >= start_time)
|
|
||||||
.group_by(PlatformStat.platform_id),
|
|
||||||
)
|
|
||||||
grouped_stats = result.all()
|
|
||||||
deprecated_stats = DeprecatedStats()
|
|
||||||
for platform_id, count in grouped_stats:
|
|
||||||
deprecated_stats.platform.append(
|
|
||||||
DeprecatedPlatformStat(
|
|
||||||
name=platform_id,
|
|
||||||
count=count,
|
|
||||||
timestamp=int(start_time.timestamp()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return deprecated_stats
|
|
||||||
|
|
||||||
result = None
|
|
||||||
|
|
||||||
def runner():
|
|
||||||
nonlocal result
|
|
||||||
result = asyncio.run(_inner())
|
|
||||||
|
|
||||||
t = threading.Thread(target=runner)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
return result
|
|
||||||
|
|
||||||
# ====
|
# ====
|
||||||
# Platform Session Management
|
# Platform Session Management
|
||||||
# ====
|
# ====
|
||||||
@@ -862,3 +798,99 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
col(PlatformSession.session_id) == session_id,
|
col(PlatformSession.session_id) == session_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ====
|
||||||
|
# Deprecated Methods
|
||||||
|
# ====
|
||||||
|
|
||||||
|
def get_base_stats(self, offset_sec=86400):
|
||||||
|
"""Get base statistics within the specified offset in seconds."""
|
||||||
|
|
||||||
|
async def _inner():
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
now = datetime.now()
|
||||||
|
start_time = now - timedelta(seconds=offset_sec)
|
||||||
|
result = await session.execute(
|
||||||
|
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
||||||
|
)
|
||||||
|
all_datas = result.scalars().all()
|
||||||
|
deprecated_stats = DeprecatedStats()
|
||||||
|
for data in all_datas:
|
||||||
|
deprecated_stats.platform.append(
|
||||||
|
DeprecatedPlatformStat(
|
||||||
|
name=data.platform_id,
|
||||||
|
count=data.count,
|
||||||
|
timestamp=int(data.timestamp.timestamp()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return deprecated_stats
|
||||||
|
|
||||||
|
result = None
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
nonlocal result
|
||||||
|
result = asyncio.run(_inner())
|
||||||
|
|
||||||
|
t = threading.Thread(target=runner)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_total_message_count(self):
|
||||||
|
"""Get the total message count from platform statistics."""
|
||||||
|
|
||||||
|
async def _inner():
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
result = await session.execute(
|
||||||
|
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||||
|
)
|
||||||
|
total_count = result.scalar_one_or_none()
|
||||||
|
return total_count if total_count is not None else 0
|
||||||
|
|
||||||
|
result = None
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
nonlocal result
|
||||||
|
result = asyncio.run(_inner())
|
||||||
|
|
||||||
|
t = threading.Thread(target=runner)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_grouped_base_stats(self, offset_sec=86400):
|
||||||
|
# group by platform_id
|
||||||
|
async def _inner():
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
now = datetime.now()
|
||||||
|
start_time = now - timedelta(seconds=offset_sec)
|
||||||
|
result = await session.execute(
|
||||||
|
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
||||||
|
.where(PlatformStat.timestamp >= start_time)
|
||||||
|
.group_by(PlatformStat.platform_id),
|
||||||
|
)
|
||||||
|
grouped_stats = result.all()
|
||||||
|
deprecated_stats = DeprecatedStats()
|
||||||
|
for platform_id, count in grouped_stats:
|
||||||
|
deprecated_stats.platform.append(
|
||||||
|
DeprecatedPlatformStat(
|
||||||
|
name=platform_id,
|
||||||
|
count=count,
|
||||||
|
timestamp=int(start_time.timestamp()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return deprecated_stats
|
||||||
|
|
||||||
|
result = None
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
nonlocal result
|
||||||
|
result = asyncio.run(_inner())
|
||||||
|
|
||||||
|
t = threading.Thread(target=runner)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
return result
|
||||||
|
|||||||
@@ -84,23 +84,14 @@ class StatRoute(Route):
|
|||||||
offset_sec = request.args.get("offset_sec", 86400)
|
offset_sec = request.args.get("offset_sec", 86400)
|
||||||
offset_sec = int(offset_sec)
|
offset_sec = int(offset_sec)
|
||||||
try:
|
try:
|
||||||
stat = self.db_helper.get_base_stats(offset_sec)
|
platform_stats = await self.db_helper.get_platform_stats(offset_sec)
|
||||||
now = int(time.time())
|
|
||||||
start_time = now - offset_sec
|
|
||||||
message_time_based_stats = []
|
|
||||||
|
|
||||||
idx = 0
|
# 获取时间序列数据(按小时分桶)
|
||||||
for bucket_end in range(start_time, now, 3600):
|
time_series = await self.db_helper.get_platform_stats_time_series(
|
||||||
cnt = 0
|
offset_sec
|
||||||
while (
|
)
|
||||||
idx < len(stat.platform)
|
message_time_based_stats = [[ts, cnt] for ts, cnt in time_series]
|
||||||
and stat.platform[idx].timestamp < bucket_end
|
message_count = await self.db_helper.count_platform_stats()
|
||||||
):
|
|
||||||
cnt += stat.platform[idx].count
|
|
||||||
idx += 1
|
|
||||||
message_time_based_stats.append([bucket_end, cnt])
|
|
||||||
|
|
||||||
stat_dict = stat.__dict__
|
|
||||||
|
|
||||||
cpu_percent = psutil.cpu_percent(interval=0.5)
|
cpu_percent = psutil.cpu_percent(interval=0.5)
|
||||||
thread_count = threading.active_count()
|
thread_count = threading.active_count()
|
||||||
@@ -121,28 +112,36 @@ class StatRoute(Route):
|
|||||||
int(time.time()) - self.core_lifecycle.start_time,
|
int(time.time()) - self.core_lifecycle.start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
stat_dict.update(
|
# 构建平台统计数据
|
||||||
|
platform_data = [
|
||||||
{
|
{
|
||||||
"platform": self.db_helper.get_grouped_base_stats(
|
"name": stat.platform_id,
|
||||||
offset_sec,
|
"count": stat.count,
|
||||||
).platform,
|
"timestamp": int(stat.timestamp.timestamp())
|
||||||
"message_count": self.db_helper.get_total_message_count() or 0,
|
if stat.timestamp
|
||||||
"platform_count": len(
|
else 0,
|
||||||
self.core_lifecycle.platform_manager.get_insts(),
|
}
|
||||||
),
|
for stat in platform_stats
|
||||||
"plugin_count": len(plugins),
|
]
|
||||||
"plugins": plugin_info,
|
|
||||||
"message_time_series": message_time_based_stats,
|
stat_dict = {
|
||||||
"running": running_time, # 现在返回时间组件而不是格式化的字符串
|
"platform": platform_data,
|
||||||
"memory": {
|
"message_count": message_count,
|
||||||
"process": psutil.Process().memory_info().rss >> 20,
|
"platform_count": len(
|
||||||
"system": psutil.virtual_memory().total >> 20,
|
self.core_lifecycle.platform_manager.get_insts(),
|
||||||
},
|
),
|
||||||
"cpu_percent": round(cpu_percent, 1),
|
"plugin_count": len(plugins),
|
||||||
"thread_count": thread_count,
|
"plugins": plugin_info,
|
||||||
"start_time": self.core_lifecycle.start_time,
|
"message_time_series": message_time_based_stats,
|
||||||
|
"running": running_time, # 现在返回时间组件而不是格式化的字符串
|
||||||
|
"memory": {
|
||||||
|
"process": psutil.Process().memory_info().rss >> 20,
|
||||||
|
"system": psutil.virtual_memory().total >> 20,
|
||||||
},
|
},
|
||||||
)
|
"cpu_percent": round(cpu_percent, 1),
|
||||||
|
"thread_count": thread_count,
|
||||||
|
"start_time": self.core_lifecycle.start_time,
|
||||||
|
}
|
||||||
|
|
||||||
return Response().ok(stat_dict).__dict__
|
return Response().ok(stat_dict).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ dependencies = [
|
|||||||
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
||||||
"xinference-client",
|
"xinference-client",
|
||||||
"tenacity>=9.1.2",
|
"tenacity>=9.1.2",
|
||||||
|
"aiomysql>=0.3.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|||||||
Reference in New Issue
Block a user