fix: 修复Pyright静态类型检查报错 (#5437)
* refactor: 修正 Sqlite 查询、下载回调、接口重构与类型调整 * feat: 为 OneBotClient 增加 CallAction 协议与异步调用支持
This commit is contained in:
@@ -4,7 +4,7 @@ import typing as T
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
from sqlalchemy import CursorResult, Row
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
@@ -626,7 +626,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = select(ApiKey).where(
|
||||
ApiKey.key_hash == key_hash,
|
||||
col(ApiKey.revoked_at).is_(None),
|
||||
or_(col(ApiKey.expires_at).is_(None), ApiKey.expires_at > now),
|
||||
or_(col(ApiKey.expires_at).is_(None), col(ApiKey.expires_at) > now),
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
@@ -638,7 +638,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
update(ApiKey)
|
||||
.where(ApiKey.key_id == key_id)
|
||||
.where(col(ApiKey.key_id) == key_id)
|
||||
.values(last_used_at=datetime.now(timezone.utc)),
|
||||
)
|
||||
|
||||
@@ -649,7 +649,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
query = (
|
||||
update(ApiKey)
|
||||
.where(ApiKey.key_id == key_id)
|
||||
.where(col(ApiKey.key_id) == key_id)
|
||||
.values(revoked_at=datetime.now(timezone.utc))
|
||||
)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
@@ -663,7 +663,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = T.cast(
|
||||
CursorResult,
|
||||
await session.execute(
|
||||
delete(ApiKey).where(ApiKey.key_id == key_id)
|
||||
delete(ApiKey).where(col(ApiKey.key_id) == key_id)
|
||||
),
|
||||
)
|
||||
return result.rowcount > 0
|
||||
@@ -1457,7 +1457,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _rows_to_session_dicts(rows: list[tuple]) -> list[dict]:
|
||||
def _rows_to_session_dicts(rows: T.Sequence[Row[tuple]]) -> list[dict]:
|
||||
sessions_with_projects = []
|
||||
for row in rows:
|
||||
platform_session = row[0]
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from telegram import BotCommand, Update
|
||||
@@ -27,7 +28,7 @@ from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.media_utils import convert_audio_to_wav
|
||||
|
||||
from .tg_event import TelegramPlatformEvent
|
||||
@@ -380,10 +381,10 @@ class TelegramPlatformAdapter(Platform):
|
||||
elif update.message.voice:
|
||||
file = await update.message.voice.get_file()
|
||||
|
||||
file_basename = os.path.basename(file.file_path)
|
||||
file_basename = os.path.basename(cast(str, file.file_path))
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
temp_path = os.path.join(temp_dir, file_basename)
|
||||
temp_path = await download_image_by_url(file.file_path, path=temp_path)
|
||||
await download_file(cast(str, file.file_path), path=temp_path)
|
||||
path_wav = os.path.join(
|
||||
temp_dir,
|
||||
f"{file_basename}.wav",
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
@@ -65,7 +65,9 @@ class WeixinOfficialAccountServer:
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.callback: (
|
||||
Callable[[BaseMessage], Coroutine[Any, Any, str | None]] | None
|
||||
) = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
self._wx_msg_time_out = 4.0 # 微信服务器要求 5 秒内回复
|
||||
|
||||
@@ -105,6 +105,22 @@ class StarHandlerRegistry(Generic[T]):
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnPluginLoadedEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnPluginUnloadedEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,7 @@ from astrbot.core.message.components import (
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
|
||||
|
||||
from .image_refs import looks_like_image_file_name, normalize_file_like_url
|
||||
from .image_refs import looks_like_image_file_name
|
||||
from .settings import SETTINGS, QuotedMessageParserSettings
|
||||
|
||||
_FORWARD_PLACEHOLDER_PATTERN = re.compile(
|
||||
@@ -296,11 +296,11 @@ def _parse_onebot_segments(
|
||||
or "file"
|
||||
)
|
||||
text_parts.append(f"[File:{file_name}]")
|
||||
candidate_url = seg_data.get("url")
|
||||
candidate_url = seg_data.get("url", "")
|
||||
if (
|
||||
isinstance(candidate_url, str)
|
||||
and candidate_url.strip()
|
||||
and looks_like_image_file_name(normalize_file_like_url(candidate_url))
|
||||
and looks_like_image_file_name(candidate_url)
|
||||
):
|
||||
image_refs.append(candidate_url.strip())
|
||||
candidate_file = seg_data.get("file")
|
||||
@@ -308,11 +308,7 @@ def _parse_onebot_segments(
|
||||
isinstance(candidate_file, str)
|
||||
and candidate_file.strip()
|
||||
and looks_like_image_file_name(
|
||||
normalize_file_like_url(
|
||||
seg_data.get("name")
|
||||
or seg_data.get("file_name")
|
||||
or candidate_file
|
||||
)
|
||||
seg_data.get("name") or seg_data.get("file_name") or candidate_file
|
||||
)
|
||||
):
|
||||
image_refs.append(candidate_file.strip())
|
||||
@@ -368,7 +364,9 @@ def _extract_text_forward_ids_and_images_from_forward_nodes(
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
|
||||
sender = node.get("sender") if isinstance(node.get("sender"), dict) else {}
|
||||
sender = node.get("sender")
|
||||
if not isinstance(sender, dict):
|
||||
sender = {}
|
||||
sender_name = (
|
||||
sender.get("nickname")
|
||||
or sender.get("card")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, Protocol
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -17,6 +18,10 @@ def _unwrap_action_response(ret: dict[str, Any] | None) -> dict[str, Any]:
|
||||
return ret
|
||||
|
||||
|
||||
class CallAction(Protocol):
|
||||
def __call__(self, action: str, **params: Any) -> Awaitable[Any] | Any: ...
|
||||
|
||||
|
||||
class OneBotClient:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -27,7 +32,7 @@ class OneBotClient:
|
||||
self._settings = settings
|
||||
|
||||
@staticmethod
|
||||
def _resolve_call_action(event: AstrMessageEvent):
|
||||
def _resolve_call_action(event: AstrMessageEvent) -> CallAction | None:
|
||||
bot = getattr(event, "bot", None)
|
||||
api = getattr(bot, "api", None)
|
||||
call_action = getattr(api, "call_action", None)
|
||||
|
||||
Reference in New Issue
Block a user