perf(db): 优化重构command相关数据库操作

This commit is contained in:
Oscar
2025-12-11 17:00:15 +08:00
parent 89fdb18936
commit 7e43cca134
+163 -100
View File
@@ -1,6 +1,7 @@
import asyncio
import threading
import typing as T
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
@@ -28,6 +29,7 @@ from astrbot.core.db.po import (
)
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
TxResult = T.TypeVar("TxResult")
class SQLiteDatabase(BaseDatabase):
@@ -676,6 +678,79 @@ class SQLiteDatabase(BaseDatabase):
# Command Configuration & Conflict Tracking
# ====
async def _run_in_tx(
self,
fn: Callable[[AsyncSession], Awaitable[TxResult]],
) -> TxResult:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
return await fn(session)
@staticmethod
def _apply_updates(model, **updates) -> None:
for field, value in updates.items():
if value is not None:
setattr(model, field, value)
@staticmethod
def _new_command_config(
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
return CommandConfig(
handler_full_name=handler_full_name,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=True if enabled is None else enabled,
keep_original_alias=False
if keep_original_alias is None
else keep_original_alias,
conflict_key=conflict_key or original_command,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=bool(auto_managed),
)
@staticmethod
def _new_command_conflict(
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
return CommandConflict(
conflict_key=conflict_key,
handler_full_name=handler_full_name,
plugin_name=plugin_name,
status=status or "pending",
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=bool(auto_generated),
)
async def get_command_configs(self) -> list[CommandConfig]:
async with self.get_db() as session:
session: AsyncSession
@@ -706,68 +781,60 @@ class SQLiteDatabase(BaseDatabase):
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
config = await session.get(CommandConfig, handler_full_name)
if not config:
config = CommandConfig(
handler_full_name=handler_full_name,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=enabled if enabled is not None else True,
keep_original_alias=keep_original_alias or False,
conflict_key=conflict_key or original_command,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=bool(auto_managed),
)
session.add(config)
else:
config.plugin_name = plugin_name or config.plugin_name
config.module_path = module_path or config.module_path
config.original_command = (
original_command or config.original_command
)
if resolved_command is not None:
config.resolved_command = resolved_command
if enabled is not None:
config.enabled = enabled
if keep_original_alias is not None:
config.keep_original_alias = keep_original_alias
if conflict_key is not None:
config.conflict_key = conflict_key
if resolution_strategy is not None:
config.resolution_strategy = resolution_strategy
if note is not None:
config.note = note
if extra_data is not None:
config.extra_data = extra_data
if auto_managed is not None:
config.auto_managed = auto_managed
await session.flush()
await session.refresh(config)
await session.commit()
async def _op(session: AsyncSession) -> CommandConfig:
config = await session.get(CommandConfig, handler_full_name)
if not config:
config = self._new_command_config(
handler_full_name,
plugin_name,
module_path,
original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
session.add(config)
else:
self._apply_updates(
config,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
await session.flush()
await session.refresh(config)
return config
return await self._run_in_tx(_op)
async def delete_command_config(self, handler_full_name: str) -> None:
await self.delete_command_configs([handler_full_name])
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
if not handler_full_names:
return
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(CommandConfig).where(
col(CommandConfig.handler_full_name).in_(handler_full_names),
),
)
await session.commit()
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConfig).where(
col(CommandConfig.handler_full_name).in_(handler_full_names),
),
)
await self._run_in_tx(_op)
async def list_command_conflicts(
self,
@@ -794,58 +861,54 @@ class SQLiteDatabase(BaseDatabase):
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
result = await session.execute(
select(CommandConflict).where(
CommandConflict.conflict_key == conflict_key,
CommandConflict.handler_full_name == handler_full_name,
),
async def _op(session: AsyncSession) -> CommandConflict:
result = await session.execute(
select(CommandConflict).where(
CommandConflict.conflict_key == conflict_key,
CommandConflict.handler_full_name == handler_full_name,
),
)
record = result.scalar_one_or_none()
if not record:
record = self._new_command_conflict(
conflict_key,
handler_full_name,
plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
record = result.scalar_one_or_none()
if not record:
record = CommandConflict(
conflict_key=conflict_key,
handler_full_name=handler_full_name,
plugin_name=plugin_name,
status=status or "pending",
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=bool(auto_generated),
)
session.add(record)
else:
record.plugin_name = plugin_name or record.plugin_name
if status is not None:
record.status = status
if resolution is not None:
record.resolution = resolution
if resolved_command is not None:
record.resolved_command = resolved_command
if note is not None:
record.note = note
if extra_data is not None:
record.extra_data = extra_data
if auto_generated is not None:
record.auto_generated = auto_generated
await session.flush()
await session.refresh(record)
await session.commit()
session.add(record)
else:
self._apply_updates(
record,
plugin_name=plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
await session.flush()
await session.refresh(record)
return record
return await self._run_in_tx(_op)
async def delete_command_conflicts(self, ids: list[int]) -> None:
if not ids:
return
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
)
await session.commit()
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
)
await self._run_in_tx(_op)
# ====
# Deprecated Methods