842c3c8ea9
* stage * stage * refactor: using sqlchemy as ORM framework, switch to async-based sqlite operation - using sqlmodel as ORM(based on sqlchemy and pydantic) - add Persona, Preference, PlatformMessageHistory table * fix: conversation * fix: remove redundant explicit session.commit, and fix some type error * fix: conversation context issue * chore: remove comments * chore: remove exclude_content param
249 lines
6.9 KiB
Python
249 lines
6.9 KiB
Python
import abc
|
|
import datetime
|
|
import typing as T
|
|
from deprecated import deprecated
|
|
from dataclasses import dataclass
|
|
from astrbot.core.db.po import (
|
|
Stats,
|
|
PlatformStat,
|
|
ConversationV2,
|
|
PlatformMessageHistory,
|
|
Attachment,
|
|
Persona,
|
|
Preference,
|
|
)
|
|
from contextlib import asynccontextmanager
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
@dataclass
|
|
class BaseDatabase(abc.ABC):
|
|
"""
|
|
数据库基类
|
|
"""
|
|
|
|
DATABASE_URL = ""
|
|
|
|
def __init__(self) -> None:
|
|
self.engine = create_async_engine(
|
|
self.DATABASE_URL,
|
|
echo=False,
|
|
future=True,
|
|
)
|
|
self.AsyncSessionLocal = sessionmaker(
|
|
self.engine, class_=AsyncSession, expire_on_commit=False
|
|
)
|
|
|
|
async def initialize(self):
|
|
"""初始化数据库连接"""
|
|
pass
|
|
|
|
@asynccontextmanager
|
|
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
|
"""Get a database session."""
|
|
if not self.inited:
|
|
await self.initialize()
|
|
self.inited = True
|
|
async with self.AsyncSessionLocal() as session:
|
|
yield session
|
|
|
|
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
|
@abc.abstractmethod
|
|
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
|
"""获取基础统计数据"""
|
|
raise NotImplementedError
|
|
|
|
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
|
@abc.abstractmethod
|
|
def get_total_message_count(self) -> int:
|
|
"""获取总消息数"""
|
|
raise NotImplementedError
|
|
|
|
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
|
@abc.abstractmethod
|
|
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
|
"""获取基础统计数据(合并)"""
|
|
raise NotImplementedError
|
|
|
|
# New methods in v4.0.0
|
|
|
|
@abc.abstractmethod
|
|
async def insert_platform_stats(
|
|
self,
|
|
platform_id: str,
|
|
platform_type: str,
|
|
count: int = 1,
|
|
timestamp: datetime.datetime = None,
|
|
) -> None:
|
|
"""Insert a new platform statistic record."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def count_platform_stats(self) -> int:
|
|
"""Count the number of platform statistics records."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
|
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_conversations(
|
|
self, user_id: str = None, platform_id: str = None
|
|
) -> list[ConversationV2]:
|
|
"""Get all conversations for a specific user and platform_id(optional).
|
|
|
|
content is not included in the result.
|
|
"""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
|
|
"""Get a specific conversation by its ID."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_all_conversations(
|
|
self, page: int = 1, page_size: int = 20
|
|
) -> list[ConversationV2]:
|
|
"""Get all conversations with pagination."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_filtered_conversations(
|
|
self,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
platform_ids: list[str] | None = None,
|
|
search_query: str = "",
|
|
**kwargs,
|
|
) -> tuple[list[ConversationV2], int]:
|
|
"""Get conversations filtered by platform IDs and search query."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def create_conversation(
|
|
self,
|
|
user_id: str,
|
|
platform_id: str,
|
|
content: list[dict] = None,
|
|
title: str = None,
|
|
persona_id: str = None,
|
|
cid: str = None,
|
|
created_at: datetime.datetime = None,
|
|
updated_at: datetime.datetime = None,
|
|
) -> ConversationV2:
|
|
"""Create a new conversation."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def update_conversation(
|
|
self,
|
|
cid: str,
|
|
title: str = None,
|
|
persona_id: str = None,
|
|
content: list[dict] = None,
|
|
) -> None:
|
|
"""Update a conversation's history."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def delete_conversation(self, cid: str) -> None:
|
|
"""Delete a conversation by its ID."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def insert_platform_message_history(
|
|
self,
|
|
platform_id: str,
|
|
user_id: str,
|
|
content: list[dict],
|
|
sender_id: str = None,
|
|
sender_name: str = None,
|
|
) -> None:
|
|
"""Insert a new platform message history record."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def delete_platform_message_offset(
|
|
self, platform_id: str, user_id: str, offset_sec: int = 86400
|
|
) -> None:
|
|
"""Delete platform message history records older than the specified offset."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_platform_message_history(
|
|
self,
|
|
platform_id: str,
|
|
user_id: str,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
) -> list[PlatformMessageHistory]:
|
|
"""Get platform message history for a specific user."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def insert_attachment(
|
|
self,
|
|
path: str,
|
|
type: str,
|
|
mime_type: str,
|
|
):
|
|
"""Insert a new attachment record."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
|
|
"""Get an attachment by its ID."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def insert_persona(
|
|
self,
|
|
persona_id: str,
|
|
system_prompt: str,
|
|
begin_dialogs: list[str] = None,
|
|
) -> Persona:
|
|
"""Insert a new persona record."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_persona_by_id(self, persona_id: str) -> Persona:
|
|
"""Get a persona by its ID."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_personas(self) -> list[Persona]:
|
|
"""Get all personas for a specific bot."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def insert_preference_or_update(self, key: str, value: str) -> Preference:
|
|
"""Insert a new preference record."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_preference(self, key: str) -> Preference:
|
|
"""Get a preference by bot ID and key."""
|
|
...
|
|
|
|
# @abc.abstractmethod
|
|
# async def insert_llm_message(
|
|
# self,
|
|
# cid: str,
|
|
# role: str,
|
|
# content: list,
|
|
# tool_calls: list = None,
|
|
# tool_call_id: str = None,
|
|
# parent_id: str = None,
|
|
# ) -> LLMMessage:
|
|
# """Insert a new LLM message into the conversation."""
|
|
# ...
|
|
|
|
# @abc.abstractmethod
|
|
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
|
# """Get all LLM messages for a specific conversation."""
|
|
# ...
|