diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 8bccae959..5aeef1eff 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -46,8 +46,6 @@ jobs: include: - language: python build-mode: none - - language: javascript-typescript - build-mode: none # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' # Use `c-cpp` to analyze code written in C, C++ or both # Use 'java-kotlin' to analyze code written in Java, Kotlin or both diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index bd7beceb1..f0019ee7e 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -23,13 +23,12 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 - with: - python-version: "3.12" - name: Install dependencies run: | - python -m pip install --upgrade pip uv - uv sync --group dev + python -m pip install --upgrade pip + pip install pytest pytest-asyncio pytest-cov + pip install --editable . - name: Run tests run: | @@ -38,7 +37,7 @@ jobs: mkdir -p data/temp export TESTING=true export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }} - uv run pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG + pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG - name: Upload results to Codecov uses: codecov/codecov-action@v5 diff --git a/.github/workflows/dashboard_ci.yml b/.github/workflows/dashboard_ci.yml index 921fc4992..46d2fea73 100644 --- a/.github/workflows/dashboard_ci.yml +++ b/.github/workflows/dashboard_ci.yml @@ -13,23 +13,18 @@ jobs: - name: Checkout repository uses: actions/checkout@v6 - - name: Setup pnpm - uses: pnpm/action-setup@v4 - with: - version: 10.28.2 - - name: Setup Node.js uses: actions/setup-node@v6 with: node-version: '24.13.0' - cache: "pnpm" - cache-dependency-path: dashboard/pnpm-lock.yaml - - name: Install and build + - name: npm install, build run: | - pnpm --dir dashboard install --frozen-lockfile - pnpm --dir dashboard run typecheck - pnpm --dir dashboard run build + cd dashboard + npm install pnpm -g + pnpm install + pnpm i --save-dev @types/markdown-it + pnpm run build - name: Inject Commit SHA id: get_sha diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 6300a65a0..18c8d4926 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -25,18 +25,6 @@ jobs: fetch-depth: 1 fetch-tag: true - - name: Setup pnpm - uses: pnpm/action-setup@v4 - with: - version: 10.28.2 - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: '24.13.0' - cache: "pnpm" - cache-dependency-path: dashboard/pnpm-lock.yaml - - name: Check for new commits today if: github.event_name == 'schedule' id: check-commits @@ -58,10 +46,12 @@ jobs: - name: Build Dashboard run: | - pnpm --dir dashboard install --frozen-lockfile - pnpm --dir dashboard run build - mkdir -p dashboard/dist/assets - echo $(git rev-parse HEAD) > dashboard/dist/assets/version + cd dashboard + npm install + npm run build + mkdir -p dist/assets + echo $(git rev-parse HEAD) > dist/assets/version + cd .. mkdir -p data cp -r dashboard/dist data/ @@ -133,18 +123,6 @@ jobs: fetch-depth: 1 fetch-tag: true - - name: Setup pnpm - uses: pnpm/action-setup@v4 - with: - version: 10.28.2 - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: '24.13.0' - cache: "pnpm" - cache-dependency-path: dashboard/pnpm-lock.yaml - - name: Get latest tag (only on manual trigger) id: get-latest-tag if: github.event_name == 'workflow_dispatch' @@ -175,10 +153,12 @@ jobs: - name: Build Dashboard run: | - pnpm --dir dashboard install --frozen-lockfile - pnpm --dir dashboard run build - mkdir -p dashboard/dist/assets - echo $(git rev-parse HEAD) > dashboard/dist/assets/version + cd dashboard + npm install + npm run build + mkdir -p dist/assets + echo $(git rev-parse HEAD) > dist/assets/version + cd .. mkdir -p data cp -r dashboard/dist data/ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4950b7a4b..41f59f0a6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,29 +18,6 @@ permissions: contents: write jobs: - verify-core: - name: Verify Core Quality Gate - runs-on: ubuntu-24.04 - steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - fetch-depth: 0 - ref: ${{ inputs.ref || github.ref }} - - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: "3.12" - - - name: Install uv - shell: bash - run: python -m pip install uv - - - name: Run local PR gate checks - shell: bash - run: make pr-test-neo - build-dashboard: name: Build Dashboard runs-on: ubuntu-24.04 @@ -108,8 +85,7 @@ jobs: VERSION_TAG: ${{ steps.tag.outputs.tag }} shell: bash run: | - sudo apt-get update - sudo apt-get install -y rclone + curl https://rclone.org/install.sh | sudo bash mkdir -p ~/.config/rclone cat < ~/.config/rclone/rclone.conf @@ -130,7 +106,6 @@ jobs: name: Publish GitHub Release runs-on: ubuntu-24.04 needs: - - verify-core - build-dashboard steps: - name: Checkout repository @@ -251,7 +226,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.12" + python-version: "3.10" - name: Install uv shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c2a126e9..8611e2698 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.15.1 + rev: v0.14.1 hooks: # Run the linter. - id: ruff-check @@ -22,4 +22,4 @@ repos: rev: v3.21.0 hooks: - id: pyupgrade - args: [--py312-plus] + args: [--py310-plus] diff --git a/Dockerfile b/Dockerfile index 544c6d6ce..992060d6e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,9 +13,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ bash \ ffmpeg \ curl \ + gnupg \ git \ - nodejs \ - npm \ + && curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \ + && apt-get install -y --no-install-recommends nodejs \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* diff --git a/README.md b/README.md index 9ac80f287..e3b096a32 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_fr.md b/README_fr.md index a6e778df9..3a586adfc 100644 --- a/README_fr.md +++ b/README_fr.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_ja.md b/README_ja.md index c34106143..43b73884d 100644 --- a/README_ja.md +++ b/README_ja.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_ru.md b/README_ru.md index 1bc1f5554..8848dd92d 100644 --- a/README_ru.md +++ b/README_ru.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_zh-TW.md b/README_zh-TW.md index 3bd2455b2..e3291d0b0 100644 --- a/README_zh-TW.md +++ b/README_zh-TW.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_zh.md b/README_zh.md index dc2c015f0..7a85217b4 100644 --- a/README_zh.md +++ b/README_zh.md @@ -17,7 +17,7 @@
-python +python zread Docker pull diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index 2764935d0..c06dda350 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -1,6 +1,6 @@ import shutil import tempfile -from enum import StrEnum +from enum import Enum from io import BytesIO from pathlib import Path from zipfile import ZipFile @@ -12,7 +12,7 @@ import yaml from .version_comparator import VersionComparator -class PluginStatus(StrEnum): +class PluginStatus(str, Enum): INSTALLED = "installed" NEED_UPDATE = "needs-update" NOT_INSTALLED = "not-installed" diff --git a/astrbot/core/agent/agent.py b/astrbot/core/agent/agent.py index f4606b6da..d6e2e7cb4 100644 --- a/astrbot/core/agent/agent.py +++ b/astrbot/core/agent/agent.py @@ -1,12 +1,13 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Generic from .hooks import BaseAgentRunHooks +from .run_context import TContext from .tool import FunctionTool @dataclass -class Agent[TContext]: +class Agent(Generic[TContext]): name: str instructions: str | None = None tools: list[str | FunctionTool] | None = None diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 01fc5159c..8475009d3 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,8 +1,11 @@ +from typing import Generic + from .agent import Agent +from .run_context import TContext from .tool import FunctionTool -class HandoffTool[TContext](FunctionTool): +class HandoffTool(FunctionTool, Generic[TContext]): """Handoff tool for delegating tasks to another agent.""" def __init__( diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 451a95753..74ca6335b 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -1,12 +1,14 @@ +from typing import Generic + import mcp from astrbot.core.agent.tool import FunctionTool from astrbot.core.provider.entities import LLMResponse -from .run_context import ContextWrapper +from .run_context import ContextWrapper, TContext -class BaseAgentRunHooks[TContext]: +class BaseAgentRunHooks(Generic[TContext]): async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 5c4c19fad..18f4d47e0 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,6 +2,7 @@ import asyncio import logging from contextlib import AsyncExitStack from datetime import timedelta +from typing import Generic from tenacity import ( before_sleep_log, @@ -15,6 +16,7 @@ from astrbot import logger from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe +from .run_context import TContext from .tool import FunctionTool try: @@ -99,7 +101,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return True, "" return False, f"HTTP {response.status}: {response.reason}" - except TimeoutError: + except asyncio.TimeoutError: return False, f"Connection timeout: {timeout} seconds" except Exception as e: return False, f"{e!s}" @@ -358,7 +360,7 @@ class MCPClient: self.running_event.set() -class MCPTool[TContext](FunctionTool): +class MCPTool(FunctionTool, Generic[TContext]): """A function tool that calls an MCP service.""" def __init__( diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index e1e3ff8e3..21e796433 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -7,7 +7,7 @@ from astrbot.core.provider.entities import LLMResponse from ..hooks import BaseAgentRunHooks from ..response import AgentResponse -from ..run_context import ContextWrapper +from ..run_context import ContextWrapper, TContext class AgentState(Enum): @@ -19,7 +19,7 @@ class AgentState(Enum): ERROR = auto() # Error state -class BaseAgentRunner[TContext]: +class BaseAgentRunner(T.Generic[TContext]): @abc.abstractmethod async def reset( self, diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py index 0d7fab207..a8300bb71 100644 --- a/astrbot/core/agent/runners/coze/coze_agent_runner.py +++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py @@ -1,7 +1,7 @@ import base64 import json +import sys import typing as T -from typing import override import astrbot.core.message.components as Comp from astrbot import logger @@ -18,6 +18,11 @@ from ...run_context import ContextWrapper, TContext from ..base import AgentResponse, AgentState, BaseAgentRunner from .coze_api_client import CozeAPIClient +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + class CozeAgentRunner(BaseAgentRunner[TContext]): """Coze Agent Runner""" @@ -246,7 +251,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]): conversation_id=conversation_id, auto_save_history=self.auto_save_history, stream=True, - timeout_seconds=self.timeout, + timeout=self.timeout, ): event_type = chunk.get("event") data = chunk.get("data", {}) diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index 03dbe64cc..f5799dfbb 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -2,7 +2,6 @@ import asyncio import io import json from collections.abc import AsyncGenerator -from pathlib import Path from typing import Any import aiohttp @@ -91,7 +90,7 @@ class CozeAPIClient: logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}") return file_id - except TimeoutError: + except asyncio.TimeoutError: logger.error("文件上传超时") raise Exception("文件上传超时") except Exception as e: @@ -129,7 +128,7 @@ class CozeAPIClient: conversation_id: str | None = None, auto_save_history: bool = True, stream: bool = True, - timeout_seconds: float = 120, + timeout: float = 120, ) -> AsyncGenerator[dict[str, Any], None]: """发送聊天消息并返回流式响应 @@ -140,7 +139,7 @@ class CozeAPIClient: conversation_id: 会话ID auto_save_history: 是否自动保存历史 stream: 是否流式响应 - timeout_seconds: 超时时间 + timeout: 超时时间 """ session = await self._ensure_session() @@ -167,7 +166,7 @@ class CozeAPIClient: url, json=payload, params=params, - timeout=aiohttp.ClientTimeout(total=timeout_seconds), + timeout=aiohttp.ClientTimeout(total=timeout), ) as response: if response.status == 401: raise Exception("Coze API 认证失败,请检查 API Key 是否正确") @@ -204,8 +203,8 @@ class CozeAPIClient: except json.JSONDecodeError: event_data = {"content": data_str} - except TimeoutError: - raise Exception(f"Coze API 流式请求超时 ({timeout_seconds}秒)") + except asyncio.TimeoutError: + raise Exception(f"Coze API 流式请求超时 ({timeout}秒)") except Exception as e: raise Exception(f"Coze API 流式请求失败: {e!s}") @@ -237,7 +236,7 @@ class CozeAPIClient: except json.JSONDecodeError: raise Exception("Coze API 返回非JSON格式") - except TimeoutError: + except asyncio.TimeoutError: raise Exception("Coze API 请求超时") except aiohttp.ClientError as e: raise Exception(f"Coze API 请求失败: {e!s}") @@ -295,7 +294,8 @@ if __name__ == "__main__": client = CozeAPIClient(api_key=api_key) try: - file_data = await asyncio.to_thread(Path("README.md").read_bytes) + with open("README.md", "rb") as f: + file_data = f.read() file_id = await client.upload_file(file_data) print(f"Uploaded file_id: {file_id}") async for event in client.chat_messages( diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 080e627d5..1aaf6e3b9 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -2,9 +2,9 @@ import asyncio import functools import queue import re +import sys import threading import typing as T -from typing import override from dashscope import Application from dashscope.app.application_response import ApplicationResponse @@ -22,6 +22,11 @@ from ...response import AgentResponseData from ...run_context import ContextWrapper, TContext from ..base import AgentResponse, AgentState, BaseAgentRunner +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + class DashscopeAgentRunner(BaseAgentRunner[TContext]): """Dashscope Agent Runner""" diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 9e4a11471..50ec7c826 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -1,10 +1,10 @@ import asyncio import hashlib import json +import sys import typing as T from collections import deque from dataclasses import dataclass, field -from typing import override from uuid import uuid4 import astrbot.core.message.components as Comp @@ -40,6 +40,11 @@ from .deerflow_stream_utils import ( get_message_id, ) +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + class DeerFlowAgentRunner(BaseAgentRunner[TContext]): """DeerFlow Agent Runner via LangGraph HTTP API.""" @@ -373,9 +378,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]): if thread_id: return thread_id - thread = await self.api_client.create_thread( - timeout_seconds=min(30, self.timeout) - ) + thread = await self.api_client.create_thread(timeout=min(30, self.timeout)) thread_id = thread.get("thread_id", "") if not thread_id: raise Exception( @@ -636,7 +639,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]): async for event in self.api_client.stream_run( thread_id=thread_id, payload=payload, - timeout_seconds=self.timeout, + timeout=self.timeout, ): event_type = event.get("event") data = event.get("data") @@ -663,7 +666,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]): if event_type == "end": break - except TimeoutError: + except (asyncio.TimeoutError, TimeoutError): logger.warning( "DeerFlow stream timed out after %ss for thread_id=%s; returning partial result.", self.timeout, diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index 4ae9432e0..37a23f243 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -139,7 +139,7 @@ class DeerFlowAPIClient: ) -> None: await self.close() - async def create_thread(self, timeout_seconds: float = 20) -> dict[str, Any]: + async def create_thread(self, timeout: float = 20) -> dict[str, Any]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads" payload = {"metadata": {}} @@ -147,7 +147,7 @@ class DeerFlowAPIClient: url, json=payload, headers=self.headers, - timeout=timeout_seconds, + timeout=timeout, proxy=self.proxy, ) as resp: if resp.status not in (200, 201): @@ -161,7 +161,7 @@ class DeerFlowAPIClient: self, thread_id: str, payload: dict[str, Any], - timeout_seconds: float = 120, + timeout: float = 120, ) -> AsyncGenerator[dict[str, Any], None]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream" @@ -183,9 +183,9 @@ class DeerFlowAPIClient: # Use socket read timeout so active heartbeats/chunks can keep the stream alive. stream_timeout = ClientTimeout( total=None, - connect=min(timeout_seconds, 30), - sock_connect=min(timeout_seconds, 30), - sock_read=timeout_seconds, + connect=min(timeout, 30), + sock_connect=min(timeout, 30), + sock_read=timeout, ) async with session.post( url, diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py index 1630ebf08..93f8d3570 100644 --- a/astrbot/core/agent/runners/dify/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -1,7 +1,7 @@ import base64 import os +import sys import typing as T -from typing import override import astrbot.core.message.components as Comp from astrbot.core import logger, sp @@ -19,6 +19,11 @@ from ...run_context import ContextWrapper, TContext from ..base import AgentResponse, AgentState, BaseAgentRunner from .dify_api_client import DifyAPIClient +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + class DifyAgentRunner(BaseAgentRunner[TContext]): """Dify Agent Runner""" @@ -171,7 +176,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): user=session_id, conversation_id=conversation_id, files=files_payload, - timeout_seconds=self.timeout, + timeout=self.timeout, ): logger.debug(f"dify resp chunk: {chunk}") if chunk["event"] == "message" or chunk["event"] == "agent_message": @@ -211,7 +216,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): }, user=session_id, files=files_payload, - timeout_seconds=self.timeout, + timeout=self.timeout, ): logger.debug(f"dify workflow resp chunk: {chunk}") match chunk["event"]: diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index db7b923fc..26da6dfe9 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -1,8 +1,6 @@ -import asyncio import codecs import json from collections.abc import AsyncGenerator -from pathlib import Path from typing import Any from aiohttp import ClientResponse, ClientSession, FormData @@ -49,20 +47,20 @@ class DifyAPIClient: response_mode: str = "streaming", conversation_id: str = "", files: list[dict[str, Any]] | None = None, - timeout_seconds: float = 60, + timeout: float = 60, ) -> AsyncGenerator[dict[str, Any], None]: if files is None: files = [] url = f"{self.api_base}/chat-messages" payload = locals() payload.pop("self") - payload.pop("timeout_seconds") + payload.pop("timeout") logger.info(f"chat_messages payload: {payload}") async with self.session.post( url, json=payload, headers=self.headers, - timeout=timeout_seconds, + timeout=timeout, ) as resp: if resp.status != 200: text = await resp.text() @@ -78,20 +76,20 @@ class DifyAPIClient: user: str, response_mode: str = "streaming", files: list[dict[str, Any]] | None = None, - timeout_seconds: float = 60, + timeout: float = 60, ): if files is None: files = [] url = f"{self.api_base}/workflows/run" payload = locals() payload.pop("self") - payload.pop("timeout_seconds") + payload.pop("timeout") logger.info(f"workflow_run payload: {payload}") async with self.session.post( url, json=payload, headers=self.headers, - timeout=timeout_seconds, + timeout=timeout, ) as resp: if resp.status != 200: text = await resp.text() @@ -136,13 +134,14 @@ class DifyAPIClient: # 使用文件路径 import os - file_content = await asyncio.to_thread(Path(file_path).read_bytes) - form.add_field( - "file", - file_content, - filename=os.path.basename(file_path), - content_type=mime_type or "application/octet-stream", - ) + with open(file_path, "rb") as f: + file_content = f.read() + form.add_field( + "file", + file_content, + filename=os.path.basename(file_path), + content_type=mime_type or "application/octet-stream", + ) else: raise ValueError("file_path 和 file_data 不能同时为 None") diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index cc231be69..743b28007 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,10 +1,10 @@ import asyncio import copy +import sys import time import traceback import typing as T from dataclasses import dataclass, field -from typing import override from mcp.types import ( BlobResourceContents, @@ -44,6 +44,11 @@ from ..run_context import ContextWrapper, TContext from ..tool_executor import BaseFunctionToolExecutor from .base import AgentResponse, AgentState, BaseAgentRunner +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + @dataclass(slots=True) class _HandleFunctionToolsResult: diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 98f354ae4..c2536708e 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,6 +1,6 @@ import copy from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import Any +from typing import Any, Generic import jsonschema import mcp @@ -10,7 +10,7 @@ from pydantic.dataclasses import dataclass from astrbot.core.message.message_event_result import MessageEventResult -from .run_context import ContextWrapper +from .run_context import ContextWrapper, TContext ParametersType = dict[str, Any] ToolExecResult = str | mcp.types.CallToolResult @@ -38,7 +38,7 @@ class ToolSchema: @dataclass -class FunctionTool[TContext](ToolSchema): +class FunctionTool(ToolSchema, Generic[TContext]): """A callable tool, for function calling.""" handler: ( diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 8708fd97d..2704119d4 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -1,13 +1,13 @@ from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, Generic import mcp -from .run_context import ContextWrapper +from .run_context import ContextWrapper, TContext from .tool import FunctionTool -class BaseFunctionToolExecutor[TContext]: +class BaseFunctionToolExecutor(Generic[TContext]): @classmethod async def execute( cls, diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index f7178fe3b..dd65f92e6 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -3,7 +3,6 @@ import re import time import traceback from collections.abc import AsyncGenerator -from pathlib import Path from astrbot.core import logger from astrbot.core.agent.message import Message @@ -510,7 +509,8 @@ async def _simulated_stream_tts( audio_path = await tts_provider.get_audio(text) if audio_path: - audio_data = await asyncio.to_thread(Path(audio_path).read_bytes) + with open(audio_path, "rb") as f: + audio_data = f.read() await audio_queue.put((text, audio_data)) except Exception as e: logger.error( diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index be51ced72..0dc8b9eeb 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -625,7 +625,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): exc_info=True, ) yield None - except TimeoutError: + except asyncio.TimeoutError: raise Exception( f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.", ) diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 933346d1f..2e0d8b0aa 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -1,4 +1,3 @@ -import asyncio import base64 import json import os @@ -242,7 +241,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]): bool: indicates whether the file was downloaded from sandbox. """ - if await asyncio.to_thread(os.path.exists, path): + if os.path.exists(path): return path, False # Try to check if the file exists in the sandbox diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index 5658bf23a..a92237599 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -4,12 +4,11 @@ 导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。 """ -import asyncio import hashlib import json import os import zipfile -from datetime import UTC, datetime +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Any @@ -84,7 +83,7 @@ class AstrBotExporter: output_dir = get_astrbot_backups_path() # 确保输出目录存在 - await asyncio.to_thread(Path(output_dir).mkdir, parents=True, exist_ok=True) + Path(output_dir).mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") zip_filename = f"astrbot_backup_{timestamp}.zip" @@ -161,10 +160,9 @@ class AstrBotExporter: # 3. 导出配置文件 if progress_callback: await progress_callback("config", 0, 100, "正在导出配置文件...") - config_content = await asyncio.to_thread( - self._read_text_if_exists, self.config_path - ) - if config_content is not None: + if os.path.exists(self.config_path): + with open(self.config_path, encoding="utf-8") as f: + config_content = f.read() zf.writestr("config/cmd_config.json", config_content) self._add_checksum("config/cmd_config.json", config_content) if progress_callback: @@ -201,7 +199,7 @@ class AstrBotExporter: except Exception as e: logger.error(f"备份导出失败: {e}") # 清理失败的文件 - if await asyncio.to_thread(os.path.exists, zip_path): + if os.path.exists(zip_path): os.remove(zip_path) raise @@ -319,7 +317,7 @@ class AstrBotExporter: for dir_name, dir_path in backup_directories.items(): full_path = Path(dir_path) - if not await asyncio.to_thread(full_path.exists): + if not full_path.exists(): logger.debug(f"目录不存在,跳过: {full_path}") continue @@ -361,44 +359,17 @@ class AstrBotExporter: self, zf: zipfile.ZipFile, attachments: list[dict] ) -> None: """导出附件文件""" - await asyncio.to_thread(self._export_attachments_sync, zf, attachments) - - def _export_attachments_sync( - self, zf: zipfile.ZipFile, attachments: list[dict] - ) -> None: - """在单个线程中批量导出附件,减少高频线程切换。""" for attachment in attachments: - file_path = attachment.get("path", "") - attachment_id = attachment.get("attachment_id") try: - if not file_path: - continue - if not attachment_id: - logger.warning( - f"跳过附件导出:attachment_id 为空 (path={file_path})" - ) - continue - # 使用 attachment_id 作为文件名 - ext = os.path.splitext(file_path)[1] - archive_path = f"files/attachments/{attachment_id}{ext}" - zf.write(file_path, archive_path) - except FileNotFoundError: - # 和旧逻辑保持一致:缺失文件直接跳过。 - continue - except OSError as e: - logger.warning( - f"导出附件失败 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}" - ) + file_path = attachment.get("path", "") + if file_path and os.path.exists(file_path): + # 使用 attachment_id 作为文件名 + attachment_id = attachment.get("attachment_id", "") + ext = os.path.splitext(file_path)[1] + archive_path = f"files/attachments/{attachment_id}{ext}" + zf.write(file_path, archive_path) except Exception as e: - logger.warning( - f"导出附件时发生非预期错误,已跳过 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}" - ) - - def _read_text_if_exists(self, file_path: str) -> str | None: - """Read text file when it exists in a single synchronous call.""" - if not os.path.exists(file_path): - return None - return Path(file_path).read_text(encoding="utf-8") + logger.warning(f"导出附件失败: {e}") def _model_to_dict(self, record: Any) -> dict: """将 SQLModel 实例转换为字典 @@ -475,7 +446,7 @@ class AstrBotExporter: manifest = { "version": BACKUP_MANIFEST_VERSION, "astrbot_version": VERSION, - "exported_at": datetime.now(UTC).isoformat(), + "exported_at": datetime.now(timezone.utc).isoformat(), "origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传 "schema_version": { "main_db": "v4", diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 5362ab3cb..b51c7d956 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -7,13 +7,12 @@ - 版本匹配时也需要用户确认 """ -import asyncio import json import os import shutil import zipfile from dataclasses import dataclass, field -from datetime import UTC, datetime +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Any @@ -365,7 +364,7 @@ class AstrBotImporter: """ result = ImportResult() - if not await asyncio.to_thread(os.path.exists, zip_path): + if not os.path.exists(zip_path): result.add_error(f"备份文件不存在: {zip_path}") return result @@ -447,13 +446,12 @@ class AstrBotImporter: try: config_content = zf.read("config/cmd_config.json") # 备份现有配置 - if await asyncio.to_thread(os.path.exists, self.config_path): + if os.path.exists(self.config_path): backup_path = f"{self.config_path}.bak" shutil.copy2(self.config_path, backup_path) - await asyncio.to_thread( - Path(self.config_path).write_bytes, config_content - ) + with open(self.config_path, "wb") as f: + f.write(config_content) result.imported_files["config"] = 1 except Exception as e: result.add_warning(f"导入配置文件失败: {e}") @@ -677,9 +675,9 @@ class AstrBotImporter: if isinstance(value, datetime): dt = value if dt.tzinfo is None: - dt = dt.replace(tzinfo=UTC) + dt = dt.replace(tzinfo=timezone.utc) else: - dt = dt.astimezone(UTC) + dt = dt.astimezone(timezone.utc) return dt.isoformat() if isinstance(value, str): timestamp = value.strip() @@ -690,9 +688,9 @@ class AstrBotImporter: try: dt = datetime.fromisoformat(timestamp) if dt.tzinfo is None: - dt = dt.replace(tzinfo=UTC) + dt = dt.replace(tzinfo=timezone.utc) else: - dt = dt.astimezone(UTC) + dt = dt.astimezone(timezone.utc) return dt.isoformat() except ValueError: return None @@ -755,8 +753,8 @@ class AstrBotImporter: if faiss_path in zf.namelist(): try: target_path = kb_dir / "index.faiss" - with zf.open(faiss_path) as src: - await asyncio.to_thread(target_path.write_bytes, src.read()) + with zf.open(faiss_path) as src, open(target_path, "wb") as dst: + dst.write(src.read()) except Exception as e: result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}") @@ -768,8 +766,8 @@ class AstrBotImporter: rel_path = name[len(media_prefix) :] target_path = kb_dir / rel_path target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src: - await asyncio.to_thread(target_path.write_bytes, src.read()) + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) except Exception as e: result.add_warning(f"导入媒体文件 {name} 失败: {e}") @@ -830,8 +828,8 @@ class AstrBotImporter: target_path = attachments_dir / os.path.basename(name) target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src: - await asyncio.to_thread(target_path.write_bytes, src.read()) + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) count += 1 except Exception as e: logger.warning(f"导入附件 {name} 失败: {e}") @@ -887,15 +885,15 @@ class AstrBotImporter: continue # 备份现有目录(如果存在) - if await asyncio.to_thread(target_dir.exists): + if target_dir.exists(): backup_path = Path(f"{target_dir}.bak") - if await asyncio.to_thread(backup_path.exists): + if backup_path.exists(): shutil.rmtree(backup_path) shutil.move(str(target_dir), str(backup_path)) logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}") # 创建目标目录 - await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True) + target_dir.mkdir(parents=True, exist_ok=True) # 解压文件 for name in dir_files: @@ -908,8 +906,8 @@ class AstrBotImporter: target_path = target_dir / rel_path target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src: - await asyncio.to_thread(target_path.write_bytes, src.read()) + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) file_count += 1 except Exception as e: result.add_warning(f"导入文件 {name} 失败: {e}") diff --git a/astrbot/core/computer/booters/bay_manager.py b/astrbot/core/computer/booters/bay_manager.py index da429eb65..24fa379e8 100644 --- a/astrbot/core/computer/booters/bay_manager.py +++ b/astrbot/core/computer/booters/bay_manager.py @@ -118,10 +118,10 @@ class BayContainerManager: return f"http://127.0.0.1:{self._host_port}" - async def wait_healthy(self, timeout_seconds: int = HEALTH_TIMEOUT_S) -> None: + async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: """Block until Bay's ``/health`` endpoint returns 200.""" url = f"http://127.0.0.1:{self._host_port}/health" - deadline = asyncio.get_event_loop().time() + timeout_seconds + deadline = asyncio.get_event_loop().time() + timeout last_error: str = "" async with aiohttp.ClientSession() as session: @@ -140,7 +140,7 @@ class BayContainerManager: await asyncio.sleep(HEALTH_POLL_INTERVAL_S) raise TimeoutError( - f"Bay did not become healthy within {timeout_seconds}s (last error: {last_error})" + f"Bay did not become healthy within {timeout}s (last error: {last_error})" ) async def read_credentials(self) -> str: diff --git a/astrbot/core/computer/booters/boxlite.py b/astrbot/core/computer/booters/boxlite.py index 337f5a68e..70064fdd4 100644 --- a/astrbot/core/computer/booters/boxlite.py +++ b/astrbot/core/computer/booters/boxlite.py @@ -1,6 +1,5 @@ import asyncio import random -from pathlib import Path from typing import Any import aiohttp @@ -47,7 +46,8 @@ class MockShipyardSandboxClient: try: # Read file content - file_content = await asyncio.to_thread(Path(path).read_bytes) + with open(path, "rb") as f: + file_content = f.read() # Create multipart form data data = aiohttp.FormData() @@ -88,7 +88,7 @@ class MockShipyardSandboxClient: "error": f"Connection error: {str(e)}", "message": "File upload failed", } - except TimeoutError: + except asyncio.TimeoutError: return { "success": False, "error": "File upload timeout", diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index 011ac45f4..a80ef0da2 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -59,7 +59,7 @@ class LocalShellComponent(ShellComponent): command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout_seconds: int | None = 30, + timeout: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: @@ -87,7 +87,7 @@ class LocalShellComponent(ShellComponent): shell=shell, cwd=working_dir, env=run_env, - timeout=timeout_seconds, + timeout=timeout, capture_output=True, text=True, ) @@ -106,14 +106,14 @@ class LocalPythonComponent(PythonComponent): self, code: str, kernel_id: str | None = None, - timeout_seconds: int = 30, + timeout: int = 30, silent: bool = False, ) -> dict[str, Any]: def _run() -> dict[str, Any]: try: result = subprocess.run( [os.environ.get("PYTHON", sys.executable), "-c", code], - timeout=timeout_seconds, + timeout=timeout, capture_output=True, text=True, ) diff --git a/astrbot/core/computer/booters/shipyard_neo.py b/astrbot/core/computer/booters/shipyard_neo.py index 6c6f62bb5..6304696ad 100644 --- a/astrbot/core/computer/booters/shipyard_neo.py +++ b/astrbot/core/computer/booters/shipyard_neo.py @@ -1,9 +1,7 @@ from __future__ import annotations -import asyncio import os import shlex -from pathlib import Path from typing import Any, cast from astrbot.api import logger @@ -35,11 +33,11 @@ class NeoPythonComponent(PythonComponent): self, code: str, kernel_id: str | None = None, - timeout_seconds: int = 30, + timeout: int = 30, silent: bool = False, ) -> dict[str, Any]: _ = kernel_id # Bay runtime does not expose kernel_id in current SDK. - result = await self._sandbox.python.exec(code, timeout=timeout_seconds) + result = await self._sandbox.python.exec(code, timeout=timeout) payload = _maybe_model_dump(result) output_text = payload.get("output", "") or "" @@ -77,7 +75,7 @@ class NeoShellComponent(ShellComponent): command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout_seconds: int | None = 30, + timeout: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: @@ -101,7 +99,7 @@ class NeoShellComponent(ShellComponent): result = await self._sandbox.shell.exec( run_command, - timeout=timeout_seconds or 30, + timeout=timeout or 30, cwd=cwd, ) payload = _maybe_model_dump(result) @@ -194,7 +192,7 @@ class NeoBrowserComponent(BrowserComponent): async def exec( self, cmd: str, - timeout_seconds: int = 30, + timeout: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, @@ -202,7 +200,7 @@ class NeoBrowserComponent(BrowserComponent): ) -> dict[str, Any]: result = await self._sandbox.browser.exec( cmd, - timeout=timeout_seconds, + timeout=timeout, description=description, tags=tags, learn=learn, @@ -213,7 +211,7 @@ class NeoBrowserComponent(BrowserComponent): async def exec_batch( self, commands: list[str], - timeout_seconds: int = 60, + timeout: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, @@ -222,7 +220,7 @@ class NeoBrowserComponent(BrowserComponent): ) -> dict[str, Any]: result = await self._sandbox.browser.exec_batch( commands, - timeout=timeout_seconds, + timeout=timeout, stop_on_error=stop_on_error, description=description, tags=tags, @@ -234,7 +232,7 @@ class NeoBrowserComponent(BrowserComponent): async def run_skill( self, skill_key: str, - timeout_seconds: int = 60, + timeout: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, @@ -242,7 +240,7 @@ class NeoBrowserComponent(BrowserComponent): ) -> dict[str, Any]: result = await self._sandbox.browser.run_skill( skill_key=skill_key, - timeout=timeout_seconds, + timeout=timeout, stop_on_error=stop_on_error, include_trace=include_trace, description=description, @@ -470,7 +468,8 @@ class ShipyardNeoBooter(ComputerBooter): async def upload_file(self, path: str, file_name: str) -> dict: if self._sandbox is None: raise RuntimeError("ShipyardNeoBooter is not initialized.") - content = await asyncio.to_thread(Path(path).read_bytes) + with open(path, "rb") as f: + content = f.read() remote_path = file_name.lstrip("/") await self._sandbox.filesystem.upload(remote_path, content) logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path) @@ -487,7 +486,8 @@ class ShipyardNeoBooter(ComputerBooter): local_dir = os.path.dirname(local_path) if local_dir: os.makedirs(local_dir, exist_ok=True) - await asyncio.to_thread(Path(local_path).write_bytes, cast(bytes, content)) + with open(local_path, "wb") as f: + f.write(cast(bytes, content)) logger.info( "[Computer] File downloaded from Neo sandbox: %s -> %s", remote_path, diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 1adaeae08..aa10d125e 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -1,4 +1,3 @@ -import asyncio import json import os import shutil @@ -373,12 +372,12 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: splitting into `apply` and `scan` phases. """ skills_root = Path(get_astrbot_skills_path()) - if not await asyncio.to_thread(skills_root.is_dir): + if not skills_root.is_dir(): return local_skill_dirs = _list_local_skill_dirs(skills_root) temp_dir = Path(get_astrbot_temp_path()) - await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) zip_base = temp_dir / "skills_bundle" zip_path = zip_base.with_suffix(".zip") diff --git a/astrbot/core/computer/olayer/browser.py b/astrbot/core/computer/olayer/browser.py index 5bc40a446..aa69f4501 100644 --- a/astrbot/core/computer/olayer/browser.py +++ b/astrbot/core/computer/olayer/browser.py @@ -11,7 +11,7 @@ class BrowserComponent(Protocol): async def exec( self, cmd: str, - timeout_seconds: int = 30, + timeout: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, @@ -23,7 +23,7 @@ class BrowserComponent(Protocol): async def exec_batch( self, commands: list[str], - timeout_seconds: int = 60, + timeout: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, @@ -36,7 +36,7 @@ class BrowserComponent(Protocol): async def run_skill( self, skill_key: str, - timeout_seconds: int = 60, + timeout: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, diff --git a/astrbot/core/computer/olayer/python.py b/astrbot/core/computer/olayer/python.py index 09bf497db..625504146 100644 --- a/astrbot/core/computer/olayer/python.py +++ b/astrbot/core/computer/olayer/python.py @@ -12,7 +12,7 @@ class PythonComponent(Protocol): self, code: str, kernel_id: str | None = None, - timeout_seconds: int = 30, + timeout: int = 30, silent: bool = False, ) -> dict[str, Any]: """Execute Python code""" diff --git a/astrbot/core/computer/olayer/shell.py b/astrbot/core/computer/olayer/shell.py index 67d9f95ef..df2263b65 100644 --- a/astrbot/core/computer/olayer/shell.py +++ b/astrbot/core/computer/olayer/shell.py @@ -13,7 +13,7 @@ class ShellComponent(Protocol): command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout_seconds: int | None = 30, + timeout: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: diff --git a/astrbot/core/computer/tools/browser.py b/astrbot/core/computer/tools/browser.py index 80a9be11a..70061ac31 100644 --- a/astrbot/core/computer/tools/browser.py +++ b/astrbot/core/computer/tools/browser.py @@ -71,23 +71,19 @@ class BrowserExecTool(FunctionTool): self, context: ContextWrapper[AstrAgentContext], cmd: str, - timeout_seconds: int = 30, + timeout: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, include_trace: bool = False, - **kwargs: Any, ) -> ToolExecResult: - legacy_timeout = kwargs.pop("timeout", None) - if legacy_timeout is not None: - timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.exec( cmd=cmd, - timeout_seconds=timeout_seconds, + timeout=timeout, description=description, tags=tags, learn=learn, @@ -137,24 +133,20 @@ class BrowserBatchExecTool(FunctionTool): self, context: ContextWrapper[AstrAgentContext], commands: list[str], - timeout_seconds: int = 60, + timeout: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, learn: bool = False, include_trace: bool = False, - **kwargs: Any, ) -> ToolExecResult: - legacy_timeout = kwargs.pop("timeout", None) - if legacy_timeout is not None: - timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.exec_batch( commands=commands, - timeout_seconds=timeout_seconds, + timeout=timeout, stop_on_error=stop_on_error, description=description, tags=tags, @@ -189,23 +181,19 @@ class RunBrowserSkillTool(FunctionTool): self, context: ContextWrapper[AstrAgentContext], skill_key: str, - timeout_seconds: int = 60, + timeout: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, tags: str | None = None, - **kwargs: Any, ) -> ToolExecResult: - legacy_timeout = kwargs.pop("timeout", None) - if legacy_timeout is not None: - timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.run_skill( skill_key=skill_key, - timeout_seconds=timeout_seconds, + timeout=timeout, stop_on_error=stop_on_error, include_trace=include_trace, description=description, diff --git a/astrbot/core/computer/tools/fs.py b/astrbot/core/computer/tools/fs.py index d50025f4d..31b7f3f51 100644 --- a/astrbot/core/computer/tools/fs.py +++ b/astrbot/core/computer/tools/fs.py @@ -1,4 +1,3 @@ -import asyncio import os import uuid from dataclasses import dataclass, field @@ -112,10 +111,10 @@ class FileUploadTool(FunctionTool): ) try: # Check if file exists - if not await asyncio.to_thread(os.path.exists, local_path): + if not os.path.exists(local_path): return f"Error: File does not exist: {local_path}" - if not await asyncio.to_thread(os.path.isfile, local_path): + if not os.path.isfile(local_path): return f"Error: Path is not a file: {local_path}" # Use basename if sandbox_filename is not provided diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index 211514f7a..d12878be3 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -1,7 +1,7 @@ import asyncio import json from collections.abc import Awaitable, Callable -from datetime import UTC, datetime +from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from zoneinfo import ZoneInfo @@ -192,7 +192,7 @@ class CronJobManager: job = await self.db.get_cron_job(job_id) if not job or not job.enabled: return - start_time = datetime.now(UTC) + start_time = datetime.now(timezone.utc) await self.db.update_cron_job( job_id, status="running", last_run_at=start_time, last_error=None ) diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index 47ecadf04..d7bca3067 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -1,4 +1,3 @@ -import asyncio import os from astrbot.api import logger, sp @@ -23,7 +22,7 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: data_dir = get_astrbot_data_path() data_v3_db = os.path.join(data_dir, "data_v3.db") - if not await asyncio.to_thread(os.path.exists, data_v3_db): + if not os.path.exists(data_v3_db): return False migration_done = await db_helper.get_preference( "global", diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index d7a57a6d6..727d97b29 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -106,8 +106,8 @@ async def migration_platform_table( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) secs_from_2023_4_10_to_now = ( - datetime.datetime.now(datetime.UTC) - - datetime.datetime(2023, 4, 10, tzinfo=datetime.UTC) + datetime.datetime.now(datetime.timezone.utc) + - datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc) ).total_seconds() offset_sec = int(secs_from_2023_4_10_to_now) logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。") @@ -162,7 +162,7 @@ async def migration_platform_table( { "timestamp": datetime.datetime.fromtimestamp( bucket_end, - tz=datetime.UTC, + tz=datetime.timezone.utc, ), "platform_id": platform_id, "platform_type": platform_type, diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 1b8179f07..451f054f6 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -1,16 +1,16 @@ import uuid from dataclasses import dataclass, field -from datetime import UTC, datetime +from datetime import datetime, timezone from typing import TypedDict from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint class TimestampMixin(SQLModel): - created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(UTC), - sa_column_kwargs={"onupdate": lambda: datetime.now(UTC)}, + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)}, ) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index e356e85aa..f496e19d5 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -2,7 +2,7 @@ import asyncio import threading import typing as T from collections.abc import Awaitable, Callable -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta, timezone from sqlalchemy import CursorResult, Row from sqlalchemy.ext.asyncio import AsyncSession @@ -633,7 +633,7 @@ class SQLiteDatabase(BaseDatabase): """Get an active API key by hash (not revoked, not expired).""" async with self.get_db() as session: session: AsyncSession - now = datetime.now(UTC) + now = datetime.now(timezone.utc) query = select(ApiKey).where( ApiKey.key_hash == key_hash, col(ApiKey.revoked_at).is_(None), @@ -650,7 +650,7 @@ class SQLiteDatabase(BaseDatabase): await session.execute( update(ApiKey) .where(col(ApiKey.key_id) == key_id) - .values(last_used_at=datetime.now(UTC)), + .values(last_used_at=datetime.now(timezone.utc)), ) async def revoke_api_key(self, key_id: str) -> bool: @@ -661,7 +661,7 @@ class SQLiteDatabase(BaseDatabase): query = ( update(ApiKey) .where(col(ApiKey.key_id) == key_id) - .values(revoked_at=datetime.now(UTC)) + .values(revoked_at=datetime.now(timezone.utc)) ) result = T.cast(CursorResult, await session.execute(query)) return result.rowcount > 0 @@ -1534,7 +1534,7 @@ class SQLiteDatabase(BaseDatabase): async with self.get_db() as session: session: AsyncSession async with session.begin(): - values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)} + values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} if display_name is not None: values["display_name"] = display_name @@ -1622,7 +1622,7 @@ class SQLiteDatabase(BaseDatabase): async with self.get_db() as session: session: AsyncSession async with session.begin(): - values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)} + values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} if title is not None: values["title"] = title if emoji is not None: diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index 5aa897e79..42fbd23df 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -28,17 +28,12 @@ class FileTokenService: await self._cleanup_expired_tokens() return file_token not in self.staged_files - async def register_file( - self, - file_path: str, - timeout_seconds: float | None = None, - **kwargs, - ) -> str: + async def register_file(self, file_path: str, timeout: float | None = None) -> str: """向令牌服务注册一个文件。 Args: file_path(str): 文件路径 - timeout_seconds(float): 超时时间,单位秒(可选) + timeout(float): 超时时间,单位秒(可选) Returns: str: 一个单次令牌 @@ -63,18 +58,15 @@ class FileTokenService: async with self.lock: await self._cleanup_expired_tokens() - legacy_timeout = kwargs.pop("timeout", None) - if legacy_timeout is not None: - timeout_seconds = float(legacy_timeout) - if not await asyncio.to_thread(os.path.exists, local_path): + if not os.path.exists(local_path): raise FileNotFoundError( f"文件不存在: {local_path} (原始输入: {file_path})", ) file_token = str(uuid.uuid4()) expire_time = time.time() + ( - timeout_seconds if timeout_seconds is not None else self.default_timeout + timeout if timeout is not None else self.default_timeout ) # 存储转换后的真实路径 self.staged_files[file_token] = (local_path, expire_time) @@ -101,6 +93,6 @@ class FileTokenService: raise KeyError(f"无效或过期的文件 token: {file_token}") file_path, _ = self.staged_files.pop(file_token) - if not await asyncio.to_thread(os.path.exists, file_path): + if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") return file_path diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index 10277a926..da919a384 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -1,5 +1,5 @@ import uuid -from datetime import UTC, datetime +from datetime import datetime, timezone from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint @@ -40,10 +40,10 @@ class KnowledgeBase(BaseKBModel, table=True): top_k_dense: int | None = Field(default=50, nullable=True) top_k_sparse: int | None = Field(default=50, nullable=True) top_m_final: int | None = Field(default=5, nullable=True) - created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(UTC), - sa_column_kwargs={"onupdate": datetime.now(UTC)}, + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) doc_count: int = Field(default=0, nullable=False) chunk_count: int = Field(default=0, nullable=False) @@ -83,10 +83,10 @@ class KBDocument(BaseKBModel, table=True): file_path: str = Field(max_length=512, nullable=False) chunk_count: int = Field(default=0, nullable=False) media_count: int = Field(default=0, nullable=False) - created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(UTC), - sa_column_kwargs={"onupdate": datetime.now(UTC)}, + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) @@ -117,4 +117,4 @@ class KBMedia(BaseKBModel, table=True): file_path: str = Field(max_length=512, nullable=False) file_size: int = Field(nullable=False) mime_type: str = Field(max_length=100, nullable=False) - created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 901dcd2ff..15265c38d 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -27,8 +27,7 @@ import json import os import sys import uuid -from enum import StrEnum -from pathlib import Path +from enum import Enum if sys.version_info >= (3, 14): from pydantic import BaseModel @@ -40,17 +39,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 -def _absolute_path(path: str) -> str: - return os.path.abspath(path) - - -def _absolute_path_if_exists(path: str | None) -> str | None: - if not path or not os.path.exists(path): - return None - return os.path.abspath(path) - - -class ComponentType(StrEnum): +class ComponentType(str, Enum): # Basic Segment Types Plain = "Plain" # plain text message Image = "Image" # image @@ -169,18 +158,18 @@ class Record(BaseMessageComponent): return self.file[8:] if self.file.startswith("http"): file_path = await download_image_by_url(self.file) - return await asyncio.to_thread(_absolute_path, file_path) + return os.path.abspath(file_path) if self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) file_path = os.path.join( get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" ) - await asyncio.to_thread(Path(file_path).write_bytes, image_bytes) - return await asyncio.to_thread(_absolute_path, file_path) - local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file) - if local_path: - return local_path + with open(file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(file_path) + if os.path.exists(self.file): + return os.path.abspath(self.file) raise Exception(f"not a valid file: {self.file}") async def convert_to_base64(self) -> str: @@ -194,17 +183,16 @@ class Record(BaseMessageComponent): if not self.file: raise Exception(f"not a valid file: {self.file}") if self.file.startswith("file:///"): - bs64_data = await file_to_base64(self.file[8:]) + bs64_data = file_to_base64(self.file[8:]) elif self.file.startswith("http"): file_path = await download_image_by_url(self.file) - bs64_data = await file_to_base64(file_path) + bs64_data = file_to_base64(file_path) elif self.file.startswith("base64://"): bs64_data = self.file + elif os.path.exists(self.file): + bs64_data = file_to_base64(self.file) else: - try: - bs64_data = await file_to_base64(self.file) - except OSError as exc: - raise Exception(f"not a valid file: {self.file}") from exc + raise Exception(f"not a valid file: {self.file}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data @@ -268,15 +256,11 @@ class Video(BaseMessageComponent): get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}" ) await download_file(url, video_file_path) - local_path = await asyncio.to_thread( - _absolute_path_if_exists, video_file_path - ) - if local_path: - return local_path + if os.path.exists(video_file_path): + return os.path.abspath(video_file_path) raise Exception(f"download failed: {url}") - local_path = await asyncio.to_thread(_absolute_path_if_exists, url) - if local_path: - return local_path + if os.path.exists(url): + return os.path.abspath(url) raise Exception(f"not a valid file: {url}") async def register_to_file_service(self) -> str: @@ -465,18 +449,18 @@ class Image(BaseMessageComponent): return url[8:] if url.startswith("http"): image_file_path = await download_image_by_url(url) - return await asyncio.to_thread(_absolute_path, image_file_path) + return os.path.abspath(image_file_path) if url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) image_file_path = os.path.join( get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg" ) - await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes) - return await asyncio.to_thread(_absolute_path, image_file_path) - local_path = await asyncio.to_thread(_absolute_path_if_exists, url) - if local_path: - return local_path + with open(image_file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(image_file_path) + if os.path.exists(url): + return os.path.abspath(url) raise Exception(f"not a valid file: {url}") async def convert_to_base64(self) -> str: @@ -491,17 +475,16 @@ class Image(BaseMessageComponent): if not url: raise ValueError("No valid file or URL provided") if url.startswith("file:///"): - bs64_data = await file_to_base64(url[8:]) + bs64_data = file_to_base64(url[8:]) elif url.startswith("http"): image_file_path = await download_image_by_url(url) - bs64_data = await file_to_base64(image_file_path) + bs64_data = file_to_base64(image_file_path) elif url.startswith("base64://"): bs64_data = url + elif os.path.exists(url): + bs64_data = file_to_base64(url) else: - try: - bs64_data = await file_to_base64(url) - except OSError as exc: - raise Exception(f"not a valid file: {url}") from exc + raise Exception(f"not a valid file: {url}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data @@ -752,9 +735,8 @@ class File(BaseMessageComponent): ): path = path[1:] - local_path = await asyncio.to_thread(_absolute_path_if_exists, path) - if local_path: - return local_path + if os.path.exists(path): + return os.path.abspath(path) if self.url: await self._download_file() @@ -769,7 +751,7 @@ class File(BaseMessageComponent): and path[2] == ":" ): path = path[1:] - return await asyncio.to_thread(_absolute_path, path) + return os.path.abspath(path) return "" @@ -785,7 +767,7 @@ class File(BaseMessageComponent): filename = f"fileseg_{uuid.uuid4().hex}" file_path = os.path.join(download_dir, filename) await download_file(self.url, file_path) - self.file_ = await asyncio.to_thread(_absolute_path, file_path) + self.file_ = os.path.abspath(file_path) async def register_to_file_service(self) -> str: """将文件注册到文件服务。 diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index e823aac9d..2d9b45cc1 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -254,7 +254,7 @@ class DingtalkPlatformAdapter(Platform): "robotCode": robot_code, } temp_dir = Path(get_astrbot_temp_path()) - await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}" async with ( aiohttp.ClientSession() as session, @@ -412,7 +412,7 @@ class DingtalkPlatformAdapter(Platform): form = aiohttp.FormData() form.add_field( "media", - await asyncio.to_thread(media_file_path.read_bytes), + media_file_path.read_bytes(), filename=media_file_path.name, content_type="application/octet-stream", ) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index 36ee5710b..ebd32c471 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -1,10 +1,15 @@ +import sys from collections.abc import Awaitable, Callable -from typing import override import discord from astrbot import logger +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + # Discord Bot客户端 class DiscordBotClient(discord.Bot): diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 40be87a63..7657962a1 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -1,6 +1,7 @@ import asyncio import re -from typing import Any, cast, override +import sys +from typing import Any, cast import discord from discord.abc import GuildChannel, Messageable, PrivateChannel @@ -26,6 +27,11 @@ from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_re from .client import DiscordBotClient from .discord_platform_event import DiscordPlatformEvent +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + # 注册平台适配器 @register_platform_adapter( diff --git a/astrbot/core/platform/sources/kook/kook_adapter.py b/astrbot/core/platform/sources/kook/kook_adapter.py index b7d047291..1124c6841 100644 --- a/astrbot/core/platform/sources/kook/kook_adapter.py +++ b/astrbot/core/platform/sources/kook/kook_adapter.py @@ -130,7 +130,7 @@ class KookPlatformAdapter(Platform): await asyncio.wait_for( self.client.wait_until_closed(), timeout=1.0 ) - except TimeoutError: + except asyncio.TimeoutError: # 正常超时,继续下一轮 while 检查 continue diff --git a/astrbot/core/platform/sources/kook/kook_client.py b/astrbot/core/platform/sources/kook/kook_client.py index 34078e2ac..9a452a9c3 100644 --- a/astrbot/core/platform/sources/kook/kook_client.py +++ b/astrbot/core/platform/sources/kook/kook_client.py @@ -171,7 +171,7 @@ class KookClient: # 处理不同类型的信令 await self._handle_signal(data) - except TimeoutError: + except asyncio.TimeoutError: # 超时检查,继续循环 continue except websockets.exceptions.ConnectionClosed: @@ -362,14 +362,12 @@ class KookClient: b64_str = file_url.removeprefix("base64://") bytes_data = base64.b64decode(b64_str) - elif file_url.startswith("file://") or await asyncio.to_thread( - os.path.exists, file_url - ): + elif file_url.startswith("file://") or os.path.exists(file_url): file_url = file_url.removeprefix("file:///") file_url = file_url.removeprefix("file://") try: - target_path = await asyncio.to_thread(Path(file_url).resolve) + target_path = Path(file_url).resolve() except Exception as exp: logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"') raise FileNotFoundError( diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 6b500dc5c..be1c81c26 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -429,7 +429,7 @@ class LarkPlatformAdapter(Platform): suffix = Path(file_name).suffix if file_name else default_suffix temp_dir = Path(get_astrbot_temp_path()) - await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) temp_path = ( temp_dir / f"lark_{message_type}_{file_name}_{uuid4().hex[:4]}{suffix}" ) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index a513f4500..92e3a32b9 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -1,10 +1,8 @@ -import asyncio import base64 import json import os import uuid from io import BytesIO -from pathlib import Path import lark_oapi as lark from lark_oapi.api.im.v1 import ( @@ -138,7 +136,7 @@ class LarkMessageEvent(AstrMessageEvent): Returns: 成功返回file_key,失败返回None """ - if not path or not await asyncio.to_thread(os.path.exists, path): + if not path or not os.path.exists(path): logger.error(f"[Lark] 文件不存在: {path}") return None @@ -147,32 +145,36 @@ class LarkMessageEvent(AstrMessageEvent): return None try: - file_obj = BytesIO(await asyncio.to_thread(Path(path).read_bytes)) - body_builder = ( - CreateFileRequestBody.builder() - .file_type(file_type) - .file_name(os.path.basename(path)) - .file(file_obj) - ) - if duration is not None: - body_builder.duration(duration) + with open(path, "rb") as file_obj: + body_builder = ( + CreateFileRequestBody.builder() + .file_type(file_type) + .file_name(os.path.basename(path)) + .file(file_obj) + ) + if duration is not None: + body_builder.duration(duration) - request = ( - CreateFileRequest.builder().request_body(body_builder.build()).build() - ) - response = await lark_client.im.v1.file.acreate(request) + request = ( + CreateFileRequest.builder() + .request_body(body_builder.build()) + .build() + ) + response = await lark_client.im.v1.file.acreate(request) - if not response.success(): - logger.error(f"[Lark] 无法上传文件({response.code}): {response.msg}") - return None + if not response.success(): + logger.error( + f"[Lark] 无法上传文件({response.code}): {response.msg}" + ) + return None - if response.data is None: - logger.error("[Lark] 上传文件成功但未返回数据(data is None)") - return None + if response.data is None: + logger.error("[Lark] 上传文件成功但未返回数据(data is None)") + return None - file_key = response.data.file_key - logger.debug(f"[Lark] 文件上传成功: {file_key}") - return file_key + file_key = response.data.file_key + logger.debug(f"[Lark] 文件上传成功: {file_key}") + return file_key except Exception as e: logger.error(f"[Lark] 无法打开或上传文件: {e}") @@ -205,9 +207,8 @@ class LarkMessageEvent(AstrMessageEvent): temp_dir, f"lark_image_{uuid.uuid4().hex[:8]}.jpg", ) - await asyncio.to_thread( - Path(file_path).write_bytes, BytesIO(image_data).getvalue() - ) + with open(file_path, "wb") as f: + f.write(BytesIO(image_data).getvalue()) else: file_path = comp.file if comp.file else "" @@ -216,9 +217,7 @@ class LarkMessageEvent(AstrMessageEvent): logger.error("[Lark] 图片路径为空,无法上传") continue try: - image_file = BytesIO( - await asyncio.to_thread(Path(file_path).read_bytes) - ) + image_file = open(file_path, "rb") except Exception as e: logger.error(f"[Lark] 无法打开图片文件: {e}") continue @@ -413,9 +412,7 @@ class LarkMessageEvent(AstrMessageEvent): logger.error(f"[Lark] 无法获取音频文件路径: {e}") return - if not original_audio_path or not await asyncio.to_thread( - os.path.exists, original_audio_path - ): + if not original_audio_path or not os.path.exists(original_audio_path): logger.error(f"[Lark] 音频文件不存在: {original_audio_path}") return @@ -445,9 +442,7 @@ class LarkMessageEvent(AstrMessageEvent): ) # 清理转换后的临时音频文件 - if converted_audio_path and await asyncio.to_thread( - os.path.exists, converted_audio_path - ): + if converted_audio_path and os.path.exists(converted_audio_path): try: os.remove(converted_audio_path) logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}") @@ -490,9 +485,7 @@ class LarkMessageEvent(AstrMessageEvent): logger.error(f"[Lark] 无法获取视频文件路径: {e}") return - if not original_video_path or not await asyncio.to_thread( - os.path.exists, original_video_path - ): + if not original_video_path or not os.path.exists(original_video_path): logger.error(f"[Lark] 视频文件不存在: {original_video_path}") return @@ -522,9 +515,7 @@ class LarkMessageEvent(AstrMessageEvent): ) # 清理转换后的临时视频文件 - if converted_video_path and await asyncio.to_thread( - os.path.exists, converted_video_path - ): + if converted_video_path and os.path.exists(converted_video_path): try: os.remove(converted_video_path) logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}") diff --git a/astrbot/core/platform/sources/line/line_event.py b/astrbot/core/platform/sources/line/line_event.py index a16bdd18f..8b82ad182 100644 --- a/astrbot/core/platform/sources/line/line_event.py +++ b/astrbot/core/platform/sources/line/line_event.py @@ -161,7 +161,7 @@ class LineMessageEvent(AstrMessageEvent): try: video_path = await segment.convert_to_file_path() temp_dir = Path(get_astrbot_temp_path()) - await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) thumb_path = temp_dir / f"line_video_preview_{uuid.uuid4().hex}.jpg" process = await asyncio.create_subprocess_exec( @@ -201,8 +201,8 @@ class LineMessageEvent(AstrMessageEvent): async def _resolve_file_size(segment: File) -> int: try: file_path = await segment.get_file(allow_return_url=False) - if file_path and await asyncio.to_thread(os.path.exists, file_path): - return int(await asyncio.to_thread(os.path.getsize, file_path)) + if file_path and os.path.exists(file_path): + return int(os.path.getsize(file_path)) except Exception as e: logger.debug("[LINE] resolve file size failed: %s", e) return 0 diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index e1169decb..fd61c3e50 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -499,8 +499,7 @@ class MisskeyPlatformAdapter(Platform): # 清理临时文件 if local_path and isinstance(local_path, str): data_temp = get_astrbot_temp_path() - if local_path.startswith(data_temp) and await asyncio.to_thread( - os.path.exists, + if local_path.startswith(data_temp) and os.path.exists( local_path, ): try: diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 64728a561..3e5eb9a90 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -3,7 +3,6 @@ import json import random import uuid from collections.abc import Awaitable, Callable -from pathlib import Path from typing import Any, NoReturn try: @@ -556,19 +555,22 @@ class MisskeyAPI: form.add_field("folderId", str(folder_id)) try: - file_bytes = await asyncio.to_thread(Path(file_path).read_bytes) + f = open(file_path, "rb") except FileNotFoundError as e: logger.error(f"[Misskey API] 本地文件不存在: {file_path}") raise APIError(f"File not found: {file_path}") from e - form.add_field("file", file_bytes, filename=filename) - async with self.session.post(url, data=form) as resp: - result = await self._process_response(resp, "drive/files/create") - file_id = FileIDExtractor.extract_file_id(result) - logger.debug( - f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}", - ) - return {"id": file_id, "raw": result} + try: + form.add_field("file", f, filename=filename) + async with self.session.post(url, data=form) as resp: + result = await self._process_response(resp, "drive/files/create") + file_id = FileIDExtractor.extract_file_id(result) + logger.debug( + f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}", + ) + return {"id": file_id, "raw": result} + finally: + f.close() except aiohttp.ClientError as e: logger.error(f"[Misskey API] 文件上传网络错误: {e}") raise APIConnectionError(f"Upload failed: {e}") from e diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 55050a821..868ec8a65 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -339,7 +339,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} # 处理文件数据 - if await asyncio.to_thread(os.path.exists, file_source): + if os.path.exists(file_source): # 读取本地文件 async with aiofiles.open(file_source, "rb") as f: file_content = await f.read() @@ -421,15 +421,15 @@ class QQOfficialMessageEvent(AstrMessageEvent): plain_text += i.text elif isinstance(i, Image) and not image_base64: if i.file and i.file.startswith("file:///"): - image_base64 = await file_to_base64(i.file[8:]) + image_base64 = file_to_base64(i.file[8:]) image_file_path = i.file[8:] elif i.file and i.file.startswith("http"): image_file_path = await download_image_by_url(i.file) - image_base64 = await file_to_base64(image_file_path) + image_base64 = file_to_base64(image_file_path) elif i.file and i.file.startswith("base64://"): image_base64 = i.file elif i.file: - image_base64 = await file_to_base64(i.file) + image_base64 = file_to_base64(i.file) else: raise ValueError("Unsupported image file format") image_base64 = image_base64.removeprefix("base64://") diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 76e3f9d98..2dd72bd0c 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -1,8 +1,9 @@ import asyncio import os import re +import sys import uuid -from typing import cast, override +from typing import cast from apscheduler.schedulers.asyncio import AsyncIOScheduler from telegram import BotCommand, Update @@ -32,6 +33,11 @@ from astrbot.core.utils.media_utils import convert_audio_to_wav from .tg_event import TelegramPlatformEvent +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + @register_platform_adapter("telegram", "telegram 适配器") class TelegramPlatformAdapter(Platform): diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py index 3a1371e72..43072ec1c 100644 --- a/astrbot/core/platform/sources/webchat/message_parts_helper.py +++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py @@ -1,4 +1,3 @@ -import asyncio import json import mimetypes import shutil @@ -140,15 +139,13 @@ async def parse_webchat_message_parts( continue file_path = Path(str(path)) - if verify_media_path_exists and not await asyncio.to_thread(file_path.exists): + if verify_media_path_exists and not file_path.exists(): if strict: raise ValueError(f"file not found: {file_path!s}") continue file_path_str = ( - str(await asyncio.to_thread(file_path.resolve)) - if verify_media_path_exists - else str(file_path) + str(file_path.resolve()) if verify_media_path_exists else str(file_path) ) has_content = True if part_type == "image": @@ -369,7 +366,7 @@ async def message_chain_to_storage_message_parts( attachments_dir: str | Path, ) -> list[dict]: target_dir = Path(attachments_dir) - await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True) + target_dir.mkdir(parents=True, exist_ok=True) parts: list[dict] = [] for comp in message_chain.chain: @@ -445,9 +442,7 @@ async def _copy_file_to_attachment_part( display_name: str | None = None, ) -> dict | None: src_path = Path(file_path) - if not await asyncio.to_thread(src_path.exists) or not await asyncio.to_thread( - src_path.is_file - ): + if not src_path.exists() or not src_path.is_file(): return None suffix = src_path.suffix diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index aacb0e12d..b7da864aa 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,10 +1,8 @@ -import asyncio import base64 import json import os import shutil import uuid -from pathlib import Path from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -82,9 +80,8 @@ class WebChatMessageEvent(AstrMessageEvent): filename = f"{str(uuid.uuid4())}.jpg" path = os.path.join(attachments_dir, filename) image_base64 = await comp.convert_to_base64() - await asyncio.to_thread( - Path(path).write_bytes, base64.b64decode(image_base64) - ) + with open(path, "wb") as f: + f.write(base64.b64decode(image_base64)) data = f"[IMAGE]{filename}" await web_chat_back_queue.put( { @@ -99,9 +96,8 @@ class WebChatMessageEvent(AstrMessageEvent): filename = f"{str(uuid.uuid4())}.wav" path = os.path.join(attachments_dir, filename) record_base64 = await comp.convert_to_base64() - await asyncio.to_thread( - Path(path).write_bytes, base64.b64decode(record_base64) - ) + with open(path, "wb") as f: + f.write(base64.b64decode(record_base64)) data = f"[RECORD]{filename}" await web_chat_back_queue.put( { diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index b77a0da3e..6647db89f 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -1,9 +1,9 @@ import asyncio import os +import sys import uuid from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import Any, cast, override +from typing import Any, cast import quart from requests import Response @@ -33,6 +33,11 @@ from .wecom_event import WecomPlatformEvent from .wecom_kf import WeChatKF from .wecom_kf_message import WeChatKFMessage +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: @@ -341,7 +346,8 @@ class WecomPlatformAdapter(Platform): ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr") - await asyncio.to_thread(Path(path).write_bytes, resp.content) + with open(path, "wb") as f: + f.write(resp.content) try: path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav") @@ -396,7 +402,8 @@ class WecomPlatformAdapter(Platform): ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg") - await asyncio.to_thread(Path(path).write_bytes, resp.content) + with open(path, "wb") as f: + f.write(resp.content) abm.message = [Image(file=path, url=path)] elif msgtype == "voice": media_id = msg.get("voice", {}).get("media_id", "") @@ -408,7 +415,8 @@ class WecomPlatformAdapter(Platform): temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr") - await asyncio.to_thread(Path(path).write_bytes, resp.content) + with open(path, "wb") as f: + f.write(resp.content) try: path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav") diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 83a91a872..7aee26e47 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -12,13 +12,6 @@ from astrbot.core.utils.media_utils import convert_audio_to_amr from .wecom_kf_message import WeChatKFMessage -def _upload_media_from_path( - client: WeChatClient, media_type: str, file_path: str -) -> dict: - with open(file_path, "rb") as f: - return client.media.upload(media_type, f) - - class WecomPlatformEvent(AstrMessageEvent): def __init__( self, @@ -107,52 +100,45 @@ class WecomPlatformEvent(AstrMessageEvent): elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "image", - img_path, + with open(img_path, "rb") as f: + try: + response = self.client.media.upload("image", f) + except Exception as e: + logger.error(f"微信客服上传图片失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传图片失败: {e}"), + ) + return + logger.debug(f"微信客服上传图片返回: {response}") + kf_message_api.send_image( + user_id, + self.get_self_id(), + response["media_id"], ) - except Exception as e: - logger.error(f"微信客服上传图片失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传图片失败: {e}"), - ) - return - logger.debug(f"微信客服上传图片返回: {response}") - kf_message_api.send_image( - user_id, - self.get_self_id(), - response["media_id"], - ) elif isinstance(comp, Record): record_path = await comp.convert_to_file_path() record_path_amr = await convert_audio_to_amr(record_path) try: - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "voice", - record_path_amr, + with open(record_path_amr, "rb") as f: + try: + response = self.client.media.upload("voice", f) + except Exception as e: + logger.error(f"微信客服上传语音失败: {e}") + await self.send( + MessageChain().message( + f"微信客服上传语音失败: {e}" + ), + ) + return + logger.info(f"微信客服上传语音返回: {response}") + kf_message_api.send_voice( + user_id, + self.get_self_id(), + response["media_id"], ) - except Exception as e: - logger.error(f"微信客服上传语音失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传语音失败: {e}"), - ) - return - logger.info(f"微信客服上传语音返回: {response}") - kf_message_api.send_voice( - user_id, - self.get_self_id(), - response["media_id"], - ) finally: - if record_path_amr != record_path and await asyncio.to_thread( - os.path.exists, + if record_path_amr != record_path and os.path.exists( record_path_amr, ): try: @@ -162,47 +148,39 @@ class WecomPlatformEvent(AstrMessageEvent): elif isinstance(comp, File): file_path = await comp.get_file() - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "file", - file_path, + with open(file_path, "rb") as f: + try: + response = self.client.media.upload("file", f) + except Exception as e: + logger.error(f"微信客服上传文件失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传文件失败: {e}"), + ) + return + logger.debug(f"微信客服上传文件返回: {response}") + kf_message_api.send_file( + user_id, + self.get_self_id(), + response["media_id"], ) - except Exception as e: - logger.error(f"微信客服上传文件失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传文件失败: {e}"), - ) - return - logger.debug(f"微信客服上传文件返回: {response}") - kf_message_api.send_file( - user_id, - self.get_self_id(), - response["media_id"], - ) elif isinstance(comp, Video): video_path = await comp.convert_to_file_path() - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "video", - video_path, + with open(video_path, "rb") as f: + try: + response = self.client.media.upload("video", f) + except Exception as e: + logger.error(f"微信客服上传视频失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传视频失败: {e}"), + ) + return + logger.debug(f"微信客服上传视频返回: {response}") + kf_message_api.send_video( + user_id, + self.get_self_id(), + response["media_id"], ) - except Exception as e: - logger.error(f"微信客服上传视频失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传视频失败: {e}"), - ) - return - logger.debug(f"微信客服上传视频返回: {response}") - kf_message_api.send_video( - user_id, - self.get_self_id(), - response["media_id"], - ) else: logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") else: @@ -221,52 +199,45 @@ class WecomPlatformEvent(AstrMessageEvent): elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "image", - img_path, + with open(img_path, "rb") as f: + try: + response = self.client.media.upload("image", f) + except Exception as e: + logger.error(f"企业微信上传图片失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传图片失败: {e}"), + ) + return + logger.debug(f"企业微信上传图片返回: {response}") + self.client.message.send_image( + message_obj.self_id, + message_obj.session_id, + response["media_id"], ) - except Exception as e: - logger.error(f"企业微信上传图片失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传图片失败: {e}"), - ) - return - logger.debug(f"企业微信上传图片返回: {response}") - self.client.message.send_image( - message_obj.self_id, - message_obj.session_id, - response["media_id"], - ) elif isinstance(comp, Record): record_path = await comp.convert_to_file_path() record_path_amr = await convert_audio_to_amr(record_path) try: - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "voice", - record_path_amr, + with open(record_path_amr, "rb") as f: + try: + response = self.client.media.upload("voice", f) + except Exception as e: + logger.error(f"企业微信上传语音失败: {e}") + await self.send( + MessageChain().message( + f"企业微信上传语音失败: {e}" + ), + ) + return + logger.info(f"企业微信上传语音返回: {response}") + self.client.message.send_voice( + message_obj.self_id, + message_obj.session_id, + response["media_id"], ) - except Exception as e: - logger.error(f"企业微信上传语音失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传语音失败: {e}"), - ) - return - logger.info(f"企业微信上传语音返回: {response}") - self.client.message.send_voice( - message_obj.self_id, - message_obj.session_id, - response["media_id"], - ) finally: - if record_path_amr != record_path and await asyncio.to_thread( - os.path.exists, + if record_path_amr != record_path and os.path.exists( record_path_amr, ): try: @@ -276,47 +247,39 @@ class WecomPlatformEvent(AstrMessageEvent): elif isinstance(comp, File): file_path = await comp.get_file() - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "file", - file_path, + with open(file_path, "rb") as f: + try: + response = self.client.media.upload("file", f) + except Exception as e: + logger.error(f"企业微信上传文件失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传文件失败: {e}"), + ) + return + logger.debug(f"企业微信上传文件返回: {response}") + self.client.message.send_file( + message_obj.self_id, + message_obj.session_id, + response["media_id"], ) - except Exception as e: - logger.error(f"企业微信上传文件失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传文件失败: {e}"), - ) - return - logger.debug(f"企业微信上传文件返回: {response}") - self.client.message.send_file( - message_obj.self_id, - message_obj.session_id, - response["media_id"], - ) elif isinstance(comp, Video): video_path = await comp.convert_to_file_path() - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "video", - video_path, + with open(video_path, "rb") as f: + try: + response = self.client.media.upload("video", f) + except Exception as e: + logger.error(f"企业微信上传视频失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传视频失败: {e}"), + ) + return + logger.debug(f"企业微信上传视频返回: {response}") + self.client.message.send_video( + message_obj.self_id, + message_obj.session_id, + response["media_id"], ) - except Exception as e: - logger.error(f"企业微信上传视频失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传视频失败: {e}"), - ) - return - logger.debug(f"企业微信上传视频返回: {response}") - self.client.message.send_video( - message_obj.self_id, - message_obj.session_id, - response["media_id"], - ) else: logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index 6dbfda7b4..f7cbe380d 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -2,6 +2,7 @@ 提供常量定义、工具函数和辅助方法 """ +import asyncio import base64 import hashlib import secrets @@ -173,7 +174,7 @@ async def process_encrypted_image( response.raise_for_status() encrypted_data = await response.read() logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) - except (TimeoutError, aiohttp.ClientError) as e: + except (aiohttp.ClientError, asyncio.TimeoutError) as e: error_msg = f"下载图片失败: {e!s}" logger.error(error_msg) return False, error_msg diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py index c305411d4..6f42f264b 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import base64 import hashlib import mimetypes @@ -104,9 +103,7 @@ class WecomAIBotWebhookClient: async def upload_media( self, file_path: Path, media_type: Literal["file", "voice"] ) -> str: - if not await asyncio.to_thread(file_path.exists) or not await asyncio.to_thread( - file_path.is_file - ): + if not file_path.exists() or not file_path.is_file(): raise WecomAIBotWebhookError(f"文件不存在: {file_path}") content_type = ( @@ -115,7 +112,7 @@ class WecomAIBotWebhookClient: form = aiohttp.FormData() form.add_field( "media", - await asyncio.to_thread(file_path.read_bytes), + file_path.read_bytes(), filename=file_path.name, content_type=content_type, ) diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index 59f8ebd8c..c01355974 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -1,10 +1,10 @@ import asyncio import os +import sys import time import uuid from collections.abc import Callable, Coroutine -from pathlib import Path -from typing import Any, cast, override +from typing import Any, cast import quart from requests import Response @@ -32,6 +32,11 @@ from astrbot.core.utils.webhook_utils import log_webhook_info from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + class WeixinOfficialAccountServer: def __init__( @@ -374,7 +379,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform): ) # wait for 180s logger.debug(f"Got future result: {result}") return result - except TimeoutError: + except asyncio.TimeoutError: logger.info(f"callback 处理消息超时: message_id={msg.id}") return create_reply("处理消息超时,请稍后再试。", msg) except Exception as e: @@ -463,7 +468,8 @@ class WeixinOfficialAccountPlatformAdapter(Platform): ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr") - await asyncio.to_thread(Path(path).write_bytes, resp.content) + with open(path, "wb") as f: + f.write(resp.content) try: path_wav = os.path.join( diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index 0797e4dee..ae536593c 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -12,13 +12,6 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.core.utils.media_utils import convert_audio_to_amr -def _upload_media_from_path( - client: WeChatClient, media_type: str, file_path: str -) -> dict: - with open(file_path, "rb") as f: - return client.media.upload(media_type, f) - - class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): def __init__( self, @@ -108,63 +101,24 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "image", - img_path, - ) - except Exception as e: - logger.error(f"微信公众平台上传图片失败: {e}") - await self.send( - MessageChain().message(f"微信公众平台上传图片失败: {e}"), - ) - return - logger.debug(f"微信公众平台上传图片返回: {response}") - - if active_send_mode: - self.client.message.send_image( - message_obj.sender.user_id, - response["media_id"], - ) - else: - reply = ImageReply( - media_id=response["media_id"], - message=cast(dict, self.message_obj.raw_message)["message"], - ) - xml = reply.render() - future = cast(dict, self.message_obj.raw_message)["future"] - assert isinstance(future, asyncio.Future) - future.set_result(xml) - - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - record_path_amr = await convert_audio_to_amr(record_path) - - try: + with open(img_path, "rb") as f: try: - response = await asyncio.to_thread( - _upload_media_from_path, - self.client, - "voice", - record_path_amr, - ) + response = self.client.media.upload("image", f) except Exception as e: - logger.error(f"微信公众平台上传语音失败: {e}") + logger.error(f"微信公众平台上传图片失败: {e}") await self.send( - MessageChain().message(f"微信公众平台上传语音失败: {e}"), + MessageChain().message(f"微信公众平台上传图片失败: {e}"), ) return - logger.info(f"微信公众平台上传语音返回: {response}") + logger.debug(f"微信公众平台上传图片返回: {response}") if active_send_mode: - self.client.message.send_voice( + self.client.message.send_image( message_obj.sender.user_id, response["media_id"], ) else: - reply = VoiceReply( + reply = ImageReply( media_id=response["media_id"], message=cast(dict, self.message_obj.raw_message)["message"], ) @@ -172,9 +126,44 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) + + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + record_path_amr = await convert_audio_to_amr(record_path) + + try: + with open(record_path_amr, "rb") as f: + try: + response = self.client.media.upload("voice", f) + except Exception as e: + logger.error(f"微信公众平台上传语音失败: {e}") + await self.send( + MessageChain().message( + f"微信公众平台上传语音失败: {e}" + ), + ) + return + logger.info(f"微信公众平台上传语音返回: {response}") + + if active_send_mode: + self.client.message.send_voice( + message_obj.sender.user_id, + response["media_id"], + ) + else: + reply = VoiceReply( + media_id=response["media_id"], + message=cast(dict, self.message_obj.raw_message)[ + "message" + ], + ) + xml = reply.render() + future = cast(dict, self.message_obj.raw_message)["future"] + assert isinstance(future, asyncio.Future) + future.set_result(xml) finally: - if record_path_amr != record_path and await asyncio.to_thread( - os.path.exists, record_path_amr + if record_path_amr != record_path and os.path.exists( + record_path_amr ): try: os.remove(record_path_amr) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index aea04645d..20c5a7947 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,11 +1,9 @@ from __future__ import annotations -import asyncio import base64 import enum import json from dataclasses import dataclass, field -from pathlib import Path from typing import Any from anthropic.types import Message as AnthropicMessage @@ -220,10 +218,9 @@ class ProviderRequest: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - image_bs64 = base64.b64encode( - await asyncio.to_thread(Path(image_url).read_bytes) - ).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 return "" diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 4239a8b47..2fd391fc9 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -8,7 +8,6 @@ import threading import urllib.parse from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass -from pathlib import Path from types import MappingProxyType from typing import Any @@ -199,7 +198,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return True, "" return False, f"HTTP {response.status}: {response.reason}" - except TimeoutError: + except asyncio.TimeoutError: return False, f"连接超时: {timeout}秒" except Exception as e: return False, f"{e!s}" @@ -378,24 +377,15 @@ class FunctionToolManager: data_dir = get_astrbot_data_path() mcp_json_file = os.path.join(data_dir, "mcp_server.json") - if not await asyncio.to_thread(os.path.exists, mcp_json_file): + if not os.path.exists(mcp_json_file): # 配置文件不存在错误处理 - config_text = json.dumps(DEFAULT_MCP_CONFIG, ensure_ascii=False, indent=4) - await asyncio.to_thread( - Path(mcp_json_file).write_text, - config_text, - encoding="utf-8", - ) + with open(mcp_json_file, "w", encoding="utf-8") as f: + json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return MCPInitSummary(total=0, success=0, failed=[]) - mcp_json_content = await asyncio.to_thread( - Path(mcp_json_file).read_text, - encoding="utf-8", - ) - mcp_server_json_obj: dict[str, dict] = json.loads(mcp_json_content)[ - "mcpServers" - ] + with open(mcp_json_file, encoding="utf-8") as f: + mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"] init_timeout_value = _resolve_timeout( timeout=init_timeout, @@ -469,7 +459,7 @@ class FunctionToolManager: cfg: dict, *, shutdown_event: asyncio.Event | None = None, - timeout_seconds: float, + timeout: float, ) -> None: """Initialize MCP server with timeout and register task/event together. @@ -479,7 +469,7 @@ class FunctionToolManager: async with self._runtime_lock: if name in self._mcp_server_runtime or name in self._mcp_starting: logger.warning( - f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout_seconds:g})。" + f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。" ) self._log_safe_mcp_debug_config(cfg) return @@ -492,11 +482,11 @@ class FunctionToolManager: try: mcp_client = await asyncio.wait_for( self._init_mcp_client(name, cfg), - timeout=timeout_seconds, + timeout=timeout, ) - except TimeoutError as exc: + except asyncio.TimeoutError as exc: raise MCPInitTimeoutError( - f"MCP 服务 {name} 初始化超时({timeout_seconds:g} 秒)" + f"MCP 服务 {name} 初始化超时({timeout:g} 秒)" ) from exc except Exception: logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) @@ -529,7 +519,7 @@ class FunctionToolManager: async def _shutdown_runtimes( self, runtimes: list[_MCPServerRuntime], - timeout_seconds: float, + timeout: float, *, strict: bool = True, ) -> list[str]: @@ -548,9 +538,9 @@ class FunctionToolManager: try: results = await asyncio.wait_for( asyncio.gather(*lifecycle_tasks, return_exceptions=True), - timeout=timeout_seconds, + timeout=timeout, ) - except TimeoutError: + except asyncio.TimeoutError: pending_names = [ runtime.name for runtime in runtimes @@ -561,10 +551,10 @@ class FunctionToolManager: task.cancel() await asyncio.gather(*lifecycle_tasks, return_exceptions=True) if strict: - raise MCPShutdownTimeoutError(pending_names, timeout_seconds) + raise MCPShutdownTimeoutError(pending_names, timeout) logger.warning( "MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s", - f"{timeout_seconds:g}", + f"{timeout:g}", ", ".join(pending_names), ) return pending_names @@ -675,8 +665,7 @@ class FunctionToolManager: name: str, config: dict, shutdown_event: asyncio.Event | None = None, - timeout_seconds: float | int | str | None = None, - **kwargs: Any, + timeout: float | int | str | None = None, ) -> None: """Enable a new MCP server and initialize it. @@ -684,22 +673,18 @@ class FunctionToolManager: name: The name of the MCP server. config: Configuration for the MCP server. shutdown_event: Event to signal when the MCP client should shut down. - timeout_seconds: Timeout in seconds for initialization. + timeout: Timeout in seconds for initialization. Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout). Raises: MCPInitTimeoutError: If initialization does not complete within timeout. Exception: If there is an error during initialization. """ - legacy_timeout = kwargs.pop("timeout", None) - if legacy_timeout is not None: - timeout_seconds = legacy_timeout - - if timeout_seconds is None: + if timeout is None: timeout_value = self._enable_timeout_default else: timeout_value = _resolve_timeout( - timeout=timeout_seconds, + timeout=timeout, env_name=ENABLE_MCP_TIMEOUT_ENV, default=self._enable_timeout_default, ) @@ -707,45 +692,36 @@ class FunctionToolManager: name=name, cfg=config, shutdown_event=shutdown_event, - timeout_seconds=timeout_value, + timeout=timeout_value, ) async def disable_mcp_server( self, name: str | None = None, - timeout_seconds: float = 10, - **kwargs: Any, + timeout: float = 10, ) -> None: """Disable an MCP server by its name. Args: name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled. - timeout_seconds (int): Timeout. + timeout (int): Timeout. Raises: MCPShutdownTimeoutError: If shutdown does not complete within timeout. Only raised when disabling a specific server (name is not None). """ - legacy_timeout = kwargs.pop("timeout", None) - if legacy_timeout is not None: - timeout_seconds = float(legacy_timeout) - if name: async with self._runtime_lock: runtime = self._mcp_server_runtime.get(name) if runtime is None: return - await self._shutdown_runtimes( - [runtime], timeout_seconds=timeout_seconds, strict=True - ) + await self._shutdown_runtimes([runtime], timeout, strict=True) else: async with self._runtime_lock: runtimes = list(self._mcp_server_runtime.values()) - await self._shutdown_runtimes( - runtimes, timeout_seconds=timeout_seconds, strict=False - ) + await self._shutdown_runtimes(runtimes, timeout, strict=False) def _warn_on_timeout_mismatch( self, diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 08c525485..901efd005 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,8 +2,7 @@ import abc import asyncio import os from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Any +from typing import TypeAlias, Union from astrbot.core.agent.message import ContentPart, Message from astrbot.core.agent.tool import ToolSet @@ -16,9 +15,13 @@ from astrbot.core.provider.entities import ( from astrbot.core.provider.register import provider_cls_map from astrbot.core.utils.astrbot_path import get_astrbot_path -type Providers = ( - "Provider" | "STTProvider" | "TTSProvider" | "EmbeddingProvider" | "RerankProvider" -) +Providers: TypeAlias = Union[ + "Provider", + "STTProvider", + "TTSProvider", + "EmbeddingProvider", + "RerankProvider", +] class AbstractProvider(abc.ABC): @@ -185,13 +188,10 @@ class Provider(AbstractProvider): return dicts - async def test(self, timeout_seconds: float = 45.0, **kwargs: Any) -> None: - legacy_timeout = kwargs.pop("timeout", None) - if legacy_timeout is not None: - timeout_seconds = float(legacy_timeout) + async def test(self, timeout: float = 45.0) -> None: await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), - timeout=timeout_seconds, + timeout=timeout, ) @@ -268,9 +268,8 @@ class TTSProvider(AbstractProvider): # 调用原有的 get_audio 方法获取音频文件路径 audio_path = await self.get_audio(accumulated_text) # 读取音频文件内容 - audio_data = await asyncio.to_thread( - Path(audio_path).read_bytes - ) + with open(audio_path, "rb") as f: + audio_data = f.read() await audio_queue.put((accumulated_text, audio_data)) except Exception: # 出错时也要发送 None 结束标记 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 7f7a51859..ec3c395a4 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,8 +1,6 @@ -import asyncio import base64 import json from collections.abc import AsyncGenerator -from pathlib import Path import anthropic import httpx @@ -639,10 +637,11 @@ class ProviderAnthropic(Provider): except Exception: mime_type = "image/jpeg" return f"data:{mime_type};base64,{raw_base64}", mime_type - image_bytes = await asyncio.to_thread(Path(image_url).read_bytes) - mime_type = self._detect_image_mime_type(image_bytes) - image_bs64 = base64.b64encode(image_bytes).decode("utf-8") - return f"data:{mime_type};base64,{image_bs64}", mime_type + with open(image_url, "rb") as f: + image_bytes = f.read() + mime_type = self._detect_image_mime_type(image_bytes) + image_bs64 = base64.b64encode(image_bytes).decode("utf-8") + return f"data:{mime_type};base64,{image_bs64}", mime_type return "", "image/jpeg" def get_current_key(self) -> str: diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index bd12f37b9..9b6816859 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,7 +3,6 @@ import base64 import logging import os import uuid -from pathlib import Path import aiohttp import dashscope @@ -60,7 +59,8 @@ class ProviderDashscopeTTSAPI(TTSProvider): ) path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}") - await asyncio.to_thread(Path(path).write_bytes, audio_bytes) + with open(path, "wb") as f: + f.write(audio_bytes) return path def _call_qwen_tts(self, model: str, text: str): @@ -129,7 +129,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): ) as response, ): return await response.read() - except (TimeoutError, aiohttp.ClientError, OSError) as e: + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e: logging.exception(f"Failed to download audio from URL {url}: {e}") return None diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 147c925ec..503bd275b 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -1,129 +1,126 @@ -import asyncio -import os -import subprocess -import uuid - -import edge_tts - -from astrbot.core import logger -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - -""" -edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 -``` -pip install edge_tts -``` -Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot -""" - - -@register_provider_adapter( - "edge_tts", - "Microsoft Edge TTS", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderEdgeTTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - - # 设置默认语音,如果没有指定则使用中文小萱 - self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") - self.rate = provider_config.get("rate") - self.volume = provider_config.get("volume") - self.pitch = provider_config.get("pitch") - self.timeout = provider_config.get("timeout", 30) - - self.proxy = os.getenv("https_proxy", None) - - self.set_model("edge_tts") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") - wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") - - # 构建 Edge TTS 参数 - kwargs = {"text": text, "voice": self.voice} - if self.rate: - kwargs["rate"] = self.rate - if self.volume: - kwargs["volume"] = self.volume - if self.pitch: - kwargs["pitch"] = self.pitch - - try: - communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) - await communicate.save(mp3_path) - - try: - from pyffmpeg import FFmpeg - - ff = FFmpeg() - ff.convert(input_file=mp3_path, output_file=wav_path) - except Exception as e: - logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") - # use ffmpeg command line - - # 使用ffmpeg将MP3转换为标准WAV格式 - p = await asyncio.create_subprocess_exec( - "ffmpeg", - "-y", # 覆盖输出文件 - "-i", - mp3_path, # 输入文件 - "-acodec", - "pcm_s16le", # 16位PCM编码 - "-ar", - "24000", # 采样率24kHz (适合微信语音) - "-ac", - "1", # 单声道 - "-af", - "apad=pad_dur=2", # 确保输出时长准确 - "-fflags", - "+genpts", # 强制生成时间戳 - "-hide_banner", # 隐藏版本信息 - wav_path, # 输出文件 - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - # 等待进程完成并获取输出 - stdout, stderr = await p.communicate() - logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") - logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") - logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") - - os.remove(mp3_path) - if ( - await asyncio.to_thread(os.path.exists, wav_path) - and await asyncio.to_thread(os.path.getsize, wav_path) > 0 - ): - return wav_path - logger.error("生成的WAV文件不存在或为空") - raise RuntimeError("生成的WAV文件不存在或为空") - - except subprocess.CalledProcessError as e: - logger.error( - f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", - ) - try: - if await asyncio.to_thread(os.path.exists, mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"FFmpeg 转换失败: {e!s}") - - except Exception as e: - logger.error(f"音频生成失败: {e!s}") - try: - if await asyncio.to_thread(os.path.exists, mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"音频生成失败: {e!s}") +import asyncio +import os +import subprocess +import uuid + +import edge_tts + +from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + +""" +edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 +``` +pip install edge_tts +``` +Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot +""" + + +@register_provider_adapter( + "edge_tts", + "Microsoft Edge TTS", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderEdgeTTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + + # 设置默认语音,如果没有指定则使用中文小萱 + self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") + self.rate = provider_config.get("rate") + self.volume = provider_config.get("volume") + self.pitch = provider_config.get("pitch") + self.timeout = provider_config.get("timeout", 30) + + self.proxy = os.getenv("https_proxy", None) + + self.set_model("edge_tts") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") + wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") + + # 构建 Edge TTS 参数 + kwargs = {"text": text, "voice": self.voice} + if self.rate: + kwargs["rate"] = self.rate + if self.volume: + kwargs["volume"] = self.volume + if self.pitch: + kwargs["pitch"] = self.pitch + + try: + communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) + await communicate.save(mp3_path) + + try: + from pyffmpeg import FFmpeg + + ff = FFmpeg() + ff.convert(input_file=mp3_path, output_file=wav_path) + except Exception as e: + logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") + # use ffmpeg command line + + # 使用ffmpeg将MP3转换为标准WAV格式 + p = await asyncio.create_subprocess_exec( + "ffmpeg", + "-y", # 覆盖输出文件 + "-i", + mp3_path, # 输入文件 + "-acodec", + "pcm_s16le", # 16位PCM编码 + "-ar", + "24000", # 采样率24kHz (适合微信语音) + "-ac", + "1", # 单声道 + "-af", + "apad=pad_dur=2", # 确保输出时长准确 + "-fflags", + "+genpts", # 强制生成时间戳 + "-hide_banner", # 隐藏版本信息 + wav_path, # 输出文件 + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # 等待进程完成并获取输出 + stdout, stderr = await p.communicate() + logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") + logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") + logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") + + os.remove(mp3_path) + if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: + return wav_path + logger.error("生成的WAV文件不存在或为空") + raise RuntimeError("生成的WAV文件不存在或为空") + + except subprocess.CalledProcessError as e: + logger.error( + f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", + ) + try: + if os.path.exists(mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"FFmpeg 转换失败: {e!s}") + + except Exception as e: + logger.error(f"音频生成失败: {e!s}") + try: + if os.path.exists(mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"音频生成失败: {e!s}") diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index c1b62ef39..35945b7b6 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -1,8 +1,6 @@ -import asyncio import os import re import uuid -from pathlib import Path from typing import Annotated, Literal import ormsgpack @@ -161,10 +159,9 @@ class ProviderFishAudioTTSAPI(TTSProvider): if response.status_code == 200 and response.headers.get( "content-type", "" ).startswith("audio/"): - audio_data = bytearray() - async for chunk in response.aiter_bytes(): - audio_data.extend(chunk) - await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data)) + with open(path, "wb") as f: + async for chunk in response.aiter_bytes(): + f.write(chunk) return path error_bytes = await response.aread() error_text = error_bytes.decode("utf-8", errors="replace")[:1024] diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 25e5c6e3a..9557f3dbc 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,7 +4,6 @@ import json import logging import random from collections.abc import AsyncGenerator -from pathlib import Path from typing import cast from google import genai @@ -925,10 +924,9 @@ class ProviderGoogleGenAI(Provider): """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - image_bs64 = base64.b64encode( - await asyncio.to_thread(Path(image_url).read_bytes) - ).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 async def terminate(self) -> None: if self.client: diff --git a/astrbot/core/provider/sources/genie_tts.py b/astrbot/core/provider/sources/genie_tts.py index 62b4b3f81..8f9b6d91d 100644 --- a/astrbot/core/provider/sources/genie_tts.py +++ b/astrbot/core/provider/sources/genie_tts.py @@ -1,7 +1,6 @@ import asyncio import os import uuid -from pathlib import Path from astrbot.core import logger from astrbot.core.provider.entities import ProviderType @@ -73,7 +72,7 @@ class GenieTTSProvider(TTSProvider): try: await loop.run_in_executor(None, _generate, path) - if await asyncio.to_thread(os.path.exists, path): + if os.path.exists(path): return path raise RuntimeError("Genie TTS did not save to file.") @@ -110,8 +109,9 @@ class GenieTTSProvider(TTSProvider): await loop.run_in_executor(None, _generate, path, text) - if await asyncio.to_thread(os.path.exists, path): - audio_data = await asyncio.to_thread(Path(path).read_bytes) + if os.path.exists(path): + with open(path, "rb") as f: + audio_data = f.read() # Put (text, bytes) into queue so frontend can display text await audio_queue.put((text, audio_data)) diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index a9ebfe9a6..fc8bccea8 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -1,7 +1,6 @@ import asyncio import os import uuid -from pathlib import Path import aiohttp @@ -130,7 +129,8 @@ class ProviderGSVTTS(TTSProvider): result = await self._make_request(endpoint, params) if isinstance(result, bytes): - await asyncio.to_thread(Path(path).write_bytes, result) + with open(path, "wb") as f: + f.write(result) return path raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index f92485b72..425e801f4 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,62 +1,59 @@ -import asyncio -import os -import urllib.parse -import uuid -from pathlib import Path - -import aiohttp - -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - - -@register_provider_adapter( - "gsvi_tts_api", - "GSVI TTS API", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderGSVITTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") - self.api_base = self.api_base.removesuffix("/") - self.character = provider_config.get("character") - self.emotion = provider_config.get("emotion") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") - params = {"text": text} - - if self.character: - params["character"] = self.character - if self.emotion: - params["emotion"] = self.emotion - - query_parts = [] - for key, value in params.items(): - encoded_value = urllib.parse.quote(str(value)) - query_parts.append(f"{key}={encoded_value}") - - url = f"{self.api_base}/tts?{'&'.join(query_parts)}" - - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - await asyncio.to_thread( - Path(path).write_bytes, await response.read() - ) - else: - error_text = await response.text() - raise Exception( - f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", - ) - - return path +import os +import urllib.parse +import uuid + +import aiohttp + +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "gsvi_tts_api", + "GSVI TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderGSVITTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") + self.api_base = self.api_base.removesuffix("/") + self.character = provider_config.get("character") + self.emotion = provider_config.get("emotion") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") + params = {"text": text} + + if self.character: + params["character"] = self.character + if self.emotion: + params["emotion"] = self.emotion + + query_parts = [] + for key, value in params.items(): + encoded_value = urllib.parse.quote(str(value)) + query_parts.append(f"{key}={encoded_value}") + + url = f"{self.api_base}/tts?{'&'.join(query_parts)}" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + with open(path, "wb") as f: + f.write(await response.read()) + else: + error_text = await response.text() + raise Exception( + f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", + ) + + return path diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index ad2e34536..69860111c 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -1,9 +1,7 @@ -import asyncio import json import os import uuid from collections.abc import AsyncIterator -from pathlib import Path import aiohttp @@ -157,7 +155,8 @@ class ProviderMiniMaxTTSAPI(TTSProvider): audio = await self._audio_play(audio_stream) # 结果保存至文件 - await asyncio.to_thread(Path(path).write_bytes, audio) + with open(path, "wb") as file: + file.write(audio) return path diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 3f0c007b3..adee24073 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -5,7 +5,6 @@ import json import random import re from collections.abc import AsyncGenerator -from pathlib import Path from typing import Any import httpx @@ -950,10 +949,9 @@ class ProviderOpenAIOfficial(Provider): """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - image_bs64 = base64.b64encode( - await asyncio.to_thread(Path(image_url).read_bytes) - ).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 async def terminate(self): if self.client: diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index 35ac1d5a8..217b18925 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,7 +1,5 @@ -import asyncio import os import uuid -from pathlib import Path import httpx from openai import NOT_GIVEN, AsyncOpenAI @@ -56,10 +54,9 @@ class ProviderOpenAITTSAPI(TTSProvider): response_format="wav", input=text, ) as response: - audio_data = bytearray() - async for chunk in response.iter_bytes(chunk_size=1024): - audio_data.extend(chunk) - await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data)) + with open(path, "wb") as f: + async for chunk in response.iter_bytes(chunk_size=1024): + f.write(chunk) return path async def terminate(self): diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index b77665796..af6c0f631 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -53,12 +53,14 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_dir = Path(get_astrbot_temp_path()) - await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) return str(temp_dir / timestamp) async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" - file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] + with open(file_path, "rb") as f: + file_header = f.read(8) + if silk_header in file_header: return True return False @@ -74,7 +76,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): await download_file(audio_url, path) audio_url = path - if not await asyncio.to_thread(os.path.isfile, audio_url): + if not os.path.isfile(audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith((".amr", ".silk")) or is_tencent: diff --git a/astrbot/core/provider/sources/volcengine_tts.py b/astrbot/core/provider/sources/volcengine_tts.py index 508220071..349815907 100644 --- a/astrbot/core/provider/sources/volcengine_tts.py +++ b/astrbot/core/provider/sources/volcengine_tts.py @@ -4,7 +4,6 @@ import json import os import traceback import uuid -from pathlib import Path import aiohttp @@ -101,9 +100,10 @@ class ProviderVolcengineTTS(TTSProvider): f"volcengine_tts_{uuid.uuid4()}.mp3", ) - await asyncio.to_thread( - Path(file_path).write_bytes, - audio_data, + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: open(file_path, "wb").write(audio_data), ) return file_path diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 00c87075d..386da063d 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -1,7 +1,5 @@ -import asyncio import os import uuid -from pathlib import Path from openai import NOT_GIVEN, AsyncOpenAI @@ -46,7 +44,8 @@ class ProviderOpenAIWhisperAPI(STTProvider): amr_header = b"#!AMR" try: - file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] + with open(file_path, "rb") as f: + file_header = f.read(8) except FileNotFoundError: return None @@ -74,7 +73,7 @@ class ProviderOpenAIWhisperAPI(STTProvider): await download_file(audio_url, path) audio_url = path - if not await asyncio.to_thread(os.path.exists, audio_url): + if not os.path.exists(audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: @@ -101,14 +100,13 @@ class ProviderOpenAIWhisperAPI(STTProvider): audio_url = output_path - audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes) result = await self.client.audio.transcriptions.create( model=self.model_name, - file=("audio.wav", audio_bytes), + file=("audio.wav", open(audio_url, "rb")), ) # remove temp file - if output_path and await asyncio.to_thread(os.path.exists, output_path): + if output_path and os.path.exists(output_path): try: os.remove(audio_url) except Exception as e: diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index d85c84f9b..678deb948 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -1,7 +1,6 @@ import asyncio import os import uuid -from pathlib import Path from typing import cast import whisper @@ -43,7 +42,9 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" - file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] + with open(file_path, "rb") as f: + file_header = f.read(8) + if silk_header in file_header: return True return False @@ -65,7 +66,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): await download_file(audio_url, path) audio_url = path - if not await asyncio.to_thread(os.path.exists, audio_url): + if not os.path.exists(audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 7b6068dd8..0a22e456e 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -1,7 +1,5 @@ -import asyncio import os import uuid -from pathlib import Path import aiohttp from xinference_client.client.restful.async_restful_client import ( @@ -104,8 +102,9 @@ class ProviderXinferenceSTT(STTProvider): f"Failed to download audio from {audio_url}, status: {resp.status}", ) return "" - elif await asyncio.to_thread(os.path.exists, audio_url): - audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes) + elif os.path.exists(audio_url): + with open(audio_url, "rb") as f: + audio_bytes = f.read() else: logger.error(f"File not found: {audio_url}") return "" @@ -144,7 +143,8 @@ class ProviderXinferenceSTT(STTProvider): ) temp_files.extend([input_path, output_path]) - await asyncio.to_thread(Path(input_path).write_bytes, audio_bytes) + with open(input_path, "wb") as f: + f.write(audio_bytes) if conversion_type == "silk": logger.info("Converting silk to wav ...") @@ -153,7 +153,8 @@ class ProviderXinferenceSTT(STTProvider): logger.info("Converting amr to wav ...") await convert_to_pcm_wav(input_path, output_path) - audio_bytes = await asyncio.to_thread(Path(output_path).read_bytes) + with open(output_path, "rb") as f: + audio_bytes = f.read() # 4. Transcribe # 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来 @@ -198,7 +199,7 @@ class ProviderXinferenceSTT(STTProvider): # 5. Cleanup for temp_file in temp_files: try: - if await asyncio.to_thread(os.path.exists, temp_file): + if os.path.exists(temp_file): os.remove(temp_file) logger.debug(f"Removed temporary file: {temp_file}") except Exception as e: diff --git a/astrbot/core/skills/neo_skill_sync.py b/astrbot/core/skills/neo_skill_sync.py index 2bb4c50f8..5fe2b7832 100644 --- a/astrbot/core/skills/neo_skill_sync.py +++ b/astrbot/core/skills/neo_skill_sync.py @@ -5,7 +5,7 @@ import json import os import re from dataclasses import dataclass -from datetime import UTC, datetime +from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -19,7 +19,7 @@ _SKILL_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+") def _now_iso() -> str: - return datetime.now(UTC).isoformat() + return datetime.now(timezone.utc).isoformat() def _to_jsonable(model_like: Any) -> dict[str, Any]: diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index a24ddac9e..d15876526 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -7,7 +7,7 @@ import shutil import tempfile import zipfile from dataclasses import dataclass -from datetime import UTC, datetime +from datetime import datetime, timezone from pathlib import Path, PurePosixPath from astrbot.core.utils.astrbot_path import ( @@ -175,7 +175,7 @@ class SkillManager: def _save_sandbox_skills_cache(self, cache: dict) -> None: cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION - cache["updated_at"] = datetime.now(UTC).isoformat() + cache["updated_at"] = datetime.now(timezone.utc).isoformat() with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f: json.dump(cache, f, ensure_ascii=False, indent=2) diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index f6afc08e1..d28ac726a 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -3,7 +3,7 @@ from __future__ import annotations import enum from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Literal, TypeVar, overload +from typing import Any, Generic, Literal, TypeVar, overload from .filter import HandlerFilter from .star import star_map @@ -11,7 +11,7 @@ from .star import star_map T = TypeVar("T", bound="StarHandlerMetadata") -class StarHandlerRegistry[T: "StarHandlerMetadata"]: +class StarHandlerRegistry(Generic[T]): def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] @@ -227,7 +227,7 @@ H = TypeVar("H", bound=Callable[..., Any]) @dataclass -class StarHandlerMetadata[H: Callable[..., Any]]: +class StarHandlerMetadata(Generic[H]): """描述一个 Star 所注册的某一个 Handler。""" event_type: EventType diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index c5fa63bee..68c58fdae 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -8,7 +8,6 @@ import logging import os import sys import traceback -from pathlib import Path from types import ModuleType import yaml @@ -189,7 +188,7 @@ class PluginManager: 如果 target_plugin 为 None,则检查所有插件的依赖 """ plugin_dir = self.plugin_store_path - if not await asyncio.to_thread(os.path.exists, plugin_dir): + if not os.path.exists(plugin_dir): return False to_update = [] if target_plugin: @@ -199,9 +198,7 @@ class PluginManager: to_update.append(p.root_dir_name) for p in to_update: plugin_path = os.path.join(plugin_dir, p) - if await asyncio.to_thread( - os.path.exists, os.path.join(plugin_path, "requirements.txt") - ): + if os.path.exists(os.path.join(plugin_path, "requirements.txt")): pth = os.path.join(plugin_path, "requirements.txt") logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}") try: @@ -220,7 +217,7 @@ class PluginManager: try: return __import__(path, fromlist=[module_str]) except (ModuleNotFoundError, ImportError) as import_exc: - if await asyncio.to_thread(os.path.exists, requirements_path): + if os.path.exists(requirements_path): try: logger.info( f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}" @@ -654,19 +651,16 @@ class PluginManager: plugin_dir_path, self.conf_schema_fname, ) - if await asyncio.to_thread(os.path.exists, plugin_schema_path): + if os.path.exists(plugin_schema_path): # 加载插件配置 - plugin_schema_text = await asyncio.to_thread( - Path(plugin_schema_path).read_text, - encoding="utf-8", - ) - plugin_config = AstrBotConfig( - config_path=os.path.join( - self.plugin_config_path, - f"{root_dir_name}_config.json", - ), - schema=json.loads(plugin_schema_text), - ) + with open(plugin_schema_path, encoding="utf-8") as f: + plugin_config = AstrBotConfig( + config_path=os.path.join( + self.plugin_config_path, + f"{root_dir_name}_config.json", + ), + schema=json.loads(f.read()), + ) logo_path = os.path.join(plugin_dir_path, self.logo_fname) if path in star_map: @@ -842,7 +836,7 @@ class PluginManager: metadata.activated = False # Plugin logo path - if await asyncio.to_thread(os.path.exists, logo_path): + if os.path.exists(logo_path): metadata.logo_path = logo_path assert metadata.module_path, f"插件 {metadata.name} 模块路径为空" @@ -961,7 +955,7 @@ class PluginManager: except Exception: logger.warning(traceback.format_exc()) - if await asyncio.to_thread(os.path.exists, plugin_path): + if os.path.exists(plugin_path): try: remove_dir(plugin_path) logger.warning(f"已清理安装失败的插件目录: {plugin_path}") @@ -974,7 +968,7 @@ class PluginManager: self.plugin_config_path, f"{dir_name}_config.json", ) - if await asyncio.to_thread(os.path.exists, plugin_config_path): + if os.path.exists(plugin_config_path): try: os.remove(plugin_config_path) logger.warning(f"已清理安装失败插件配置: {plugin_config_path}") @@ -1106,14 +1100,13 @@ class PluginManager: # Extract README.md content if exists readme_content = None readme_path = os.path.join(plugin_path, "README.md") - if not await asyncio.to_thread(os.path.exists, readme_path): + if not os.path.exists(readme_path): readme_path = os.path.join(plugin_path, "readme.md") - if await asyncio.to_thread(os.path.exists, readme_path): + if os.path.exists(readme_path): try: - readme_content = await asyncio.to_thread( - Path(readme_path).read_text, encoding="utf-8" - ) + with open(readme_path, encoding="utf-8") as f: + readme_content = f.read() except Exception as e: logger.warning( f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}", @@ -1218,7 +1211,7 @@ class PluginManager: self._cleanup_plugin_state(dir_name) plugin_path = os.path.join(self.plugin_store_path, dir_name) - if await asyncio.to_thread(os.path.exists, plugin_path): + if os.path.exists(plugin_path): try: remove_dir(plugin_path) except Exception as e: @@ -1505,14 +1498,13 @@ class PluginManager: # Extract README.md content if exists readme_content = None readme_path = os.path.join(desti_dir, "README.md") - if not await asyncio.to_thread(os.path.exists, readme_path): + if not os.path.exists(readme_path): readme_path = os.path.join(desti_dir, "readme.md") - if await asyncio.to_thread(os.path.exists, readme_path): + if os.path.exists(readme_path): try: - readme_content = await asyncio.to_thread( - Path(readme_path).read_text, encoding="utf-8" - ) + with open(readme_path, encoding="utf-8") as f: + readme_content = f.read() except Exception as e: logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}") diff --git a/astrbot/core/utils/datetime_utils.py b/astrbot/core/utils/datetime_utils.py index 431c9cd50..97b8196dd 100644 --- a/astrbot/core/utils/datetime_utils.py +++ b/astrbot/core/utils/datetime_utils.py @@ -1,4 +1,4 @@ -from datetime import UTC, datetime +from datetime import datetime, timezone def normalize_datetime_utc(dt: datetime | None) -> datetime | None: @@ -9,8 +9,8 @@ def normalize_datetime_utc(dt: datetime | None) -> datetime | None: if dt is None: return None if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: - return dt.replace(tzinfo=UTC) - return dt.astimezone(UTC) + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) def to_utc_isoformat(dt: datetime | None) -> str | None: diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index d37e4fbb1..b56592674 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -1,4 +1,3 @@ -import asyncio import base64 import logging import os @@ -9,7 +8,6 @@ import time import uuid import zipfile from pathlib import Path -from typing import BinaryIO import aiohttp import certifi @@ -19,8 +17,6 @@ from PIL import Image from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path logger = logging.getLogger("astrbot") -_DOWNLOAD_READ_CHUNK_SIZE = 64 * 1024 -_DOWNLOAD_FLUSH_THRESHOLD = 256 * 1024 def on_error(func, path, exc_info) -> None: @@ -62,7 +58,8 @@ def save_temp_img(img: Image.Image | bytes) -> str: if isinstance(img, Image.Image): img.save(p) else: - Path(p).write_bytes(img) + with open(p, "wb") as f: + f.write(img) return p @@ -86,13 +83,15 @@ async def download_image_by_url( async with session.post(url, json=post_data) as resp: if not path: return save_temp_img(await resp.read()) - await asyncio.to_thread(Path(path).write_bytes, await resp.read()) + with open(path, "wb") as f: + f.write(await resp.read()) return path else: async with session.get(url) as resp: if not path: return save_temp_img(await resp.read()) - await asyncio.to_thread(Path(path).write_bytes, await resp.read()) + with open(path, "wb") as f: + f.write(await resp.read()) return path except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): # 关闭SSL验证(仅在证书验证失败时作为fallback) @@ -110,13 +109,15 @@ async def download_image_by_url( async with session.post(url, json=post_data, ssl=ssl_context) as resp: if not path: return save_temp_img(await resp.read()) - await asyncio.to_thread(Path(path).write_bytes, await resp.read()) + with open(path, "wb") as f: + f.write(await resp.read()) return path else: async with session.get(url, ssl=ssl_context) as resp: if not path: return save_temp_img(await resp.read()) - await asyncio.to_thread(Path(path).write_bytes, await resp.read()) + with open(path, "wb") as f: + f.write(await resp.read()) return path except Exception as e: raise e @@ -137,20 +138,28 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non if resp.status != 200: raise Exception(f"下载文件失败: {resp.status}") total_size = int(resp.headers.get("content-length", 0)) + downloaded_size = 0 start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - file_obj = await asyncio.to_thread(Path(path).open, "wb") - try: - await _stream_to_file( - resp.content, - file_obj, - total_size=total_size, - start_time=start_time, - show_progress=show_progress, - ) - finally: - await asyncio.to_thread(file_obj.close) + with open(path, "wb") as f: + while True: + chunk = await resp.content.read(8192) + if not chunk: + break + f.write(chunk) + downloaded_size += len(chunk) + if show_progress: + elapsed_time = ( + time.time() - start_time + if time.time() - start_time > 0 + else 1 + ) + speed = downloaded_size / 1024 / elapsed_time # KB/s + print( + f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", + end="", + ) except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): # 关闭SSL验证(仅在证书验证失败时作为fallback) logger.warning( @@ -168,76 +177,32 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non async with aiohttp.ClientSession() as session: async with session.get(url, ssl=ssl_context, timeout=120) as resp: total_size = int(resp.headers.get("content-length", 0)) + downloaded_size = 0 start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - file_obj = await asyncio.to_thread(Path(path).open, "wb") - try: - await _stream_to_file( - resp.content, - file_obj, - total_size=total_size, - start_time=start_time, - show_progress=show_progress, - ) - finally: - await asyncio.to_thread(file_obj.close) + with open(path, "wb") as f: + while True: + chunk = await resp.content.read(8192) + if not chunk: + break + f.write(chunk) + downloaded_size += len(chunk) + if show_progress: + elapsed_time = time.time() - start_time + speed = downloaded_size / 1024 / elapsed_time # KB/s + print( + f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", + end="", + ) if show_progress: print() -async def _stream_to_file( - stream: aiohttp.StreamReader, - file_obj: BinaryIO, - *, - total_size: int, - start_time: float, - show_progress: bool, -) -> None: - """Stream HTTP response into file with buffered thread-offloaded writes.""" - downloaded_size = 0 - known_total = total_size if total_size > 0 else None - buffered = bytearray() - - try: - while True: - chunk = await stream.read(_DOWNLOAD_READ_CHUNK_SIZE) - if not chunk: - break - - buffered.extend(chunk) - downloaded_size += len(chunk) - - if len(buffered) >= _DOWNLOAD_FLUSH_THRESHOLD: - await asyncio.to_thread(file_obj.write, bytes(buffered)) - buffered.clear() - - if show_progress: - _print_download_progress(downloaded_size, known_total, start_time) - finally: - if buffered: - # Ensure buffered data is flushed even on cancellation. - await asyncio.shield(asyncio.to_thread(file_obj.write, bytes(buffered))) - - -def _print_download_progress( - downloaded_size: int, total_size: int | None, start_time: float -) -> None: - elapsed_time = max(time.time() - start_time, 1e-6) - speed = downloaded_size / 1024 / elapsed_time # KB/s - - if total_size: - percent = downloaded_size / total_size - msg = f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s" - else: - msg = f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s" - - print(msg, end="") - - -async def file_to_base64(file_path: str) -> str: - data_bytes = await asyncio.to_thread(Path(file_path).read_bytes) - base64_str = base64.b64encode(data_bytes).decode() +def file_to_base64(file_path: str) -> str: + with open(file_path, "rb") as f: + data_bytes = f.read() + base64_str = base64.b64encode(data_bytes).decode() return "base64://" + base64_str @@ -256,18 +221,17 @@ def get_local_ip_addresses(): async def get_dashboard_version(): # First check user data directory (manually updated / downloaded dashboard). dist_dir = os.path.join(get_astrbot_data_path(), "dist") - if not await asyncio.to_thread(os.path.exists, dist_dir): + if not os.path.exists(dist_dir): # Fall back to the dist bundled inside the installed wheel. _bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist" - if await asyncio.to_thread(_bundled.exists): + if _bundled.exists(): dist_dir = str(_bundled) - if await asyncio.to_thread(os.path.exists, dist_dir): + if os.path.exists(dist_dir): version_file = os.path.join(dist_dir, "assets", "version") - if await asyncio.to_thread(os.path.exists, version_file): - v = ( - await asyncio.to_thread(Path(version_file).read_text, encoding="utf-8") - ).strip() - return v + if os.path.exists(version_file): + with open(version_file, encoding="utf-8") as f: + v = f.read().strip() + return v return None @@ -280,12 +244,9 @@ async def download_dashboard( ) -> None: """下载管理面板文件""" if path is None: - zip_path = ( - await asyncio.to_thread(Path(get_astrbot_data_path()).absolute) - / "dashboard.zip" - ) + zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip" else: - zip_path = await asyncio.to_thread(Path(path).absolute) + zip_path = Path(path).absolute() if latest or len(str(version)) != 40: ver_name = "latest" if latest else version diff --git a/astrbot/core/utils/media_utils.py b/astrbot/core/utils/media_utils.py index 7ecebcad4..8d833514f 100644 --- a/astrbot/core/utils/media_utils.py +++ b/astrbot/core/utils/media_utils.py @@ -108,7 +108,7 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None) if process.returncode != 0: # 清理可能已生成但无效的临时文件 - if output_path and await asyncio.to_thread(os.path.exists, output_path): + if output_path and os.path.exists(output_path): try: os.remove(output_path) logger.debug( @@ -183,7 +183,7 @@ async def convert_video_format( if process.returncode != 0: # 清理可能已生成但无效的临时文件 - if output_path and await asyncio.to_thread(os.path.exists, output_path): + if output_path and os.path.exists(output_path): try: os.remove(output_path) logger.debug( @@ -231,7 +231,7 @@ async def convert_audio_format( if output_path is None: temp_dir = Path(get_astrbot_temp_path()) - await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}") args = ["ffmpeg", "-y", "-i", audio_path] @@ -249,7 +249,7 @@ async def convert_audio_format( ) _, stderr = await process.communicate() if process.returncode != 0: - if output_path and await asyncio.to_thread(os.path.exists, output_path): + if output_path and os.path.exists(output_path): try: os.remove(output_path) except OSError as e: @@ -287,7 +287,7 @@ async def extract_video_cover( """从视频中提取封面图(JPG)。""" if output_path is None: temp_dir = Path(get_astrbot_temp_path()) - await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg") try: @@ -306,7 +306,7 @@ async def extract_video_cover( ) _, stderr = await process.communicate() if process.returncode != 0: - if output_path and await asyncio.to_thread(os.path.exists, output_path): + if output_path and os.path.exists(output_path): try: os.remove(output_path) except OSError as e: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index a6c62c495..b327a6184 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -71,11 +71,11 @@ class SessionController: asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout_seconds: float) -> None: + async def _holding(self, event: asyncio.Event, timeout: float) -> None: """等待事件结束或超时""" try: - await asyncio.wait_for(event.wait(), timeout_seconds) - except TimeoutError: + await asyncio.wait_for(event.wait(), timeout) + except asyncio.TimeoutError: if not self.future.done(): self.future.set_exception(TimeoutError("等待超时")) except asyncio.CancelledError: @@ -124,14 +124,14 @@ class SessionWaiter: async def register_wait( self, handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]], - timeout_seconds: int = 30, + timeout: int = 30, ) -> Any: """等待外部输入并处理""" self.handler = handler USER_SESSIONS[self.session_id] = self # 开始一个会话保持事件 - self.session_controller.keep(timeout_seconds, reset_timeout=True) + self.session_controller.keep(timeout, reset_timeout=True) try: return await self.session_controller.future diff --git a/astrbot/core/utils/temp_dir_cleaner.py b/astrbot/core/utils/temp_dir_cleaner.py index 668ee4513..c0c060098 100644 --- a/astrbot/core/utils/temp_dir_cleaner.py +++ b/astrbot/core/utils/temp_dir_cleaner.py @@ -141,7 +141,7 @@ class TempDirCleaner: self._stop_event.wait(), timeout=self.CHECK_INTERVAL_SECONDS, ) - except TimeoutError: + except asyncio.TimeoutError: continue logger.info("TempDirCleaner stopped.") diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index 1abd6d1c0..f342484bd 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -5,7 +5,6 @@ import subprocess import tempfile import wave from io import BytesIO -from pathlib import Path from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path @@ -14,18 +13,19 @@ from astrbot.core.utils.astrbot_path import get_astrbot_temp_path async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: import pysilk - input_data = await asyncio.to_thread(Path(silk_path).read_bytes) - if input_data.startswith(b"\x02"): - input_data = input_data[1:] - input_io = BytesIO(input_data) - output_io = BytesIO() - pysilk.decode(input_io, output_io, 24000) - output_io.seek(0) - with wave.open(output_path, "wb") as wav: - wav.setnchannels(1) - wav.setsampwidth(2) - wav.setframerate(24000) - wav.writeframes(output_io.read()) + with open(silk_path, "rb") as f: + input_data = f.read() + if input_data.startswith(b"\x02"): + input_data = input_data[1:] + input_io = BytesIO(input_data) + output_io = BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + with wave.open(output_path, "wb") as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(output_io.read()) return output_path @@ -97,10 +97,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str: logger.debug(f"[FFmpeg] stderr: {stderr.decode().strip()}") logger.info(f"[FFmpeg] return code: {p.returncode}") - if ( - await asyncio.to_thread(os.path.exists, output_path) - and await asyncio.to_thread(os.path.getsize, output_path) > 0 - ): + if os.path.exists(output_path) and os.path.getsize(output_path) > 0: return output_path raise RuntimeError("生成的WAV文件不存在或为空") @@ -159,12 +156,13 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]: tencent=True, ) - silk_bytes = await asyncio.to_thread(Path(silk_path).read_bytes) - silk_b64 = base64.b64encode(silk_bytes).decode("utf-8") + with open(silk_path, "rb") as f: + silk_bytes = await asyncio.to_thread(f.read) + silk_b64 = base64.b64encode(silk_bytes).decode("utf-8") return silk_b64, duration # 已是秒 finally: - if await asyncio.to_thread(os.path.exists, wav_path) and wav_path != audio_path: + if os.path.exists(wav_path) and wav_path != audio_path: os.remove(wav_path) - if await asyncio.to_thread(os.path.exists, silk_path): + if os.path.exists(silk_path): os.remove(silk_path) diff --git a/astrbot/dashboard/routes/api_key.py b/astrbot/dashboard/routes/api_key.py index 6d89de910..4b957fe8e 100644 --- a/astrbot/dashboard/routes/api_key.py +++ b/astrbot/dashboard/routes/api_key.py @@ -1,6 +1,6 @@ import hashlib import secrets -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta, timezone from quart import g, request @@ -59,7 +59,7 @@ class ApiKeyRoute(Route): "expires_at": ApiKeyRoute._serialize_datetime(key.expires_at), "revoked_at": ApiKeyRoute._serialize_datetime(key.revoked_at), "is_revoked": key.revoked_at is not None, - "is_expired": bool(expires_at and expires_at < datetime.now(UTC)), + "is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)), } async def list_api_keys(self): @@ -98,7 +98,9 @@ class ApiKeyRoute(Route): return ( Response().error("expires_in_days must be greater than 0").__dict__ ) - expires_at = datetime.now(UTC) + timedelta(days=expires_in_days_int) + expires_at = datetime.now(timezone.utc) + timedelta( + days=expires_in_days_int + ) raw_key = f"abk_{secrets.token_urlsafe(32)}" key_hash = self._hash_key(raw_key) diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index f9bdc51d8..40db1f60b 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -82,7 +82,7 @@ class AuthRoute(Route): def generate_jwt(self, username): payload = { "username": username, - "exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=7), + "exp": datetime.datetime.utcnow() + datetime.timedelta(days=7), } jwt_token = self.config["dashboard"].get("jwt_secret", None) if not jwt_token: diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index 674bbbfdd..952806beb 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -32,18 +32,6 @@ CHUNK_SIZE = 1024 * 1024 # 1MB UPLOAD_EXPIRE_SECONDS = 3600 # 上传会话过期时间(1小时) -def _merge_backup_chunks(output_path: str, chunk_dir: str, total: int) -> None: - with open(output_path, "wb") as outfile: - for i in range(total): - chunk_path = os.path.join(chunk_dir, f"{i}.part") - with open(chunk_path, "rb") as chunk_file: - while True: - data_block = chunk_file.read(8192) - if not data_block: - break - outfile.write(data_block) - - def secure_filename(filename: str) -> str: """清洗文件名,移除路径遍历字符和危险字符 @@ -252,7 +240,7 @@ class BackupRoute(Route): if upload_id in self.upload_sessions: session = self.upload_sessions[upload_id] chunk_dir = session.get("chunk_dir") - if chunk_dir and await asyncio.to_thread(os.path.exists, chunk_dir): + if chunk_dir and os.path.exists(chunk_dir): try: shutil.rmtree(chunk_dir) except Exception as e: @@ -295,9 +283,7 @@ class BackupRoute(Route): page_size = request.args.get("page_size", 20, type=int) # 确保备份目录存在 - await asyncio.to_thread( - Path(self.backup_dir).mkdir, parents=True, exist_ok=True - ) + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) # 获取所有备份文件 backup_files = [] @@ -307,7 +293,7 @@ class BackupRoute(Route): continue file_path = os.path.join(self.backup_dir, filename) - if not await asyncio.to_thread(os.path.isfile, file_path): + if not os.path.isfile(file_path): continue # 读取 manifest.json 获取备份信息 @@ -417,7 +403,7 @@ class BackupRoute(Route): result={ "filename": os.path.basename(zip_path), "path": zip_path, - "size": await asyncio.to_thread(os.path.getsize, zip_path), + "size": os.path.getsize(zip_path), }, ) except Exception as e: @@ -451,9 +437,7 @@ class BackupRoute(Route): unique_filename = generate_unique_filename(safe_filename) # 保存上传的文件 - await asyncio.to_thread( - Path(self.backup_dir).mkdir, parents=True, exist_ok=True - ) + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) zip_path = os.path.join(self.backup_dir, unique_filename) await file.save(zip_path) @@ -467,7 +451,7 @@ class BackupRoute(Route): { "filename": unique_filename, "original_filename": file.filename, - "size": await asyncio.to_thread(os.path.getsize, zip_path), + "size": os.path.getsize(zip_path), } ) .__dict__ @@ -515,7 +499,7 @@ class BackupRoute(Route): # 创建分片存储目录 chunk_dir = os.path.join(self.chunks_dir, upload_id) - await asyncio.to_thread(Path(chunk_dir).mkdir, parents=True, exist_ok=True) + Path(chunk_dir).mkdir(parents=True, exist_ok=True) # 清洗文件名 safe_filename = secure_filename(filename) @@ -701,20 +685,22 @@ class BackupRoute(Route): chunk_dir = session["chunk_dir"] filename = session["filename"] - await asyncio.to_thread( - Path(self.backup_dir).mkdir, parents=True, exist_ok=True - ) + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) output_path = os.path.join(self.backup_dir, filename) try: - await asyncio.to_thread( - _merge_backup_chunks, - output_path, - chunk_dir, - total, - ) + with open(output_path, "wb") as outfile: + for i in range(total): + chunk_path = os.path.join(chunk_dir, f"{i}.part") + with open(chunk_path, "rb") as chunk_file: + # 分块读取,避免内存溢出 + while True: + data_block = chunk_file.read(8192) + if not data_block: + break + outfile.write(data_block) - file_size = await asyncio.to_thread(os.path.getsize, output_path) + file_size = os.path.getsize(output_path) # 标记备份为上传来源(修改 manifest.json 中的 origin 字段) self._mark_backup_as_uploaded(output_path) @@ -739,7 +725,7 @@ class BackupRoute(Route): ) except Exception as e: # 如果合并失败,删除不完整的文件 - if await asyncio.to_thread(os.path.exists, output_path): + if os.path.exists(output_path): os.remove(output_path) raise e @@ -801,7 +787,7 @@ class BackupRoute(Route): return Response().error("无效的文件名").__dict__ zip_path = os.path.join(self.backup_dir, filename) - if not await asyncio.to_thread(os.path.exists, zip_path): + if not os.path.exists(zip_path): return Response().error(f"备份文件不存在: {filename}").__dict__ # 获取知识库管理器(用于构造 importer) @@ -855,7 +841,7 @@ class BackupRoute(Route): return Response().error("无效的文件名").__dict__ zip_path = os.path.join(self.backup_dir, filename) - if not await asyncio.to_thread(os.path.exists, zip_path): + if not os.path.exists(zip_path): return Response().error(f"备份文件不存在: {filename}").__dict__ # 生成任务ID @@ -1002,7 +988,7 @@ class BackupRoute(Route): return Response().error("无效的文件名").__dict__ file_path = os.path.join(self.backup_dir, filename) - if not await asyncio.to_thread(os.path.exists, file_path): + if not os.path.exists(file_path): return Response().error("备份文件不存在").__dict__ return await send_file( @@ -1033,7 +1019,7 @@ class BackupRoute(Route): return Response().error("无效的文件名").__dict__ file_path = os.path.join(self.backup_dir, filename) - if not await asyncio.to_thread(os.path.exists, file_path): + if not os.path.exists(file_path): return Response().error("备份文件不存在").__dict__ os.remove(file_path) @@ -1081,12 +1067,12 @@ class BackupRoute(Route): # 检查原文件是否存在 old_path = os.path.join(self.backup_dir, filename) - if not await asyncio.to_thread(os.path.exists, old_path): + if not os.path.exists(old_path): return Response().error("备份文件不存在").__dict__ # 检查新文件名是否已存在 new_path = os.path.join(self.backup_dir, new_filename) - if await asyncio.to_thread(os.path.exists, new_path): + if os.path.exists(new_path): return Response().error(f"文件名 '{new_filename}' 已存在").__dict__ # 执行重命名 diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index f76aa7f9f..a914f3cbf 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -80,23 +80,17 @@ class ChatRoute(Route): try: file_path = os.path.join(self.attachments_dir, os.path.basename(filename)) - real_file_path = await asyncio.to_thread(os.path.realpath, file_path) - real_imgs_dir = await asyncio.to_thread( - os.path.realpath, self.attachments_dir - ) + real_file_path = os.path.realpath(file_path) + real_imgs_dir = os.path.realpath(self.attachments_dir) - if not await asyncio.to_thread(os.path.exists, real_file_path): + if not os.path.exists(real_file_path): # try legacy file_path = os.path.join( self.legacy_img_dir, os.path.basename(filename) ) - if await asyncio.to_thread(os.path.exists, file_path): - real_file_path = await asyncio.to_thread( - os.path.realpath, file_path - ) - real_imgs_dir = await asyncio.to_thread( - os.path.realpath, self.legacy_img_dir - ) + if os.path.exists(file_path): + real_file_path = os.path.realpath(file_path) + real_imgs_dir = os.path.realpath(self.legacy_img_dir) if not real_file_path.startswith(real_imgs_dir): return Response().error("Invalid file path").__dict__ @@ -123,7 +117,7 @@ class ChatRoute(Route): return Response().error("Attachment not found").__dict__ file_path = attachment.path - real_file_path = await asyncio.to_thread(os.path.realpath, file_path) + real_file_path = os.path.realpath(file_path) return await send_file(real_file_path, mimetype=attachment.mime_type) @@ -350,7 +344,7 @@ class ChatRoute(Route): while True: try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except TimeoutError: + except asyncio.TimeoutError: continue except asyncio.CancelledError: logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") @@ -658,7 +652,7 @@ class ChatRoute(Route): try: attachments = await self.db.get_attachments(attachment_ids) for attachment in attachments: - if not await asyncio.to_thread(os.path.exists, attachment.path): + if not os.path.exists(attachment.path): continue try: os.remove(attachment.path) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 1ed80a218..823d0fb9d 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1103,9 +1103,7 @@ class ConfigRoute(Route): if not files: return Response().error("No files uploaded").__dict__ - storage_root_path = await asyncio.to_thread( - Path(get_astrbot_plugin_data_path()).resolve, strict=False - ) + storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1181,9 +1179,7 @@ class ConfigRoute(Route): if not md: return Response().error(f"Plugin {name} not found").__dict__ - storage_root_path = await asyncio.to_thread( - Path(get_astrbot_plugin_data_path()).resolve, strict=False - ) + storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1211,9 +1207,7 @@ class ConfigRoute(Route): if not meta or meta.get("type") != "file": return Response().error("Config item not found or not file type").__dict__ - storage_root_path = await asyncio.to_thread( - Path(get_astrbot_plugin_data_path()).resolve, strict=False - ) + storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1381,7 +1375,7 @@ class ConfigRoute(Route): logo_file_path = os.path.join(plugin_dir, platform.logo_path) # 检查文件是否存在并注册令牌 - if await asyncio.to_thread(os.path.exists, logo_file_path): + if os.path.exists(logo_file_path): logo_token = await file_token_service.register_file( logo_file_path, timeout=3600, diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index d06414c01..f0ac5d43d 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -729,7 +729,7 @@ class KnowledgeBaseRoute(Route): ) finally: # 清理临时文件 - if await asyncio.to_thread(os.path.exists, temp_file_path): + if os.path.exists(temp_file_path): os.remove(temp_file_path) # 获取知识库 diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 58398d24c..8d0af938d 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -86,7 +86,7 @@ class LiveChatSession: self.temp_audio_path = audio_path logger.info( - f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {await asyncio.to_thread(os.path.getsize, audio_path)} bytes" + f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes" ) return audio_path, time.time() - start_time @@ -491,7 +491,7 @@ class LiveChatRoute(Route): try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except TimeoutError: + except asyncio.TimeoutError: continue if not result: @@ -790,7 +790,7 @@ class LiveChatRoute(Route): try: result = await asyncio.wait_for(back_queue.get(), timeout=0.5) - except TimeoutError: + except asyncio.TimeoutError: continue if not result: diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index 763d05db0..9a736b176 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -369,7 +369,7 @@ class OpenApiRoute(Route): while True: try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except TimeoutError: + except asyncio.TimeoutError: continue if not result: diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index f3d1d69ee..bb7769926 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -6,7 +6,6 @@ import ssl import traceback from dataclasses import dataclass from datetime import datetime -from pathlib import Path import aiohttp import certifi @@ -739,20 +738,19 @@ class PluginRoute(Route): plugin_obj.root_dir_name, ) - if not await asyncio.to_thread(os.path.isdir, plugin_dir): + if not os.path.isdir(plugin_dir): logger.warning(f"无法找到插件目录: {plugin_dir}") return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ readme_path = os.path.join(plugin_dir, "README.md") - if not await asyncio.to_thread(os.path.isfile, readme_path): + if not os.path.isfile(readme_path): logger.warning(f"插件 {plugin_name} 没有README文件") return Response().error(f"插件 {plugin_name} 没有README文件").__dict__ try: - readme_content = await asyncio.to_thread( - Path(readme_path).read_text, encoding="utf-8" - ) + with open(readme_path, encoding="utf-8") as f: + readme_content = f.read() return ( Response() @@ -801,7 +799,7 @@ class PluginRoute(Route): plugin_obj.root_dir_name, ) - if not await asyncio.to_thread(os.path.isdir, plugin_dir): + if not os.path.isdir(plugin_dir): logger.warning(f"无法找到插件目录: {plugin_dir}") return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ @@ -809,11 +807,10 @@ class PluginRoute(Route): changelog_names = ["CHANGELOG.md", "changelog.md", "CHANGELOG", "changelog"] for name in changelog_names: changelog_path = os.path.join(plugin_dir, name) - if await asyncio.to_thread(os.path.isfile, changelog_path): + if os.path.isfile(changelog_path): try: - changelog_content = await asyncio.to_thread( - Path(changelog_path).read_text, encoding="utf-8" - ) + with open(changelog_path, encoding="utf-8") as f: + changelog_content = f.read() return ( Response() .ok({"content": changelog_content}, "成功获取更新日志") diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py index b003e2010..adad49615 100644 --- a/astrbot/dashboard/routes/skills.py +++ b/astrbot/dashboard/routes/skills.py @@ -1,4 +1,3 @@ -import asyncio import os import re import shutil @@ -183,7 +182,7 @@ class SkillsRoute(Route): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ finally: - if temp_path and await asyncio.to_thread(os.path.exists, temp_path): + if temp_path and os.path.exists(temp_path): try: os.remove(temp_path) except Exception: diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 238b6aa4c..532238ac7 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,4 +1,3 @@ -import asyncio import os import re import threading @@ -215,17 +214,13 @@ class StatRoute(Route): changelog_path = os.path.join(changelogs_dir, filename) # 规范化路径,防止符号链接攻击 - changelog_path = await asyncio.to_thread(os.path.realpath, changelog_path) - changelogs_dir = await asyncio.to_thread(os.path.realpath, changelogs_dir) + changelog_path = os.path.realpath(changelog_path) + changelogs_dir = os.path.realpath(changelogs_dir) # 验证最终路径在预期的 changelogs 目录内(防止路径遍历) # 确保规范化后的路径以 changelogs_dir 开头,且是目录内的文件 - changelog_path_normalized = await asyncio.to_thread( - os.path.normpath, changelog_path - ) - changelogs_dir_normalized = await asyncio.to_thread( - os.path.normpath, changelogs_dir - ) + changelog_path_normalized = os.path.normpath(changelog_path) + changelogs_dir_normalized = os.path.normpath(changelogs_dir) # 检查路径是否在预期目录内(必须是目录的子文件,不能是目录本身) expected_prefix = changelogs_dir_normalized + os.sep @@ -235,22 +230,21 @@ class StatRoute(Route): ) return Response().error("Invalid version format").__dict__ - if not await asyncio.to_thread(os.path.exists, changelog_path): + if not os.path.exists(changelog_path): return ( Response() .error(f"Changelog for version {version} not found") .__dict__ ) - if not await asyncio.to_thread(os.path.isfile, changelog_path): + if not os.path.isfile(changelog_path): return ( Response() .error(f"Changelog for version {version} not found") .__dict__ ) - content = await asyncio.to_thread( - Path(changelog_path).read_text, encoding="utf-8" - ) + with open(changelog_path, encoding="utf-8") as f: + content = f.read() return Response().ok({"content": content, "version": version}).__dict__ except Exception as e: @@ -263,7 +257,7 @@ class StatRoute(Route): project_path = get_astrbot_path() changelogs_dir = os.path.join(project_path, "changelogs") - if not await asyncio.to_thread(os.path.exists, changelogs_dir): + if not os.path.exists(changelogs_dir): return Response().ok({"versions": []}).__dict__ versions = [] diff --git a/main.py b/main.py index b8c42d78e..36c46fca3 100644 --- a/main.py +++ b/main.py @@ -69,13 +69,13 @@ async def check_dashboard_files(webui_dir: str | None = None): """下载管理面板文件""" # 指定webui目录 if webui_dir: - if await asyncio.to_thread(os.path.exists, webui_dir): + if os.path.exists(webui_dir): logger.info(f"使用指定的 WebUI 目录: {webui_dir}") return webui_dir logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。") data_dist_path = os.path.join(get_astrbot_data_path(), "dist") - if await asyncio.to_thread(os.path.exists, data_dist_path): + if os.path.exists(data_dist_path): v = await get_dashboard_version() if v is not None: # 存在文件 diff --git a/pyproject.toml b/pyproject.toml index 408fec56e..b59960aab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ astrbot = "astrbot.cli.__main__:cli" [tool.ruff] exclude = ["astrbot/core/utils/t2i/local_strategy.py", "astrbot/api/all.py", "tests"] line-length = 88 -target-version = "py312" +target-version = "py310" [tool.ruff.lint] select = [ @@ -99,11 +99,13 @@ ignore = [ "F403", "F405", "E501", + "ASYNC230", # TODO: handle ASYNC230 in AstrBot + "ASYNC240", # TODO: handle ASYNC240 in AstrBot ] [tool.pyright] typeCheckingMode = "basic" -pythonVersion = "3.12" +pythonVersion = "3.10" reportMissingTypeStubs = false reportMissingImports = false include = ["astrbot"] diff --git a/tests/test_skill_manager_sandbox_cache.py b/tests/test_skill_manager_sandbox_cache.py index 5707148c6..88923ec10 100644 --- a/tests/test_skill_manager_sandbox_cache.py +++ b/tests/test_skill_manager_sandbox_cache.py @@ -2,8 +2,6 @@ from __future__ import annotations from pathlib import Path -import pytest - from astrbot.core.skills.skill_manager import SkillManager @@ -58,7 +56,7 @@ def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path) assert by_name["custom-local"].description == "local description" assert by_name["custom-local"].path == "skills/custom-local/SKILL.md" assert by_name["python-sandbox"].description == "ship built-in" - assert by_name["python-sandbox"].path == "/workspace/skills/python-sandbox/SKILL.md" + assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md" def test_sandbox_cached_skill_respects_active_and_display_path( @@ -100,8 +98,7 @@ def test_sandbox_cached_skill_respects_active_and_display_path( assert len(all_skills) == 1 assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md" - with pytest.raises(PermissionError): - mgr.set_skill_active("browser-automation", False) + mgr.set_skill_active("browser-automation", False) active_skills = mgr.list_skills(runtime="sandbox", active_only=True) - assert len(active_skills) == 1 - assert active_skills[0].name == "browser-automation" + assert active_skills == [] + diff --git a/tests/unit/test_io_file_to_base64.py b/tests/unit/test_io_file_to_base64.py deleted file mode 100644 index b490ffed8..000000000 --- a/tests/unit/test_io_file_to_base64.py +++ /dev/null @@ -1,16 +0,0 @@ -import base64 - -import pytest - -from astrbot.core.utils.io import file_to_base64 - - -@pytest.mark.asyncio -async def test_file_to_base64_reads_file_async(tmp_path): - sample_file = tmp_path / "sample.bin" - sample_file.write_bytes(b"astrbot") - - result = await file_to_base64(str(sample_file)) - - expected = "base64://" + base64.b64encode(b"astrbot").decode() - assert result == expected