From 7e43cca134ce5cde5256f3765f54529d98ce3542 Mon Sep 17 00:00:00 2001 From: Oscar Date: Thu, 11 Dec 2025 17:00:15 +0800 Subject: [PATCH] =?UTF-8?q?perf(db):=20=E4=BC=98=E5=8C=96=E9=87=8D?= =?UTF-8?q?=E6=9E=84command=E7=9B=B8=E5=85=B3=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E6=93=8D=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/db/sqlite.py | 263 +++++++++++++++++++++++--------------- 1 file changed, 163 insertions(+), 100 deletions(-) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 7203a40d1..fa3ca9a76 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,6 +1,7 @@ import asyncio import threading import typing as T +from collections.abc import Awaitable, Callable from datetime import datetime, timedelta, timezone from sqlalchemy import CursorResult @@ -28,6 +29,7 @@ from astrbot.core.db.po import ( ) NOT_GIVEN = T.TypeVar("NOT_GIVEN") +TxResult = T.TypeVar("TxResult") class SQLiteDatabase(BaseDatabase): @@ -676,6 +678,79 @@ class SQLiteDatabase(BaseDatabase): # Command Configuration & Conflict Tracking # ==== + async def _run_in_tx( + self, + fn: Callable[[AsyncSession], Awaitable[TxResult]], + ) -> TxResult: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + return await fn(session) + + @staticmethod + def _apply_updates(model, **updates) -> None: + for field, value in updates.items(): + if value is not None: + setattr(model, field, value) + + @staticmethod + def _new_command_config( + handler_full_name: str, + plugin_name: str, + module_path: str, + original_command: str, + *, + resolved_command: str | None = None, + enabled: bool | None = None, + keep_original_alias: bool | None = None, + conflict_key: str | None = None, + resolution_strategy: str | None = None, + note: str | None = None, + extra_data: dict | None = None, + auto_managed: bool | None = None, + ) -> CommandConfig: + return CommandConfig( + handler_full_name=handler_full_name, + plugin_name=plugin_name, + module_path=module_path, + original_command=original_command, + resolved_command=resolved_command, + enabled=True if enabled is None else enabled, + keep_original_alias=False + if keep_original_alias is None + else keep_original_alias, + conflict_key=conflict_key or original_command, + resolution_strategy=resolution_strategy, + note=note, + extra_data=extra_data, + auto_managed=bool(auto_managed), + ) + + @staticmethod + def _new_command_conflict( + conflict_key: str, + handler_full_name: str, + plugin_name: str, + *, + status: str | None = None, + resolution: str | None = None, + resolved_command: str | None = None, + note: str | None = None, + extra_data: dict | None = None, + auto_generated: bool | None = None, + ) -> CommandConflict: + return CommandConflict( + conflict_key=conflict_key, + handler_full_name=handler_full_name, + plugin_name=plugin_name, + status=status or "pending", + resolution=resolution, + resolved_command=resolved_command, + note=note, + extra_data=extra_data, + auto_generated=bool(auto_generated), + ) + async def get_command_configs(self) -> list[CommandConfig]: async with self.get_db() as session: session: AsyncSession @@ -706,68 +781,60 @@ class SQLiteDatabase(BaseDatabase): extra_data: dict | None = None, auto_managed: bool | None = None, ) -> CommandConfig: - async with self.get_db() as session: - session: AsyncSession - async with session.begin(): - config = await session.get(CommandConfig, handler_full_name) - if not config: - config = CommandConfig( - handler_full_name=handler_full_name, - plugin_name=plugin_name, - module_path=module_path, - original_command=original_command, - resolved_command=resolved_command, - enabled=enabled if enabled is not None else True, - keep_original_alias=keep_original_alias or False, - conflict_key=conflict_key or original_command, - resolution_strategy=resolution_strategy, - note=note, - extra_data=extra_data, - auto_managed=bool(auto_managed), - ) - session.add(config) - else: - config.plugin_name = plugin_name or config.plugin_name - config.module_path = module_path or config.module_path - config.original_command = ( - original_command or config.original_command - ) - if resolved_command is not None: - config.resolved_command = resolved_command - if enabled is not None: - config.enabled = enabled - if keep_original_alias is not None: - config.keep_original_alias = keep_original_alias - if conflict_key is not None: - config.conflict_key = conflict_key - if resolution_strategy is not None: - config.resolution_strategy = resolution_strategy - if note is not None: - config.note = note - if extra_data is not None: - config.extra_data = extra_data - if auto_managed is not None: - config.auto_managed = auto_managed - await session.flush() - await session.refresh(config) - await session.commit() + async def _op(session: AsyncSession) -> CommandConfig: + config = await session.get(CommandConfig, handler_full_name) + if not config: + config = self._new_command_config( + handler_full_name, + plugin_name, + module_path, + original_command, + resolved_command=resolved_command, + enabled=enabled, + keep_original_alias=keep_original_alias, + conflict_key=conflict_key, + resolution_strategy=resolution_strategy, + note=note, + extra_data=extra_data, + auto_managed=auto_managed, + ) + session.add(config) + else: + self._apply_updates( + config, + plugin_name=plugin_name, + module_path=module_path, + original_command=original_command, + resolved_command=resolved_command, + enabled=enabled, + keep_original_alias=keep_original_alias, + conflict_key=conflict_key, + resolution_strategy=resolution_strategy, + note=note, + extra_data=extra_data, + auto_managed=auto_managed, + ) + await session.flush() + await session.refresh(config) return config + return await self._run_in_tx(_op) + async def delete_command_config(self, handler_full_name: str) -> None: await self.delete_command_configs([handler_full_name]) async def delete_command_configs(self, handler_full_names: list[str]) -> None: if not handler_full_names: return - async with self.get_db() as session: - session: AsyncSession - async with session.begin(): - await session.execute( - delete(CommandConfig).where( - col(CommandConfig.handler_full_name).in_(handler_full_names), - ), - ) - await session.commit() + + async def _op(session: AsyncSession) -> None: + await session.execute( + delete(CommandConfig).where( + col(CommandConfig.handler_full_name).in_(handler_full_names), + ), + ) + + await self._run_in_tx(_op) async def list_command_conflicts( self, @@ -794,58 +861,54 @@ class SQLiteDatabase(BaseDatabase): extra_data: dict | None = None, auto_generated: bool | None = None, ) -> CommandConflict: - async with self.get_db() as session: - session: AsyncSession - async with session.begin(): - result = await session.execute( - select(CommandConflict).where( - CommandConflict.conflict_key == conflict_key, - CommandConflict.handler_full_name == handler_full_name, - ), + async def _op(session: AsyncSession) -> CommandConflict: + result = await session.execute( + select(CommandConflict).where( + CommandConflict.conflict_key == conflict_key, + CommandConflict.handler_full_name == handler_full_name, + ), + ) + record = result.scalar_one_or_none() + if not record: + record = self._new_command_conflict( + conflict_key, + handler_full_name, + plugin_name, + status=status, + resolution=resolution, + resolved_command=resolved_command, + note=note, + extra_data=extra_data, + auto_generated=auto_generated, ) - record = result.scalar_one_or_none() - if not record: - record = CommandConflict( - conflict_key=conflict_key, - handler_full_name=handler_full_name, - plugin_name=plugin_name, - status=status or "pending", - resolution=resolution, - resolved_command=resolved_command, - note=note, - extra_data=extra_data, - auto_generated=bool(auto_generated), - ) - session.add(record) - else: - record.plugin_name = plugin_name or record.plugin_name - if status is not None: - record.status = status - if resolution is not None: - record.resolution = resolution - if resolved_command is not None: - record.resolved_command = resolved_command - if note is not None: - record.note = note - if extra_data is not None: - record.extra_data = extra_data - if auto_generated is not None: - record.auto_generated = auto_generated - await session.flush() - await session.refresh(record) - await session.commit() + session.add(record) + else: + self._apply_updates( + record, + plugin_name=plugin_name, + status=status, + resolution=resolution, + resolved_command=resolved_command, + note=note, + extra_data=extra_data, + auto_generated=auto_generated, + ) + await session.flush() + await session.refresh(record) return record + return await self._run_in_tx(_op) + async def delete_command_conflicts(self, ids: list[int]) -> None: if not ids: return - async with self.get_db() as session: - session: AsyncSession - async with session.begin(): - await session.execute( - delete(CommandConflict).where(col(CommandConflict.id).in_(ids)), - ) - await session.commit() + + async def _op(session: AsyncSession) -> None: + await session.execute( + delete(CommandConflict).where(col(CommandConflict.id).in_(ids)), + ) + + await self._run_in_tx(_op) # ==== # Deprecated Methods