This reverts commit a9c16febf4.
This commit is contained in:
Soulter
2026-03-05 01:34:07 +08:00
parent 2d27bfb6d0
commit 6beca2144c
116 changed files with 1022 additions and 1299 deletions
-2
View File
@@ -46,8 +46,6 @@ jobs:
include:
- language: python
build-mode: none
- language: javascript-typescript
build-mode: none
# CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
# Use `c-cpp` to analyze code written in C, C++ or both
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
+4 -5
View File
@@ -23,13 +23,12 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip uv
uv sync --group dev
python -m pip install --upgrade pip
pip install pytest pytest-asyncio pytest-cov
pip install --editable .
- name: Run tests
run: |
@@ -38,7 +37,7 @@ jobs:
mkdir -p data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
uv run pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v5
+6 -11
View File
@@ -13,23 +13,18 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v6
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '24.13.0'
cache: "pnpm"
cache-dependency-path: dashboard/pnpm-lock.yaml
- name: Install and build
- name: npm install, build
run: |
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir dashboard run typecheck
pnpm --dir dashboard run build
cd dashboard
npm install pnpm -g
pnpm install
pnpm i --save-dev @types/markdown-it
pnpm run build
- name: Inject Commit SHA
id: get_sha
+12 -32
View File
@@ -25,18 +25,6 @@ jobs:
fetch-depth: 1
fetch-tag: true
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '24.13.0'
cache: "pnpm"
cache-dependency-path: dashboard/pnpm-lock.yaml
- name: Check for new commits today
if: github.event_name == 'schedule'
id: check-commits
@@ -58,10 +46,12 @@ jobs:
- name: Build Dashboard
run: |
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir dashboard run build
mkdir -p dashboard/dist/assets
echo $(git rev-parse HEAD) > dashboard/dist/assets/version
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
@@ -133,18 +123,6 @@ jobs:
fetch-depth: 1
fetch-tag: true
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '24.13.0'
cache: "pnpm"
cache-dependency-path: dashboard/pnpm-lock.yaml
- name: Get latest tag (only on manual trigger)
id: get-latest-tag
if: github.event_name == 'workflow_dispatch'
@@ -175,10 +153,12 @@ jobs:
- name: Build Dashboard
run: |
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir dashboard run build
mkdir -p dashboard/dist/assets
echo $(git rev-parse HEAD) > dashboard/dist/assets/version
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
+2 -27
View File
@@ -18,29 +18,6 @@ permissions:
contents: write
jobs:
verify-core:
name: Verify Core Quality Gate
runs-on: ubuntu-24.04
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Install uv
shell: bash
run: python -m pip install uv
- name: Run local PR gate checks
shell: bash
run: make pr-test-neo
build-dashboard:
name: Build Dashboard
runs-on: ubuntu-24.04
@@ -108,8 +85,7 @@ jobs:
VERSION_TAG: ${{ steps.tag.outputs.tag }}
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y rclone
curl https://rclone.org/install.sh | sudo bash
mkdir -p ~/.config/rclone
cat <<EOF > ~/.config/rclone/rclone.conf
@@ -130,7 +106,6 @@ jobs:
name: Publish GitHub Release
runs-on: ubuntu-24.04
needs:
- verify-core
- build-dashboard
steps:
- name: Checkout repository
@@ -251,7 +226,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
python-version: "3.10"
- name: Install uv
shell: bash
+2 -2
View File
@@ -8,7 +8,7 @@ ci:
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.15.1
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
@@ -22,4 +22,4 @@ repos:
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py312-plus]
args: [--py310-plus]
+3 -2
View File
@@ -13,9 +13,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
bash \
ffmpeg \
curl \
gnupg \
git \
nodejs \
npm \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y --no-install-recommends nodejs \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
+1 -1
View File
@@ -19,7 +19,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
+1 -1
View File
@@ -19,7 +19,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
+1 -1
View File
@@ -19,7 +19,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0LjYxNTZDNS4zMTUwMiAxNC4zOTk5IDUuNjAxNTYgMTQuMTEzNCA1LjYwMTU2IDEzLjc1OTlWMTEuMDM5OUM1LjYwMTU2IDEwLjY4NjQgNS4zMTUwMiAxMC4zOTk5IDQuOTYxNTYgMTAuMzk5OVoiIGZpbGw9IiNmZmYiLz4KPHBhdGggZD0iTTEzLjc1ODQgMS42MDAxSDExLjAzODRDMTAuNjg1IDEuNjAwMSAxMC4zOTg0IDEuODg2NjQgMTAuMzk4NCAyLjI0MDFWNC45NjAxQzEwLjM5ODQgNS4zMTM1NiAxMC42ODUgNS42MDAxIDExLjAzODQgNS42MDAxSDEzLjc1ODRDMTQuMTExOSA1LjYwMDEgMTQuMzk4NCA1LjMxMzU2IDE0LjM5ODQgNC45NjAxVjIuMjQwMUMxNC4zOTg0IDEuODg2NjQgMTQuMTExOSAxLjYwMDEgMTMuNzU4NCAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDRMNCAxMlpFIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
+1 -1
View File
@@ -19,7 +19,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjczODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
+1 -1
View File
@@ -19,7 +19,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
+1 -1
View File
@@ -17,7 +17,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
+2 -2
View File
@@ -1,6 +1,6 @@
import shutil
import tempfile
from enum import StrEnum
from enum import Enum
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile
@@ -12,7 +12,7 @@ import yaml
from .version_comparator import VersionComparator
class PluginStatus(StrEnum):
class PluginStatus(str, Enum):
INSTALLED = "installed"
NEED_UPDATE = "needs-update"
NOT_INSTALLED = "not-installed"
+3 -2
View File
@@ -1,12 +1,13 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, Generic
from .hooks import BaseAgentRunHooks
from .run_context import TContext
from .tool import FunctionTool
@dataclass
class Agent[TContext]:
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str | FunctionTool] | None = None
+4 -1
View File
@@ -1,8 +1,11 @@
from typing import Generic
from .agent import Agent
from .run_context import TContext
from .tool import FunctionTool
class HandoffTool[TContext](FunctionTool):
class HandoffTool(FunctionTool, Generic[TContext]):
"""Handoff tool for delegating tasks to another agent."""
def __init__(
+4 -2
View File
@@ -1,12 +1,14 @@
from typing import Generic
import mcp
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.provider.entities import LLMResponse
from .run_context import ContextWrapper
from .run_context import ContextWrapper, TContext
class BaseAgentRunHooks[TContext]:
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ...
async def on_tool_start(
self,
+4 -2
View File
@@ -2,6 +2,7 @@ import asyncio
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from tenacity import (
before_sleep_log,
@@ -15,6 +16,7 @@ from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
from .run_context import TContext
from .tool import FunctionTool
try:
@@ -99,7 +101,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
except TimeoutError:
except asyncio.TimeoutError:
return False, f"Connection timeout: {timeout} seconds"
except Exception as e:
return False, f"{e!s}"
@@ -358,7 +360,7 @@ class MCPClient:
self.running_event.set()
class MCPTool[TContext](FunctionTool):
class MCPTool(FunctionTool, Generic[TContext]):
"""A function tool that calls an MCP service."""
def __init__(
+2 -2
View File
@@ -7,7 +7,7 @@ from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper
from ..run_context import ContextWrapper, TContext
class AgentState(Enum):
@@ -19,7 +19,7 @@ class AgentState(Enum):
ERROR = auto() # Error state
class BaseAgentRunner[TContext]:
class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
@@ -1,7 +1,7 @@
import base64
import json
import sys
import typing as T
from typing import override
import astrbot.core.message.components as Comp
from astrbot import logger
@@ -18,6 +18,11 @@ from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .coze_api_client import CozeAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class CozeAgentRunner(BaseAgentRunner[TContext]):
"""Coze Agent Runner"""
@@ -246,7 +251,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout_seconds=self.timeout,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
@@ -2,7 +2,6 @@ import asyncio
import io
import json
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
import aiohttp
@@ -91,7 +90,7 @@ class CozeAPIClient:
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
return file_id
except TimeoutError:
except asyncio.TimeoutError:
logger.error("文件上传超时")
raise Exception("文件上传超时")
except Exception as e:
@@ -129,7 +128,7 @@ class CozeAPIClient:
conversation_id: str | None = None,
auto_save_history: bool = True,
stream: bool = True,
timeout_seconds: float = 120,
timeout: float = 120,
) -> AsyncGenerator[dict[str, Any], None]:
"""发送聊天消息并返回流式响应
@@ -140,7 +139,7 @@ class CozeAPIClient:
conversation_id: 会话ID
auto_save_history: 是否自动保存历史
stream: 是否流式响应
timeout_seconds: 超时时间
timeout: 超时时间
"""
session = await self._ensure_session()
@@ -167,7 +166,7 @@ class CozeAPIClient:
url,
json=payload,
params=params,
timeout=aiohttp.ClientTimeout(total=timeout_seconds),
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
@@ -204,8 +203,8 @@ class CozeAPIClient:
except json.JSONDecodeError:
event_data = {"content": data_str}
except TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout_seconds}秒)")
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {e!s}")
@@ -237,7 +236,7 @@ class CozeAPIClient:
except json.JSONDecodeError:
raise Exception("Coze API 返回非JSON格式")
except TimeoutError:
except asyncio.TimeoutError:
raise Exception("Coze API 请求超时")
except aiohttp.ClientError as e:
raise Exception(f"Coze API 请求失败: {e!s}")
@@ -295,7 +294,8 @@ if __name__ == "__main__":
client = CozeAPIClient(api_key=api_key)
try:
file_data = await asyncio.to_thread(Path("README.md").read_bytes)
with open("README.md", "rb") as f:
file_data = f.read()
file_id = await client.upload_file(file_data)
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
@@ -2,9 +2,9 @@ import asyncio
import functools
import queue
import re
import sys
import threading
import typing as T
from typing import override
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
@@ -22,6 +22,11 @@ from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
"""Dashscope Agent Runner"""
@@ -1,10 +1,10 @@
import asyncio
import hashlib
import json
import sys
import typing as T
from collections import deque
from dataclasses import dataclass, field
from typing import override
from uuid import uuid4
import astrbot.core.message.components as Comp
@@ -40,6 +40,11 @@ from .deerflow_stream_utils import (
get_message_id,
)
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
"""DeerFlow Agent Runner via LangGraph HTTP API."""
@@ -373,9 +378,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
if thread_id:
return thread_id
thread = await self.api_client.create_thread(
timeout_seconds=min(30, self.timeout)
)
thread = await self.api_client.create_thread(timeout=min(30, self.timeout))
thread_id = thread.get("thread_id", "")
if not thread_id:
raise Exception(
@@ -636,7 +639,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
async for event in self.api_client.stream_run(
thread_id=thread_id,
payload=payload,
timeout_seconds=self.timeout,
timeout=self.timeout,
):
event_type = event.get("event")
data = event.get("data")
@@ -663,7 +666,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
if event_type == "end":
break
except TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
logger.warning(
"DeerFlow stream timed out after %ss for thread_id=%s; returning partial result.",
self.timeout,
@@ -139,7 +139,7 @@ class DeerFlowAPIClient:
) -> None:
await self.close()
async def create_thread(self, timeout_seconds: float = 20) -> dict[str, Any]:
async def create_thread(self, timeout: float = 20) -> dict[str, Any]:
session = self._get_session()
url = f"{self.api_base}/api/langgraph/threads"
payload = {"metadata": {}}
@@ -147,7 +147,7 @@ class DeerFlowAPIClient:
url,
json=payload,
headers=self.headers,
timeout=timeout_seconds,
timeout=timeout,
proxy=self.proxy,
) as resp:
if resp.status not in (200, 201):
@@ -161,7 +161,7 @@ class DeerFlowAPIClient:
self,
thread_id: str,
payload: dict[str, Any],
timeout_seconds: float = 120,
timeout: float = 120,
) -> AsyncGenerator[dict[str, Any], None]:
session = self._get_session()
url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream"
@@ -183,9 +183,9 @@ class DeerFlowAPIClient:
# Use socket read timeout so active heartbeats/chunks can keep the stream alive.
stream_timeout = ClientTimeout(
total=None,
connect=min(timeout_seconds, 30),
sock_connect=min(timeout_seconds, 30),
sock_read=timeout_seconds,
connect=min(timeout, 30),
sock_connect=min(timeout, 30),
sock_read=timeout,
)
async with session.post(
url,
@@ -1,7 +1,7 @@
import base64
import os
import sys
import typing as T
from typing import override
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
@@ -19,6 +19,11 @@ from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .dify_api_client import DifyAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DifyAgentRunner(BaseAgentRunner[TContext]):
"""Dify Agent Runner"""
@@ -171,7 +176,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout_seconds=self.timeout,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk["event"] == "message" or chunk["event"] == "agent_message":
@@ -211,7 +216,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
},
user=session_id,
files=files_payload,
timeout_seconds=self.timeout,
timeout=self.timeout,
):
logger.debug(f"dify workflow resp chunk: {chunk}")
match chunk["event"]:
@@ -1,8 +1,6 @@
import asyncio
import codecs
import json
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from aiohttp import ClientResponse, ClientSession, FormData
@@ -49,20 +47,20 @@ class DifyAPIClient:
response_mode: str = "streaming",
conversation_id: str = "",
files: list[dict[str, Any]] | None = None,
timeout_seconds: float = 60,
timeout: float = 60,
) -> AsyncGenerator[dict[str, Any], None]:
if files is None:
files = []
url = f"{self.api_base}/chat-messages"
payload = locals()
payload.pop("self")
payload.pop("timeout_seconds")
payload.pop("timeout")
logger.info(f"chat_messages payload: {payload}")
async with self.session.post(
url,
json=payload,
headers=self.headers,
timeout=timeout_seconds,
timeout=timeout,
) as resp:
if resp.status != 200:
text = await resp.text()
@@ -78,20 +76,20 @@ class DifyAPIClient:
user: str,
response_mode: str = "streaming",
files: list[dict[str, Any]] | None = None,
timeout_seconds: float = 60,
timeout: float = 60,
):
if files is None:
files = []
url = f"{self.api_base}/workflows/run"
payload = locals()
payload.pop("self")
payload.pop("timeout_seconds")
payload.pop("timeout")
logger.info(f"workflow_run payload: {payload}")
async with self.session.post(
url,
json=payload,
headers=self.headers,
timeout=timeout_seconds,
timeout=timeout,
) as resp:
if resp.status != 200:
text = await resp.text()
@@ -136,13 +134,14 @@ class DifyAPIClient:
# 使用文件路径
import os
file_content = await asyncio.to_thread(Path(file_path).read_bytes)
form.add_field(
"file",
file_content,
filename=os.path.basename(file_path),
content_type=mime_type or "application/octet-stream",
)
with open(file_path, "rb") as f:
file_content = f.read()
form.add_field(
"file",
file_content,
filename=os.path.basename(file_path),
content_type=mime_type or "application/octet-stream",
)
else:
raise ValueError("file_path 和 file_data 不能同时为 None")
@@ -1,10 +1,10 @@
import asyncio
import copy
import sys
import time
import traceback
import typing as T
from dataclasses import dataclass, field
from typing import override
from mcp.types import (
BlobResourceContents,
@@ -44,6 +44,11 @@ from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@dataclass(slots=True)
class _HandleFunctionToolsResult:
+3 -3
View File
@@ -1,6 +1,6 @@
import copy
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
from typing import Any, Generic
import jsonschema
import mcp
@@ -10,7 +10,7 @@ from pydantic.dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageEventResult
from .run_context import ContextWrapper
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
ToolExecResult = str | mcp.types.CallToolResult
@@ -38,7 +38,7 @@ class ToolSchema:
@dataclass
class FunctionTool[TContext](ToolSchema):
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
handler: (
+3 -3
View File
@@ -1,13 +1,13 @@
from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, Generic
import mcp
from .run_context import ContextWrapper
from .run_context import ContextWrapper, TContext
from .tool import FunctionTool
class BaseFunctionToolExecutor[TContext]:
class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod
async def execute(
cls,
+2 -2
View File
@@ -3,7 +3,6 @@ import re
import time
import traceback
from collections.abc import AsyncGenerator
from pathlib import Path
from astrbot.core import logger
from astrbot.core.agent.message import Message
@@ -510,7 +509,8 @@ async def _simulated_stream_tts(
audio_path = await tts_provider.get_audio(text)
if audio_path:
audio_data = await asyncio.to_thread(Path(audio_path).read_bytes)
with open(audio_path, "rb") as f:
audio_data = f.read()
await audio_queue.put((text, audio_data))
except Exception as e:
logger.error(
+1 -1
View File
@@ -625,7 +625,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
exc_info=True,
)
yield None
except TimeoutError:
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.",
)
+1 -2
View File
@@ -1,4 +1,3 @@
import asyncio
import base64
import json
import os
@@ -242,7 +241,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
bool: indicates whether the file was downloaded from sandbox.
"""
if await asyncio.to_thread(os.path.exists, path):
if os.path.exists(path):
return path, False
# Try to check if the file exists in the sandbox
+16 -45
View File
@@ -4,12 +4,11 @@
导出格式为 JSON这是数据库无关的方案支持未来向 MySQL/PostgreSQL 迁移
"""
import asyncio
import hashlib
import json
import os
import zipfile
from datetime import UTC, datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -84,7 +83,7 @@ class AstrBotExporter:
output_dir = get_astrbot_backups_path()
# 确保输出目录存在
await asyncio.to_thread(Path(output_dir).mkdir, parents=True, exist_ok=True)
Path(output_dir).mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"astrbot_backup_{timestamp}.zip"
@@ -161,10 +160,9 @@ class AstrBotExporter:
# 3. 导出配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导出配置文件...")
config_content = await asyncio.to_thread(
self._read_text_if_exists, self.config_path
)
if config_content is not None:
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config_content = f.read()
zf.writestr("config/cmd_config.json", config_content)
self._add_checksum("config/cmd_config.json", config_content)
if progress_callback:
@@ -201,7 +199,7 @@ class AstrBotExporter:
except Exception as e:
logger.error(f"备份导出失败: {e}")
# 清理失败的文件
if await asyncio.to_thread(os.path.exists, zip_path):
if os.path.exists(zip_path):
os.remove(zip_path)
raise
@@ -319,7 +317,7 @@ class AstrBotExporter:
for dir_name, dir_path in backup_directories.items():
full_path = Path(dir_path)
if not await asyncio.to_thread(full_path.exists):
if not full_path.exists():
logger.debug(f"目录不存在,跳过: {full_path}")
continue
@@ -361,44 +359,17 @@ class AstrBotExporter:
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""导出附件文件"""
await asyncio.to_thread(self._export_attachments_sync, zf, attachments)
def _export_attachments_sync(
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""在单个线程中批量导出附件,减少高频线程切换。"""
for attachment in attachments:
file_path = attachment.get("path", "")
attachment_id = attachment.get("attachment_id")
try:
if not file_path:
continue
if not attachment_id:
logger.warning(
f"跳过附件导出:attachment_id 为空 (path={file_path})"
)
continue
# 使用 attachment_id 作为文件名
ext = os.path.splitext(file_path)[1]
archive_path = f"files/attachments/{attachment_id}{ext}"
zf.write(file_path, archive_path)
except FileNotFoundError:
# 和旧逻辑保持一致:缺失文件直接跳过。
continue
except OSError as e:
logger.warning(
f"导出附件失败 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
)
file_path = attachment.get("path", "")
if file_path and os.path.exists(file_path):
# 使用 attachment_id 作为文件名
attachment_id = attachment.get("attachment_id", "")
ext = os.path.splitext(file_path)[1]
archive_path = f"files/attachments/{attachment_id}{ext}"
zf.write(file_path, archive_path)
except Exception as e:
logger.warning(
f"导出附件时发生非预期错误,已跳过 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
)
def _read_text_if_exists(self, file_path: str) -> str | None:
"""Read text file when it exists in a single synchronous call."""
if not os.path.exists(file_path):
return None
return Path(file_path).read_text(encoding="utf-8")
logger.warning(f"导出附件失败: {e}")
def _model_to_dict(self, record: Any) -> dict:
"""将 SQLModel 实例转换为字典
@@ -475,7 +446,7 @@ class AstrBotExporter:
manifest = {
"version": BACKUP_MANIFEST_VERSION,
"astrbot_version": VERSION,
"exported_at": datetime.now(UTC).isoformat(),
"exported_at": datetime.now(timezone.utc).isoformat(),
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
"schema_version": {
"main_db": "v4",
+20 -22
View File
@@ -7,13 +7,12 @@
- 版本匹配时也需要用户确认
"""
import asyncio
import json
import os
import shutil
import zipfile
from dataclasses import dataclass, field
from datetime import UTC, datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -365,7 +364,7 @@ class AstrBotImporter:
"""
result = ImportResult()
if not await asyncio.to_thread(os.path.exists, zip_path):
if not os.path.exists(zip_path):
result.add_error(f"备份文件不存在: {zip_path}")
return result
@@ -447,13 +446,12 @@ class AstrBotImporter:
try:
config_content = zf.read("config/cmd_config.json")
# 备份现有配置
if await asyncio.to_thread(os.path.exists, self.config_path):
if os.path.exists(self.config_path):
backup_path = f"{self.config_path}.bak"
shutil.copy2(self.config_path, backup_path)
await asyncio.to_thread(
Path(self.config_path).write_bytes, config_content
)
with open(self.config_path, "wb") as f:
f.write(config_content)
result.imported_files["config"] = 1
except Exception as e:
result.add_warning(f"导入配置文件失败: {e}")
@@ -677,9 +675,9 @@ class AstrBotImporter:
if isinstance(value, datetime):
dt = value
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(UTC)
dt = dt.astimezone(timezone.utc)
return dt.isoformat()
if isinstance(value, str):
timestamp = value.strip()
@@ -690,9 +688,9 @@ class AstrBotImporter:
try:
dt = datetime.fromisoformat(timestamp)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(UTC)
dt = dt.astimezone(timezone.utc)
return dt.isoformat()
except ValueError:
return None
@@ -755,8 +753,8 @@ class AstrBotImporter:
if faiss_path in zf.namelist():
try:
target_path = kb_dir / "index.faiss"
with zf.open(faiss_path) as src:
await asyncio.to_thread(target_path.write_bytes, src.read())
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
@@ -768,8 +766,8 @@ class AstrBotImporter:
rel_path = name[len(media_prefix) :]
target_path = kb_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src:
await asyncio.to_thread(target_path.write_bytes, src.read())
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
@@ -830,8 +828,8 @@ class AstrBotImporter:
target_path = attachments_dir / os.path.basename(name)
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src:
await asyncio.to_thread(target_path.write_bytes, src.read())
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
count += 1
except Exception as e:
logger.warning(f"导入附件 {name} 失败: {e}")
@@ -887,15 +885,15 @@ class AstrBotImporter:
continue
# 备份现有目录(如果存在)
if await asyncio.to_thread(target_dir.exists):
if target_dir.exists():
backup_path = Path(f"{target_dir}.bak")
if await asyncio.to_thread(backup_path.exists):
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(str(target_dir), str(backup_path))
logger.debug(f"已备份现有目录 {target_dir}{backup_path}")
# 创建目标目录
await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True)
target_dir.mkdir(parents=True, exist_ok=True)
# 解压文件
for name in dir_files:
@@ -908,8 +906,8 @@ class AstrBotImporter:
target_path = target_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src:
await asyncio.to_thread(target_path.write_bytes, src.read())
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
file_count += 1
except Exception as e:
result.add_warning(f"导入文件 {name} 失败: {e}")
+3 -3
View File
@@ -118,10 +118,10 @@ class BayContainerManager:
return f"http://127.0.0.1:{self._host_port}"
async def wait_healthy(self, timeout_seconds: int = HEALTH_TIMEOUT_S) -> None:
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
"""Block until Bay's ``/health`` endpoint returns 200."""
url = f"http://127.0.0.1:{self._host_port}/health"
deadline = asyncio.get_event_loop().time() + timeout_seconds
deadline = asyncio.get_event_loop().time() + timeout
last_error: str = ""
async with aiohttp.ClientSession() as session:
@@ -140,7 +140,7 @@ class BayContainerManager:
await asyncio.sleep(HEALTH_POLL_INTERVAL_S)
raise TimeoutError(
f"Bay did not become healthy within {timeout_seconds}s (last error: {last_error})"
f"Bay did not become healthy within {timeout}s (last error: {last_error})"
)
async def read_credentials(self) -> str:
+3 -3
View File
@@ -1,6 +1,5 @@
import asyncio
import random
from pathlib import Path
from typing import Any
import aiohttp
@@ -47,7 +46,8 @@ class MockShipyardSandboxClient:
try:
# Read file content
file_content = await asyncio.to_thread(Path(path).read_bytes)
with open(path, "rb") as f:
file_content = f.read()
# Create multipart form data
data = aiohttp.FormData()
@@ -88,7 +88,7 @@ class MockShipyardSandboxClient:
"error": f"Connection error: {str(e)}",
"message": "File upload failed",
}
except TimeoutError:
except asyncio.TimeoutError:
return {
"success": False,
"error": "File upload timeout",
+4 -4
View File
@@ -59,7 +59,7 @@ class LocalShellComponent(ShellComponent):
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout_seconds: int | None = 30,
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
) -> dict[str, Any]:
@@ -87,7 +87,7 @@ class LocalShellComponent(ShellComponent):
shell=shell,
cwd=working_dir,
env=run_env,
timeout=timeout_seconds,
timeout=timeout,
capture_output=True,
text=True,
)
@@ -106,14 +106,14 @@ class LocalPythonComponent(PythonComponent):
self,
code: str,
kernel_id: str | None = None,
timeout_seconds: int = 30,
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
try:
result = subprocess.run(
[os.environ.get("PYTHON", sys.executable), "-c", code],
timeout=timeout_seconds,
timeout=timeout,
capture_output=True,
text=True,
)
+14 -14
View File
@@ -1,9 +1,7 @@
from __future__ import annotations
import asyncio
import os
import shlex
from pathlib import Path
from typing import Any, cast
from astrbot.api import logger
@@ -35,11 +33,11 @@ class NeoPythonComponent(PythonComponent):
self,
code: str,
kernel_id: str | None = None,
timeout_seconds: int = 30,
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
_ = kernel_id # Bay runtime does not expose kernel_id in current SDK.
result = await self._sandbox.python.exec(code, timeout=timeout_seconds)
result = await self._sandbox.python.exec(code, timeout=timeout)
payload = _maybe_model_dump(result)
output_text = payload.get("output", "") or ""
@@ -77,7 +75,7 @@ class NeoShellComponent(ShellComponent):
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout_seconds: int | None = 30,
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
) -> dict[str, Any]:
@@ -101,7 +99,7 @@ class NeoShellComponent(ShellComponent):
result = await self._sandbox.shell.exec(
run_command,
timeout=timeout_seconds or 30,
timeout=timeout or 30,
cwd=cwd,
)
payload = _maybe_model_dump(result)
@@ -194,7 +192,7 @@ class NeoBrowserComponent(BrowserComponent):
async def exec(
self,
cmd: str,
timeout_seconds: int = 30,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
@@ -202,7 +200,7 @@ class NeoBrowserComponent(BrowserComponent):
) -> dict[str, Any]:
result = await self._sandbox.browser.exec(
cmd,
timeout=timeout_seconds,
timeout=timeout,
description=description,
tags=tags,
learn=learn,
@@ -213,7 +211,7 @@ class NeoBrowserComponent(BrowserComponent):
async def exec_batch(
self,
commands: list[str],
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
@@ -222,7 +220,7 @@ class NeoBrowserComponent(BrowserComponent):
) -> dict[str, Any]:
result = await self._sandbox.browser.exec_batch(
commands,
timeout=timeout_seconds,
timeout=timeout,
stop_on_error=stop_on_error,
description=description,
tags=tags,
@@ -234,7 +232,7 @@ class NeoBrowserComponent(BrowserComponent):
async def run_skill(
self,
skill_key: str,
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
@@ -242,7 +240,7 @@ class NeoBrowserComponent(BrowserComponent):
) -> dict[str, Any]:
result = await self._sandbox.browser.run_skill(
skill_key=skill_key,
timeout=timeout_seconds,
timeout=timeout,
stop_on_error=stop_on_error,
include_trace=include_trace,
description=description,
@@ -470,7 +468,8 @@ class ShipyardNeoBooter(ComputerBooter):
async def upload_file(self, path: str, file_name: str) -> dict:
if self._sandbox is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
content = await asyncio.to_thread(Path(path).read_bytes)
with open(path, "rb") as f:
content = f.read()
remote_path = file_name.lstrip("/")
await self._sandbox.filesystem.upload(remote_path, content)
logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path)
@@ -487,7 +486,8 @@ class ShipyardNeoBooter(ComputerBooter):
local_dir = os.path.dirname(local_path)
if local_dir:
os.makedirs(local_dir, exist_ok=True)
await asyncio.to_thread(Path(local_path).write_bytes, cast(bytes, content))
with open(local_path, "wb") as f:
f.write(cast(bytes, content))
logger.info(
"[Computer] File downloaded from Neo sandbox: %s -> %s",
remote_path,
+2 -3
View File
@@ -1,4 +1,3 @@
import asyncio
import json
import os
import shutil
@@ -373,12 +372,12 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
splitting into `apply` and `scan` phases.
"""
skills_root = Path(get_astrbot_skills_path())
if not await asyncio.to_thread(skills_root.is_dir):
if not skills_root.is_dir():
return
local_skill_dirs = _list_local_skill_dirs(skills_root)
temp_dir = Path(get_astrbot_temp_path())
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
zip_base = temp_dir / "skills_bundle"
zip_path = zip_base.with_suffix(".zip")
+3 -3
View File
@@ -11,7 +11,7 @@ class BrowserComponent(Protocol):
async def exec(
self,
cmd: str,
timeout_seconds: int = 30,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
@@ -23,7 +23,7 @@ class BrowserComponent(Protocol):
async def exec_batch(
self,
commands: list[str],
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
@@ -36,7 +36,7 @@ class BrowserComponent(Protocol):
async def run_skill(
self,
skill_key: str,
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
+1 -1
View File
@@ -12,7 +12,7 @@ class PythonComponent(Protocol):
self,
code: str,
kernel_id: str | None = None,
timeout_seconds: int = 30,
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
"""Execute Python code"""
+1 -1
View File
@@ -13,7 +13,7 @@ class ShellComponent(Protocol):
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout_seconds: int | None = 30,
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
) -> dict[str, Any]:
+6 -18
View File
@@ -71,23 +71,19 @@ class BrowserExecTool(FunctionTool):
self,
context: ContextWrapper[AstrAgentContext],
cmd: str,
timeout_seconds: int = 30,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
**kwargs: Any,
) -> ToolExecResult:
legacy_timeout = kwargs.pop("timeout", None)
if legacy_timeout is not None:
timeout_seconds = int(legacy_timeout)
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.exec(
cmd=cmd,
timeout_seconds=timeout_seconds,
timeout=timeout,
description=description,
tags=tags,
learn=learn,
@@ -137,24 +133,20 @@ class BrowserBatchExecTool(FunctionTool):
self,
context: ContextWrapper[AstrAgentContext],
commands: list[str],
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
**kwargs: Any,
) -> ToolExecResult:
legacy_timeout = kwargs.pop("timeout", None)
if legacy_timeout is not None:
timeout_seconds = int(legacy_timeout)
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.exec_batch(
commands=commands,
timeout_seconds=timeout_seconds,
timeout=timeout,
stop_on_error=stop_on_error,
description=description,
tags=tags,
@@ -189,23 +181,19 @@ class RunBrowserSkillTool(FunctionTool):
self,
context: ContextWrapper[AstrAgentContext],
skill_key: str,
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
tags: str | None = None,
**kwargs: Any,
) -> ToolExecResult:
legacy_timeout = kwargs.pop("timeout", None)
if legacy_timeout is not None:
timeout_seconds = int(legacy_timeout)
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.run_skill(
skill_key=skill_key,
timeout_seconds=timeout_seconds,
timeout=timeout,
stop_on_error=stop_on_error,
include_trace=include_trace,
description=description,
+2 -3
View File
@@ -1,4 +1,3 @@
import asyncio
import os
import uuid
from dataclasses import dataclass, field
@@ -112,10 +111,10 @@ class FileUploadTool(FunctionTool):
)
try:
# Check if file exists
if not await asyncio.to_thread(os.path.exists, local_path):
if not os.path.exists(local_path):
return f"Error: File does not exist: {local_path}"
if not await asyncio.to_thread(os.path.isfile, local_path):
if not os.path.isfile(local_path):
return f"Error: Path is not a file: {local_path}"
# Use basename if sandbox_filename is not provided
+2 -2
View File
@@ -1,7 +1,7 @@
import asyncio
import json
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from zoneinfo import ZoneInfo
@@ -192,7 +192,7 @@ class CronJobManager:
job = await self.db.get_cron_job(job_id)
if not job or not job.enabled:
return
start_time = datetime.now(UTC)
start_time = datetime.now(timezone.utc)
await self.db.update_cron_job(
job_id, status="running", last_run_at=start_time, last_error=None
)
+1 -2
View File
@@ -1,4 +1,3 @@
import asyncio
import os
from astrbot.api import logger, sp
@@ -23,7 +22,7 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
data_dir = get_astrbot_data_path()
data_v3_db = os.path.join(data_dir, "data_v3.db")
if not await asyncio.to_thread(os.path.exists, data_v3_db):
if not os.path.exists(data_v3_db):
return False
migration_done = await db_helper.get_preference(
"global",
+3 -3
View File
@@ -106,8 +106,8 @@ async def migration_platform_table(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
)
secs_from_2023_4_10_to_now = (
datetime.datetime.now(datetime.UTC)
- datetime.datetime(2023, 4, 10, tzinfo=datetime.UTC)
datetime.datetime.now(datetime.timezone.utc)
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
).total_seconds()
offset_sec = int(secs_from_2023_4_10_to_now)
logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。")
@@ -162,7 +162,7 @@ async def migration_platform_table(
{
"timestamp": datetime.datetime.fromtimestamp(
bucket_end,
tz=datetime.UTC,
tz=datetime.timezone.utc,
),
"platform_id": platform_id,
"platform_type": platform_type,
+4 -4
View File
@@ -1,16 +1,16 @@
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from datetime import datetime, timezone
from typing import TypedDict
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
class TimestampMixin(SQLModel):
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(UTC),
sa_column_kwargs={"onupdate": lambda: datetime.now(UTC)},
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)},
)
+6 -6
View File
@@ -2,7 +2,7 @@ import asyncio
import threading
import typing as T
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime, timedelta
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult, Row
from sqlalchemy.ext.asyncio import AsyncSession
@@ -633,7 +633,7 @@ class SQLiteDatabase(BaseDatabase):
"""Get an active API key by hash (not revoked, not expired)."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now(UTC)
now = datetime.now(timezone.utc)
query = select(ApiKey).where(
ApiKey.key_hash == key_hash,
col(ApiKey.revoked_at).is_(None),
@@ -650,7 +650,7 @@ class SQLiteDatabase(BaseDatabase):
await session.execute(
update(ApiKey)
.where(col(ApiKey.key_id) == key_id)
.values(last_used_at=datetime.now(UTC)),
.values(last_used_at=datetime.now(timezone.utc)),
)
async def revoke_api_key(self, key_id: str) -> bool:
@@ -661,7 +661,7 @@ class SQLiteDatabase(BaseDatabase):
query = (
update(ApiKey)
.where(col(ApiKey.key_id) == key_id)
.values(revoked_at=datetime.now(UTC))
.values(revoked_at=datetime.now(timezone.utc))
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
@@ -1534,7 +1534,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)}
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if display_name is not None:
values["display_name"] = display_name
@@ -1622,7 +1622,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)}
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if title is not None:
values["title"] = title
if emoji is not None:
+5 -13
View File
@@ -28,17 +28,12 @@ class FileTokenService:
await self._cleanup_expired_tokens()
return file_token not in self.staged_files
async def register_file(
self,
file_path: str,
timeout_seconds: float | None = None,
**kwargs,
) -> str:
async def register_file(self, file_path: str, timeout: float | None = None) -> str:
"""向令牌服务注册一个文件。
Args:
file_path(str): 文件路径
timeout_seconds(float): 超时时间单位秒可选
timeout(float): 超时时间单位秒可选
Returns:
str: 一个单次令牌
@@ -63,18 +58,15 @@ class FileTokenService:
async with self.lock:
await self._cleanup_expired_tokens()
legacy_timeout = kwargs.pop("timeout", None)
if legacy_timeout is not None:
timeout_seconds = float(legacy_timeout)
if not await asyncio.to_thread(os.path.exists, local_path):
if not os.path.exists(local_path):
raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})",
)
file_token = str(uuid.uuid4())
expire_time = time.time() + (
timeout_seconds if timeout_seconds is not None else self.default_timeout
timeout if timeout is not None else self.default_timeout
)
# 存储转换后的真实路径
self.staged_files[file_token] = (local_path, expire_time)
@@ -101,6 +93,6 @@ class FileTokenService:
raise KeyError(f"无效或过期的文件 token: {file_token}")
file_path, _ = self.staged_files.pop(file_token)
if not await asyncio.to_thread(os.path.exists, file_path):
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
return file_path
+8 -8
View File
@@ -1,5 +1,5 @@
import uuid
from datetime import UTC, datetime
from datetime import datetime, timezone
from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint
@@ -40,10 +40,10 @@ class KnowledgeBase(BaseKBModel, table=True):
top_k_dense: int | None = Field(default=50, nullable=True)
top_k_sparse: int | None = Field(default=50, nullable=True)
top_m_final: int | None = Field(default=5, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(UTC),
sa_column_kwargs={"onupdate": datetime.now(UTC)},
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
doc_count: int = Field(default=0, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
@@ -83,10 +83,10 @@ class KBDocument(BaseKBModel, table=True):
file_path: str = Field(max_length=512, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
media_count: int = Field(default=0, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(UTC),
sa_column_kwargs={"onupdate": datetime.now(UTC)},
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
@@ -117,4 +117,4 @@ class KBMedia(BaseKBModel, table=True):
file_path: str = Field(max_length=512, nullable=False)
file_size: int = Field(nullable=False)
mime_type: str = Field(max_length=100, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+32 -50
View File
@@ -27,8 +27,7 @@ import json
import os
import sys
import uuid
from enum import StrEnum
from pathlib import Path
from enum import Enum
if sys.version_info >= (3, 14):
from pydantic import BaseModel
@@ -40,17 +39,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
def _absolute_path(path: str) -> str:
return os.path.abspath(path)
def _absolute_path_if_exists(path: str | None) -> str | None:
if not path or not os.path.exists(path):
return None
return os.path.abspath(path)
class ComponentType(StrEnum):
class ComponentType(str, Enum):
# Basic Segment Types
Plain = "Plain" # plain text message
Image = "Image" # image
@@ -169,18 +158,18 @@ class Record(BaseMessageComponent):
return self.file[8:]
if self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return await asyncio.to_thread(_absolute_path, file_path)
return os.path.abspath(file_path)
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
file_path = os.path.join(
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
)
await asyncio.to_thread(Path(file_path).write_bytes, image_bytes)
return await asyncio.to_thread(_absolute_path, file_path)
local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file)
if local_path:
return local_path
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
if os.path.exists(self.file):
return os.path.abspath(self.file)
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str:
@@ -194,17 +183,16 @@ class Record(BaseMessageComponent):
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
bs64_data = await file_to_base64(self.file[8:])
bs64_data = file_to_base64(self.file[8:])
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = await file_to_base64(file_path)
bs64_data = file_to_base64(file_path)
elif self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
else:
try:
bs64_data = await file_to_base64(self.file)
except OSError as exc:
raise Exception(f"not a valid file: {self.file}") from exc
raise Exception(f"not a valid file: {self.file}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
@@ -268,15 +256,11 @@ class Video(BaseMessageComponent):
get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
)
await download_file(url, video_file_path)
local_path = await asyncio.to_thread(
_absolute_path_if_exists, video_file_path
)
if local_path:
return local_path
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
raise Exception(f"download failed: {url}")
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
if local_path:
return local_path
if os.path.exists(url):
return os.path.abspath(url)
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self) -> str:
@@ -465,18 +449,18 @@ class Image(BaseMessageComponent):
return url[8:]
if url.startswith("http"):
image_file_path = await download_image_by_url(url)
return await asyncio.to_thread(_absolute_path, image_file_path)
return os.path.abspath(image_file_path)
if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
image_file_path = os.path.join(
get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
)
await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes)
return await asyncio.to_thread(_absolute_path, image_file_path)
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
if local_path:
return local_path
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
if os.path.exists(url):
return os.path.abspath(url)
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str:
@@ -491,17 +475,16 @@ class Image(BaseMessageComponent):
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
bs64_data = await file_to_base64(url[8:])
bs64_data = file_to_base64(url[8:])
elif url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = await file_to_base64(image_file_path)
bs64_data = file_to_base64(image_file_path)
elif url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
else:
try:
bs64_data = await file_to_base64(url)
except OSError as exc:
raise Exception(f"not a valid file: {url}") from exc
raise Exception(f"not a valid file: {url}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
@@ -752,9 +735,8 @@ class File(BaseMessageComponent):
):
path = path[1:]
local_path = await asyncio.to_thread(_absolute_path_if_exists, path)
if local_path:
return local_path
if os.path.exists(path):
return os.path.abspath(path)
if self.url:
await self._download_file()
@@ -769,7 +751,7 @@ class File(BaseMessageComponent):
and path[2] == ":"
):
path = path[1:]
return await asyncio.to_thread(_absolute_path, path)
return os.path.abspath(path)
return ""
@@ -785,7 +767,7 @@ class File(BaseMessageComponent):
filename = f"fileseg_{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path)
self.file_ = await asyncio.to_thread(_absolute_path, file_path)
self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self) -> str:
"""将文件注册到文件服务。
@@ -254,7 +254,7 @@ class DingtalkPlatformAdapter(Platform):
"robotCode": robot_code,
}
temp_dir = Path(get_astrbot_temp_path())
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}"
async with (
aiohttp.ClientSession() as session,
@@ -412,7 +412,7 @@ class DingtalkPlatformAdapter(Platform):
form = aiohttp.FormData()
form.add_field(
"media",
await asyncio.to_thread(media_file_path.read_bytes),
media_file_path.read_bytes(),
filename=media_file_path.name,
content_type="application/octet-stream",
)
@@ -1,10 +1,15 @@
import sys
from collections.abc import Awaitable, Callable
from typing import override
import discord
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# Discord Bot客户端
class DiscordBotClient(discord.Bot):
@@ -1,6 +1,7 @@
import asyncio
import re
from typing import Any, cast, override
import sys
from typing import Any, cast
import discord
from discord.abc import GuildChannel, Messageable, PrivateChannel
@@ -26,6 +27,11 @@ from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_re
from .client import DiscordBotClient
from .discord_platform_event import DiscordPlatformEvent
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# 注册平台适配器
@register_platform_adapter(
@@ -130,7 +130,7 @@ class KookPlatformAdapter(Platform):
await asyncio.wait_for(
self.client.wait_until_closed(), timeout=1.0
)
except TimeoutError:
except asyncio.TimeoutError:
# 正常超时,继续下一轮 while 检查
continue
@@ -171,7 +171,7 @@ class KookClient:
# 处理不同类型的信令
await self._handle_signal(data)
except TimeoutError:
except asyncio.TimeoutError:
# 超时检查,继续循环
continue
except websockets.exceptions.ConnectionClosed:
@@ -362,14 +362,12 @@ class KookClient:
b64_str = file_url.removeprefix("base64://")
bytes_data = base64.b64decode(b64_str)
elif file_url.startswith("file://") or await asyncio.to_thread(
os.path.exists, file_url
):
elif file_url.startswith("file://") or os.path.exists(file_url):
file_url = file_url.removeprefix("file:///")
file_url = file_url.removeprefix("file://")
try:
target_path = await asyncio.to_thread(Path(file_url).resolve)
target_path = Path(file_url).resolve()
except Exception as exp:
logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"')
raise FileNotFoundError(
@@ -429,7 +429,7 @@ class LarkPlatformAdapter(Platform):
suffix = Path(file_name).suffix if file_name else default_suffix
temp_dir = Path(get_astrbot_temp_path())
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
temp_path = (
temp_dir / f"lark_{message_type}_{file_name}_{uuid4().hex[:4]}{suffix}"
)
@@ -1,10 +1,8 @@
import asyncio
import base64
import json
import os
import uuid
from io import BytesIO
from pathlib import Path
import lark_oapi as lark
from lark_oapi.api.im.v1 import (
@@ -138,7 +136,7 @@ class LarkMessageEvent(AstrMessageEvent):
Returns:
成功返回file_key失败返回None
"""
if not path or not await asyncio.to_thread(os.path.exists, path):
if not path or not os.path.exists(path):
logger.error(f"[Lark] 文件不存在: {path}")
return None
@@ -147,32 +145,36 @@ class LarkMessageEvent(AstrMessageEvent):
return None
try:
file_obj = BytesIO(await asyncio.to_thread(Path(path).read_bytes))
body_builder = (
CreateFileRequestBody.builder()
.file_type(file_type)
.file_name(os.path.basename(path))
.file(file_obj)
)
if duration is not None:
body_builder.duration(duration)
with open(path, "rb") as file_obj:
body_builder = (
CreateFileRequestBody.builder()
.file_type(file_type)
.file_name(os.path.basename(path))
.file(file_obj)
)
if duration is not None:
body_builder.duration(duration)
request = (
CreateFileRequest.builder().request_body(body_builder.build()).build()
)
response = await lark_client.im.v1.file.acreate(request)
request = (
CreateFileRequest.builder()
.request_body(body_builder.build())
.build()
)
response = await lark_client.im.v1.file.acreate(request)
if not response.success():
logger.error(f"[Lark] 无法上传文件({response.code}): {response.msg}")
return None
if not response.success():
logger.error(
f"[Lark] 无法上传文件({response.code}): {response.msg}"
)
return None
if response.data is None:
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
return None
if response.data is None:
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
return None
file_key = response.data.file_key
logger.debug(f"[Lark] 文件上传成功: {file_key}")
return file_key
file_key = response.data.file_key
logger.debug(f"[Lark] 文件上传成功: {file_key}")
return file_key
except Exception as e:
logger.error(f"[Lark] 无法打开或上传文件: {e}")
@@ -205,9 +207,8 @@ class LarkMessageEvent(AstrMessageEvent):
temp_dir,
f"lark_image_{uuid.uuid4().hex[:8]}.jpg",
)
await asyncio.to_thread(
Path(file_path).write_bytes, BytesIO(image_data).getvalue()
)
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else:
file_path = comp.file if comp.file else ""
@@ -216,9 +217,7 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error("[Lark] 图片路径为空,无法上传")
continue
try:
image_file = BytesIO(
await asyncio.to_thread(Path(file_path).read_bytes)
)
image_file = open(file_path, "rb")
except Exception as e:
logger.error(f"[Lark] 无法打开图片文件: {e}")
continue
@@ -413,9 +412,7 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error(f"[Lark] 无法获取音频文件路径: {e}")
return
if not original_audio_path or not await asyncio.to_thread(
os.path.exists, original_audio_path
):
if not original_audio_path or not os.path.exists(original_audio_path):
logger.error(f"[Lark] 音频文件不存在: {original_audio_path}")
return
@@ -445,9 +442,7 @@ class LarkMessageEvent(AstrMessageEvent):
)
# 清理转换后的临时音频文件
if converted_audio_path and await asyncio.to_thread(
os.path.exists, converted_audio_path
):
if converted_audio_path and os.path.exists(converted_audio_path):
try:
os.remove(converted_audio_path)
logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}")
@@ -490,9 +485,7 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error(f"[Lark] 无法获取视频文件路径: {e}")
return
if not original_video_path or not await asyncio.to_thread(
os.path.exists, original_video_path
):
if not original_video_path or not os.path.exists(original_video_path):
logger.error(f"[Lark] 视频文件不存在: {original_video_path}")
return
@@ -522,9 +515,7 @@ class LarkMessageEvent(AstrMessageEvent):
)
# 清理转换后的临时视频文件
if converted_video_path and await asyncio.to_thread(
os.path.exists, converted_video_path
):
if converted_video_path and os.path.exists(converted_video_path):
try:
os.remove(converted_video_path)
logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}")
@@ -161,7 +161,7 @@ class LineMessageEvent(AstrMessageEvent):
try:
video_path = await segment.convert_to_file_path()
temp_dir = Path(get_astrbot_temp_path())
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
thumb_path = temp_dir / f"line_video_preview_{uuid.uuid4().hex}.jpg"
process = await asyncio.create_subprocess_exec(
@@ -201,8 +201,8 @@ class LineMessageEvent(AstrMessageEvent):
async def _resolve_file_size(segment: File) -> int:
try:
file_path = await segment.get_file(allow_return_url=False)
if file_path and await asyncio.to_thread(os.path.exists, file_path):
return int(await asyncio.to_thread(os.path.getsize, file_path))
if file_path and os.path.exists(file_path):
return int(os.path.getsize(file_path))
except Exception as e:
logger.debug("[LINE] resolve file size failed: %s", e)
return 0
@@ -499,8 +499,7 @@ class MisskeyPlatformAdapter(Platform):
# 清理临时文件
if local_path and isinstance(local_path, str):
data_temp = get_astrbot_temp_path()
if local_path.startswith(data_temp) and await asyncio.to_thread(
os.path.exists,
if local_path.startswith(data_temp) and os.path.exists(
local_path,
):
try:
@@ -3,7 +3,6 @@ import json
import random
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any, NoReturn
try:
@@ -556,19 +555,22 @@ class MisskeyAPI:
form.add_field("folderId", str(folder_id))
try:
file_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
f = open(file_path, "rb")
except FileNotFoundError as e:
logger.error(f"[Misskey API] 本地文件不存在: {file_path}")
raise APIError(f"File not found: {file_path}") from e
form.add_field("file", file_bytes, filename=filename)
async with self.session.post(url, data=form) as resp:
result = await self._process_response(resp, "drive/files/create")
file_id = FileIDExtractor.extract_file_id(result)
logger.debug(
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}",
)
return {"id": file_id, "raw": result}
try:
form.add_field("file", f, filename=filename)
async with self.session.post(url, data=form) as resp:
result = await self._process_response(resp, "drive/files/create")
file_id = FileIDExtractor.extract_file_id(result)
logger.debug(
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}",
)
return {"id": file_id, "raw": result}
finally:
f.close()
except aiohttp.ClientError as e:
logger.error(f"[Misskey API] 文件上传网络错误: {e}")
raise APIConnectionError(f"Upload failed: {e}") from e
@@ -339,7 +339,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
# 处理文件数据
if await asyncio.to_thread(os.path.exists, file_source):
if os.path.exists(file_source):
# 读取本地文件
async with aiofiles.open(file_source, "rb") as f:
file_content = await f.read()
@@ -421,15 +421,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = await file_to_base64(i.file[8:])
image_base64 = file_to_base64(i.file[8:])
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = await file_to_base64(image_file_path)
image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"):
image_base64 = i.file
elif i.file:
image_base64 = await file_to_base64(i.file)
image_base64 = file_to_base64(i.file)
else:
raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://")
@@ -1,8 +1,9 @@
import asyncio
import os
import re
import sys
import uuid
from typing import cast, override
from typing import cast
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from telegram import BotCommand, Update
@@ -32,6 +33,11 @@ from astrbot.core.utils.media_utils import convert_audio_to_wav
from .tg_event import TelegramPlatformEvent
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@register_platform_adapter("telegram", "telegram 适配器")
class TelegramPlatformAdapter(Platform):
@@ -1,4 +1,3 @@
import asyncio
import json
import mimetypes
import shutil
@@ -140,15 +139,13 @@ async def parse_webchat_message_parts(
continue
file_path = Path(str(path))
if verify_media_path_exists and not await asyncio.to_thread(file_path.exists):
if verify_media_path_exists and not file_path.exists():
if strict:
raise ValueError(f"file not found: {file_path!s}")
continue
file_path_str = (
str(await asyncio.to_thread(file_path.resolve))
if verify_media_path_exists
else str(file_path)
str(file_path.resolve()) if verify_media_path_exists else str(file_path)
)
has_content = True
if part_type == "image":
@@ -369,7 +366,7 @@ async def message_chain_to_storage_message_parts(
attachments_dir: str | Path,
) -> list[dict]:
target_dir = Path(attachments_dir)
await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True)
target_dir.mkdir(parents=True, exist_ok=True)
parts: list[dict] = []
for comp in message_chain.chain:
@@ -445,9 +442,7 @@ async def _copy_file_to_attachment_part(
display_name: str | None = None,
) -> dict | None:
src_path = Path(file_path)
if not await asyncio.to_thread(src_path.exists) or not await asyncio.to_thread(
src_path.is_file
):
if not src_path.exists() or not src_path.is_file():
return None
suffix = src_path.suffix
@@ -1,10 +1,8 @@
import asyncio
import base64
import json
import os
import shutil
import uuid
from pathlib import Path
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -82,9 +80,8 @@ class WebChatMessageEvent(AstrMessageEvent):
filename = f"{str(uuid.uuid4())}.jpg"
path = os.path.join(attachments_dir, filename)
image_base64 = await comp.convert_to_base64()
await asyncio.to_thread(
Path(path).write_bytes, base64.b64decode(image_base64)
)
with open(path, "wb") as f:
f.write(base64.b64decode(image_base64))
data = f"[IMAGE]{filename}"
await web_chat_back_queue.put(
{
@@ -99,9 +96,8 @@ class WebChatMessageEvent(AstrMessageEvent):
filename = f"{str(uuid.uuid4())}.wav"
path = os.path.join(attachments_dir, filename)
record_base64 = await comp.convert_to_base64()
await asyncio.to_thread(
Path(path).write_bytes, base64.b64decode(record_base64)
)
with open(path, "wb") as f:
f.write(base64.b64decode(record_base64))
data = f"[RECORD]{filename}"
await web_chat_back_queue.put(
{
@@ -1,9 +1,9 @@
import asyncio
import os
import sys
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any, cast, override
from typing import Any, cast
import quart
from requests import Response
@@ -33,6 +33,11 @@ from .wecom_event import WecomPlatformEvent
from .wecom_kf import WeChatKF
from .wecom_kf_message import WeChatKFMessage
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
@@ -341,7 +346,8 @@ class WecomPlatformAdapter(Platform):
)
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
await asyncio.to_thread(Path(path).write_bytes, resp.content)
with open(path, "wb") as f:
f.write(resp.content)
try:
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
@@ -396,7 +402,8 @@ class WecomPlatformAdapter(Platform):
)
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg")
await asyncio.to_thread(Path(path).write_bytes, resp.content)
with open(path, "wb") as f:
f.write(resp.content)
abm.message = [Image(file=path, url=path)]
elif msgtype == "voice":
media_id = msg.get("voice", {}).get("media_id", "")
@@ -408,7 +415,8 @@ class WecomPlatformAdapter(Platform):
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr")
await asyncio.to_thread(Path(path).write_bytes, resp.content)
with open(path, "wb") as f:
f.write(resp.content)
try:
path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav")
+118 -155
View File
@@ -12,13 +12,6 @@ from astrbot.core.utils.media_utils import convert_audio_to_amr
from .wecom_kf_message import WeChatKFMessage
def _upload_media_from_path(
client: WeChatClient, media_type: str, file_path: str
) -> dict:
with open(file_path, "rb") as f:
return client.media.upload(media_type, f)
class WecomPlatformEvent(AstrMessageEvent):
def __init__(
self,
@@ -107,52 +100,45 @@ class WecomPlatformEvent(AstrMessageEvent):
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"image",
img_path,
with open(img_path, "rb") as f:
try:
response = self.client.media.upload("image", f)
except Exception as e:
logger.error(f"微信客服上传图片失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传图片失败: {e}"),
)
return
logger.debug(f"微信客服上传图片返回: {response}")
kf_message_api.send_image(
user_id,
self.get_self_id(),
response["media_id"],
)
except Exception as e:
logger.error(f"微信客服上传图片失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传图片失败: {e}"),
)
return
logger.debug(f"微信客服上传图片返回: {response}")
kf_message_api.send_image(
user_id,
self.get_self_id(),
response["media_id"],
)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
record_path_amr = await convert_audio_to_amr(record_path)
try:
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"voice",
record_path_amr,
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信客服上传语音失败: {e}")
await self.send(
MessageChain().message(
f"微信客服上传语音失败: {e}"
),
)
return
logger.info(f"微信客服上传语音返回: {response}")
kf_message_api.send_voice(
user_id,
self.get_self_id(),
response["media_id"],
)
except Exception as e:
logger.error(f"微信客服上传语音失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传语音失败: {e}"),
)
return
logger.info(f"微信客服上传语音返回: {response}")
kf_message_api.send_voice(
user_id,
self.get_self_id(),
response["media_id"],
)
finally:
if record_path_amr != record_path and await asyncio.to_thread(
os.path.exists,
if record_path_amr != record_path and os.path.exists(
record_path_amr,
):
try:
@@ -162,47 +148,39 @@ class WecomPlatformEvent(AstrMessageEvent):
elif isinstance(comp, File):
file_path = await comp.get_file()
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"file",
file_path,
with open(file_path, "rb") as f:
try:
response = self.client.media.upload("file", f)
except Exception as e:
logger.error(f"微信客服上传文件失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传文件失败: {e}"),
)
return
logger.debug(f"微信客服上传文件返回: {response}")
kf_message_api.send_file(
user_id,
self.get_self_id(),
response["media_id"],
)
except Exception as e:
logger.error(f"微信客服上传文件失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传文件失败: {e}"),
)
return
logger.debug(f"微信客服上传文件返回: {response}")
kf_message_api.send_file(
user_id,
self.get_self_id(),
response["media_id"],
)
elif isinstance(comp, Video):
video_path = await comp.convert_to_file_path()
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"video",
video_path,
with open(video_path, "rb") as f:
try:
response = self.client.media.upload("video", f)
except Exception as e:
logger.error(f"微信客服上传视频失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传视频失败: {e}"),
)
return
logger.debug(f"微信客服上传视频返回: {response}")
kf_message_api.send_video(
user_id,
self.get_self_id(),
response["media_id"],
)
except Exception as e:
logger.error(f"微信客服上传视频失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传视频失败: {e}"),
)
return
logger.debug(f"微信客服上传视频返回: {response}")
kf_message_api.send_video(
user_id,
self.get_self_id(),
response["media_id"],
)
else:
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")
else:
@@ -221,52 +199,45 @@ class WecomPlatformEvent(AstrMessageEvent):
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"image",
img_path,
with open(img_path, "rb") as f:
try:
response = self.client.media.upload("image", f)
except Exception as e:
logger.error(f"企业微信上传图片失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传图片失败: {e}"),
)
return
logger.debug(f"企业微信上传图片返回: {response}")
self.client.message.send_image(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
except Exception as e:
logger.error(f"企业微信上传图片失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传图片失败: {e}"),
)
return
logger.debug(f"企业微信上传图片返回: {response}")
self.client.message.send_image(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
record_path_amr = await convert_audio_to_amr(record_path)
try:
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"voice",
record_path_amr,
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"企业微信上传语音失败: {e}")
await self.send(
MessageChain().message(
f"企业微信上传语音失败: {e}"
),
)
return
logger.info(f"企业微信上传语音返回: {response}")
self.client.message.send_voice(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
except Exception as e:
logger.error(f"企业微信上传语音失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传语音失败: {e}"),
)
return
logger.info(f"企业微信上传语音返回: {response}")
self.client.message.send_voice(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
finally:
if record_path_amr != record_path and await asyncio.to_thread(
os.path.exists,
if record_path_amr != record_path and os.path.exists(
record_path_amr,
):
try:
@@ -276,47 +247,39 @@ class WecomPlatformEvent(AstrMessageEvent):
elif isinstance(comp, File):
file_path = await comp.get_file()
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"file",
file_path,
with open(file_path, "rb") as f:
try:
response = self.client.media.upload("file", f)
except Exception as e:
logger.error(f"企业微信上传文件失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传文件失败: {e}"),
)
return
logger.debug(f"企业微信上传文件返回: {response}")
self.client.message.send_file(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
except Exception as e:
logger.error(f"企业微信上传文件失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传文件失败: {e}"),
)
return
logger.debug(f"企业微信上传文件返回: {response}")
self.client.message.send_file(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
elif isinstance(comp, Video):
video_path = await comp.convert_to_file_path()
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"video",
video_path,
with open(video_path, "rb") as f:
try:
response = self.client.media.upload("video", f)
except Exception as e:
logger.error(f"企业微信上传视频失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传视频失败: {e}"),
)
return
logger.debug(f"企业微信上传视频返回: {response}")
self.client.message.send_video(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
except Exception as e:
logger.error(f"企业微信上传视频失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传视频失败: {e}"),
)
return
logger.debug(f"企业微信上传视频返回: {response}")
self.client.message.send_video(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
else:
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")
@@ -2,6 +2,7 @@
提供常量定义工具函数和辅助方法
"""
import asyncio
import base64
import hashlib
import secrets
@@ -173,7 +174,7 @@ async def process_encrypted_image(
response.raise_for_status()
encrypted_data = await response.read()
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
except (TimeoutError, aiohttp.ClientError) as e:
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
error_msg = f"下载图片失败: {e!s}"
logger.error(error_msg)
return False, error_msg
@@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
import base64
import hashlib
import mimetypes
@@ -104,9 +103,7 @@ class WecomAIBotWebhookClient:
async def upload_media(
self, file_path: Path, media_type: Literal["file", "voice"]
) -> str:
if not await asyncio.to_thread(file_path.exists) or not await asyncio.to_thread(
file_path.is_file
):
if not file_path.exists() or not file_path.is_file():
raise WecomAIBotWebhookError(f"文件不存在: {file_path}")
content_type = (
@@ -115,7 +112,7 @@ class WecomAIBotWebhookClient:
form = aiohttp.FormData()
form.add_field(
"media",
await asyncio.to_thread(file_path.read_bytes),
file_path.read_bytes(),
filename=file_path.name,
content_type=content_type,
)
@@ -1,10 +1,10 @@
import asyncio
import os
import sys
import time
import uuid
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any, cast, override
from typing import Any, cast
import quart
from requests import Response
@@ -32,6 +32,11 @@ from astrbot.core.utils.webhook_utils import log_webhook_info
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class WeixinOfficialAccountServer:
def __init__(
@@ -374,7 +379,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
) # wait for 180s
logger.debug(f"Got future result: {result}")
return result
except TimeoutError:
except asyncio.TimeoutError:
logger.info(f"callback 处理消息超时: message_id={msg.id}")
return create_reply("处理消息超时,请稍后再试。", msg)
except Exception as e:
@@ -463,7 +468,8 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
)
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr")
await asyncio.to_thread(Path(path).write_bytes, resp.content)
with open(path, "wb") as f:
f.write(resp.content)
try:
path_wav = os.path.join(
@@ -12,13 +12,6 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.media_utils import convert_audio_to_amr
def _upload_media_from_path(
client: WeChatClient, media_type: str, file_path: str
) -> dict:
with open(file_path, "rb") as f:
return client.media.upload(media_type, f)
class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
def __init__(
self,
@@ -108,63 +101,24 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"image",
img_path,
)
except Exception as e:
logger.error(f"微信公众平台上传图片失败: {e}")
await self.send(
MessageChain().message(f"微信公众平台上传图片失败: {e}"),
)
return
logger.debug(f"微信公众平台上传图片返回: {response}")
if active_send_mode:
self.client.message.send_image(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = ImageReply(
media_id=response["media_id"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
record_path_amr = await convert_audio_to_amr(record_path)
try:
with open(img_path, "rb") as f:
try:
response = await asyncio.to_thread(
_upload_media_from_path,
self.client,
"voice",
record_path_amr,
)
response = self.client.media.upload("image", f)
except Exception as e:
logger.error(f"微信公众平台上传语音失败: {e}")
logger.error(f"微信公众平台上传图片失败: {e}")
await self.send(
MessageChain().message(f"微信公众平台上传语音失败: {e}"),
MessageChain().message(f"微信公众平台上传图片失败: {e}"),
)
return
logger.info(f"微信公众平台上传语音返回: {response}")
logger.debug(f"微信公众平台上传图片返回: {response}")
if active_send_mode:
self.client.message.send_voice(
self.client.message.send_image(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = VoiceReply(
reply = ImageReply(
media_id=response["media_id"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
@@ -172,9 +126,44 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
record_path_amr = await convert_audio_to_amr(record_path)
try:
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信公众平台上传语音失败: {e}")
await self.send(
MessageChain().message(
f"微信公众平台上传语音失败: {e}"
),
)
return
logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = VoiceReply(
media_id=response["media_id"],
message=cast(dict, self.message_obj.raw_message)[
"message"
],
)
xml = reply.render()
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
finally:
if record_path_amr != record_path and await asyncio.to_thread(
os.path.exists, record_path_amr
if record_path_amr != record_path and os.path.exists(
record_path_amr
):
try:
os.remove(record_path_amr)
+3 -6
View File
@@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
import base64
import enum
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from anthropic.types import Message as AnthropicMessage
@@ -220,10 +218,9 @@ class ProviderRequest:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
image_bs64 = base64.b64encode(
await asyncio.to_thread(Path(image_url).read_bytes)
).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
+25 -49
View File
@@ -8,7 +8,6 @@ import threading
import urllib.parse
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from dataclasses import dataclass
from pathlib import Path
from types import MappingProxyType
from typing import Any
@@ -199,7 +198,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
except TimeoutError:
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
@@ -378,24 +377,15 @@ class FunctionToolManager:
data_dir = get_astrbot_data_path()
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
if not await asyncio.to_thread(os.path.exists, mcp_json_file):
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
config_text = json.dumps(DEFAULT_MCP_CONFIG, ensure_ascii=False, indent=4)
await asyncio.to_thread(
Path(mcp_json_file).write_text,
config_text,
encoding="utf-8",
)
with open(mcp_json_file, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return MCPInitSummary(total=0, success=0, failed=[])
mcp_json_content = await asyncio.to_thread(
Path(mcp_json_file).read_text,
encoding="utf-8",
)
mcp_server_json_obj: dict[str, dict] = json.loads(mcp_json_content)[
"mcpServers"
]
with open(mcp_json_file, encoding="utf-8") as f:
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
init_timeout_value = _resolve_timeout(
timeout=init_timeout,
@@ -469,7 +459,7 @@ class FunctionToolManager:
cfg: dict,
*,
shutdown_event: asyncio.Event | None = None,
timeout_seconds: float,
timeout: float,
) -> None:
"""Initialize MCP server with timeout and register task/event together.
@@ -479,7 +469,7 @@ class FunctionToolManager:
async with self._runtime_lock:
if name in self._mcp_server_runtime or name in self._mcp_starting:
logger.warning(
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout_seconds:g})。"
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。"
)
self._log_safe_mcp_debug_config(cfg)
return
@@ -492,11 +482,11 @@ class FunctionToolManager:
try:
mcp_client = await asyncio.wait_for(
self._init_mcp_client(name, cfg),
timeout=timeout_seconds,
timeout=timeout,
)
except TimeoutError as exc:
except asyncio.TimeoutError as exc:
raise MCPInitTimeoutError(
f"MCP 服务 {name} 初始化超时({timeout_seconds:g} 秒)"
f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
) from exc
except Exception:
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
@@ -529,7 +519,7 @@ class FunctionToolManager:
async def _shutdown_runtimes(
self,
runtimes: list[_MCPServerRuntime],
timeout_seconds: float,
timeout: float,
*,
strict: bool = True,
) -> list[str]:
@@ -548,9 +538,9 @@ class FunctionToolManager:
try:
results = await asyncio.wait_for(
asyncio.gather(*lifecycle_tasks, return_exceptions=True),
timeout=timeout_seconds,
timeout=timeout,
)
except TimeoutError:
except asyncio.TimeoutError:
pending_names = [
runtime.name
for runtime in runtimes
@@ -561,10 +551,10 @@ class FunctionToolManager:
task.cancel()
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
if strict:
raise MCPShutdownTimeoutError(pending_names, timeout_seconds)
raise MCPShutdownTimeoutError(pending_names, timeout)
logger.warning(
"MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s",
f"{timeout_seconds:g}",
f"{timeout:g}",
", ".join(pending_names),
)
return pending_names
@@ -675,8 +665,7 @@ class FunctionToolManager:
name: str,
config: dict,
shutdown_event: asyncio.Event | None = None,
timeout_seconds: float | int | str | None = None,
**kwargs: Any,
timeout: float | int | str | None = None,
) -> None:
"""Enable a new MCP server and initialize it.
@@ -684,22 +673,18 @@ class FunctionToolManager:
name: The name of the MCP server.
config: Configuration for the MCP server.
shutdown_event: Event to signal when the MCP client should shut down.
timeout_seconds: Timeout in seconds for initialization.
timeout: Timeout in seconds for initialization.
Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout).
Raises:
MCPInitTimeoutError: If initialization does not complete within timeout.
Exception: If there is an error during initialization.
"""
legacy_timeout = kwargs.pop("timeout", None)
if legacy_timeout is not None:
timeout_seconds = legacy_timeout
if timeout_seconds is None:
if timeout is None:
timeout_value = self._enable_timeout_default
else:
timeout_value = _resolve_timeout(
timeout=timeout_seconds,
timeout=timeout,
env_name=ENABLE_MCP_TIMEOUT_ENV,
default=self._enable_timeout_default,
)
@@ -707,45 +692,36 @@ class FunctionToolManager:
name=name,
cfg=config,
shutdown_event=shutdown_event,
timeout_seconds=timeout_value,
timeout=timeout_value,
)
async def disable_mcp_server(
self,
name: str | None = None,
timeout_seconds: float = 10,
**kwargs: Any,
timeout: float = 10,
) -> None:
"""Disable an MCP server by its name.
Args:
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
timeout_seconds (int): Timeout.
timeout (int): Timeout.
Raises:
MCPShutdownTimeoutError: If shutdown does not complete within timeout.
Only raised when disabling a specific server (name is not None).
"""
legacy_timeout = kwargs.pop("timeout", None)
if legacy_timeout is not None:
timeout_seconds = float(legacy_timeout)
if name:
async with self._runtime_lock:
runtime = self._mcp_server_runtime.get(name)
if runtime is None:
return
await self._shutdown_runtimes(
[runtime], timeout_seconds=timeout_seconds, strict=True
)
await self._shutdown_runtimes([runtime], timeout, strict=True)
else:
async with self._runtime_lock:
runtimes = list(self._mcp_server_runtime.values())
await self._shutdown_runtimes(
runtimes, timeout_seconds=timeout_seconds, strict=False
)
await self._shutdown_runtimes(runtimes, timeout, strict=False)
def _warn_on_timeout_mismatch(
self,
+12 -13
View File
@@ -2,8 +2,7 @@ import abc
import asyncio
import os
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from typing import TypeAlias, Union
from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
@@ -16,9 +15,13 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
type Providers = (
"Provider" | "STTProvider" | "TTSProvider" | "EmbeddingProvider" | "RerankProvider"
)
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC):
@@ -185,13 +188,10 @@ class Provider(AbstractProvider):
return dicts
async def test(self, timeout_seconds: float = 45.0, **kwargs: Any) -> None:
legacy_timeout = kwargs.pop("timeout", None)
if legacy_timeout is not None:
timeout_seconds = float(legacy_timeout)
async def test(self, timeout: float = 45.0) -> None:
await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout_seconds,
timeout=timeout,
)
@@ -268,9 +268,8 @@ class TTSProvider(AbstractProvider):
# 调用原有的 get_audio 方法获取音频文件路径
audio_path = await self.get_audio(accumulated_text)
# 读取音频文件内容
audio_data = await asyncio.to_thread(
Path(audio_path).read_bytes
)
with open(audio_path, "rb") as f:
audio_data = f.read()
await audio_queue.put((accumulated_text, audio_data))
except Exception:
# 出错时也要发送 None 结束标记
@@ -1,8 +1,6 @@
import asyncio
import base64
import json
from collections.abc import AsyncGenerator
from pathlib import Path
import anthropic
import httpx
@@ -639,10 +637,11 @@ class ProviderAnthropic(Provider):
except Exception:
mime_type = "image/jpeg"
return f"data:{mime_type};base64,{raw_base64}", mime_type
image_bytes = await asyncio.to_thread(Path(image_url).read_bytes)
mime_type = self._detect_image_mime_type(image_bytes)
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}", mime_type
with open(image_url, "rb") as f:
image_bytes = f.read()
mime_type = self._detect_image_mime_type(image_bytes)
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}", mime_type
return "", "image/jpeg"
def get_current_key(self) -> str:
@@ -3,7 +3,6 @@ import base64
import logging
import os
import uuid
from pathlib import Path
import aiohttp
import dashscope
@@ -60,7 +59,8 @@ class ProviderDashscopeTTSAPI(TTSProvider):
)
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
await asyncio.to_thread(Path(path).write_bytes, audio_bytes)
with open(path, "wb") as f:
f.write(audio_bytes)
return path
def _call_qwen_tts(self, model: str, text: str):
@@ -129,7 +129,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
) as response,
):
return await response.read()
except (TimeoutError, aiohttp.ClientError, OSError) as e:
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
logging.exception(f"Failed to download audio from URL {url}: {e}")
return None
+126 -129
View File
@@ -1,129 +1,126 @@
import asyncio
import os
import subprocess
import uuid
import edge_tts
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
"""
edge_tts 方式能够免费快速生成语音使用需要先安装edge-tts库
```
pip install edge_tts
```
Windows 如果提示找不到指定文件以管理员身份运行命令行窗口然后再次运行 AstrBot
"""
@register_provider_adapter(
"edge_tts",
"Microsoft Edge TTS",
provider_type=ProviderType.TEXT_TO_SPEECH,
)
class ProviderEdgeTTS(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
# 设置默认语音,如果没有指定则使用中文小萱
self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural")
self.rate = provider_config.get("rate")
self.volume = provider_config.get("volume")
self.pitch = provider_config.get("pitch")
self.timeout = provider_config.get("timeout", 30)
self.proxy = os.getenv("https_proxy", None)
self.set_model("edge_tts")
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
# 构建 Edge TTS 参数
kwargs = {"text": text, "voice": self.voice}
if self.rate:
kwargs["rate"] = self.rate
if self.volume:
kwargs["volume"] = self.volume
if self.pitch:
kwargs["pitch"] = self.pitch
try:
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
await communicate.save(mp3_path)
try:
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line
# 使用ffmpeg将MP3转换为标准WAV格式
p = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y", # 覆盖输出文件
"-i",
mp3_path, # 输入文件
"-acodec",
"pcm_s16le", # 16位PCM编码
"-ar",
"24000", # 采样率24kHz (适合微信语音)
"-ac",
"1", # 单声道
"-af",
"apad=pad_dur=2", # 确保输出时长准确
"-fflags",
"+genpts", # 强制生成时间戳
"-hide_banner", # 隐藏版本信息
wav_path, # 输出文件
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# 等待进程完成并获取输出
stdout, stderr = await p.communicate()
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
os.remove(mp3_path)
if (
await asyncio.to_thread(os.path.exists, wav_path)
and await asyncio.to_thread(os.path.getsize, wav_path) > 0
):
return wav_path
logger.error("生成的WAV文件不存在或为空")
raise RuntimeError("生成的WAV文件不存在或为空")
except subprocess.CalledProcessError as e:
logger.error(
f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}",
)
try:
if await asyncio.to_thread(os.path.exists, mp3_path):
os.remove(mp3_path)
except Exception:
pass
raise RuntimeError(f"FFmpeg 转换失败: {e!s}")
except Exception as e:
logger.error(f"音频生成失败: {e!s}")
try:
if await asyncio.to_thread(os.path.exists, mp3_path):
os.remove(mp3_path)
except Exception:
pass
raise RuntimeError(f"音频生成失败: {e!s}")
import asyncio
import os
import subprocess
import uuid
import edge_tts
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
"""
edge_tts 方式能够免费快速生成语音使用需要先安装edge-tts库
```
pip install edge_tts
```
Windows 如果提示找不到指定文件以管理员身份运行命令行窗口然后再次运行 AstrBot
"""
@register_provider_adapter(
"edge_tts",
"Microsoft Edge TTS",
provider_type=ProviderType.TEXT_TO_SPEECH,
)
class ProviderEdgeTTS(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
# 设置默认语音,如果没有指定则使用中文小萱
self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural")
self.rate = provider_config.get("rate")
self.volume = provider_config.get("volume")
self.pitch = provider_config.get("pitch")
self.timeout = provider_config.get("timeout", 30)
self.proxy = os.getenv("https_proxy", None)
self.set_model("edge_tts")
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
# 构建 Edge TTS 参数
kwargs = {"text": text, "voice": self.voice}
if self.rate:
kwargs["rate"] = self.rate
if self.volume:
kwargs["volume"] = self.volume
if self.pitch:
kwargs["pitch"] = self.pitch
try:
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
await communicate.save(mp3_path)
try:
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line
# 使用ffmpeg将MP3转换为标准WAV格式
p = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y", # 覆盖输出文件
"-i",
mp3_path, # 输入文件
"-acodec",
"pcm_s16le", # 16位PCM编码
"-ar",
"24000", # 采样率24kHz (适合微信语音)
"-ac",
"1", # 单声道
"-af",
"apad=pad_dur=2", # 确保输出时长准确
"-fflags",
"+genpts", # 强制生成时间戳
"-hide_banner", # 隐藏版本信息
wav_path, # 输出文件
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# 等待进程完成并获取输出
stdout, stderr = await p.communicate()
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
os.remove(mp3_path)
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
return wav_path
logger.error("生成的WAV文件不存在或为空")
raise RuntimeError("生成的WAV文件不存在或为空")
except subprocess.CalledProcessError as e:
logger.error(
f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}",
)
try:
if os.path.exists(mp3_path):
os.remove(mp3_path)
except Exception:
pass
raise RuntimeError(f"FFmpeg 转换失败: {e!s}")
except Exception as e:
logger.error(f"音频生成失败: {e!s}")
try:
if os.path.exists(mp3_path):
os.remove(mp3_path)
except Exception:
pass
raise RuntimeError(f"音频生成失败: {e!s}")
@@ -1,8 +1,6 @@
import asyncio
import os
import re
import uuid
from pathlib import Path
from typing import Annotated, Literal
import ormsgpack
@@ -161,10 +159,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
if response.status_code == 200 and response.headers.get(
"content-type", ""
).startswith("audio/"):
audio_data = bytearray()
async for chunk in response.aiter_bytes():
audio_data.extend(chunk)
await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data))
with open(path, "wb") as f:
async for chunk in response.aiter_bytes():
f.write(chunk)
return path
error_bytes = await response.aread()
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
@@ -4,7 +4,6 @@ import json
import logging
import random
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import cast
from google import genai
@@ -925,10 +924,9 @@ class ProviderGoogleGenAI(Provider):
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
image_bs64 = base64.b64encode(
await asyncio.to_thread(Path(image_url).read_bytes)
).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
async def terminate(self) -> None:
if self.client:
+4 -4
View File
@@ -1,7 +1,6 @@
import asyncio
import os
import uuid
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.entities import ProviderType
@@ -73,7 +72,7 @@ class GenieTTSProvider(TTSProvider):
try:
await loop.run_in_executor(None, _generate, path)
if await asyncio.to_thread(os.path.exists, path):
if os.path.exists(path):
return path
raise RuntimeError("Genie TTS did not save to file.")
@@ -110,8 +109,9 @@ class GenieTTSProvider(TTSProvider):
await loop.run_in_executor(None, _generate, path, text)
if await asyncio.to_thread(os.path.exists, path):
audio_data = await asyncio.to_thread(Path(path).read_bytes)
if os.path.exists(path):
with open(path, "rb") as f:
audio_data = f.read()
# Put (text, bytes) into queue so frontend can display text
await audio_queue.put((text, audio_data))
@@ -1,7 +1,6 @@
import asyncio
import os
import uuid
from pathlib import Path
import aiohttp
@@ -130,7 +129,8 @@ class ProviderGSVTTS(TTSProvider):
result = await self._make_request(endpoint, params)
if isinstance(result, bytes):
await asyncio.to_thread(Path(path).write_bytes, result)
with open(path, "wb") as f:
f.write(result)
return path
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
@@ -1,62 +1,59 @@
import asyncio
import os
import urllib.parse
import uuid
from pathlib import Path
import aiohttp
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
@register_provider_adapter(
"gsvi_tts_api",
"GSVI TTS API",
provider_type=ProviderType.TEXT_TO_SPEECH,
)
class ProviderGSVITTS(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000")
self.api_base = self.api_base.removesuffix("/")
self.character = provider_config.get("character")
self.emotion = provider_config.get("emotion")
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
params = {"text": text}
if self.character:
params["character"] = self.character
if self.emotion:
params["emotion"] = self.emotion
query_parts = []
for key, value in params.items():
encoded_value = urllib.parse.quote(str(value))
query_parts.append(f"{key}={encoded_value}")
url = f"{self.api_base}/tts?{'&'.join(query_parts)}"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
await asyncio.to_thread(
Path(path).write_bytes, await response.read()
)
else:
error_text = await response.text()
raise Exception(
f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}",
)
return path
import os
import urllib.parse
import uuid
import aiohttp
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
@register_provider_adapter(
"gsvi_tts_api",
"GSVI TTS API",
provider_type=ProviderType.TEXT_TO_SPEECH,
)
class ProviderGSVITTS(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000")
self.api_base = self.api_base.removesuffix("/")
self.character = provider_config.get("character")
self.emotion = provider_config.get("emotion")
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
params = {"text": text}
if self.character:
params["character"] = self.character
if self.emotion:
params["emotion"] = self.emotion
query_parts = []
for key, value in params.items():
encoded_value = urllib.parse.quote(str(value))
query_parts.append(f"{key}={encoded_value}")
url = f"{self.api_base}/tts?{'&'.join(query_parts)}"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
with open(path, "wb") as f:
f.write(await response.read())
else:
error_text = await response.text()
raise Exception(
f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}",
)
return path
@@ -1,9 +1,7 @@
import asyncio
import json
import os
import uuid
from collections.abc import AsyncIterator
from pathlib import Path
import aiohttp
@@ -157,7 +155,8 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
audio = await self._audio_play(audio_stream)
# 结果保存至文件
await asyncio.to_thread(Path(path).write_bytes, audio)
with open(path, "wb") as file:
file.write(audio)
return path
@@ -5,7 +5,6 @@ import json
import random
import re
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
import httpx
@@ -950,10 +949,9 @@ class ProviderOpenAIOfficial(Provider):
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
image_bs64 = base64.b64encode(
await asyncio.to_thread(Path(image_url).read_bytes)
).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
async def terminate(self):
if self.client:
@@ -1,7 +1,5 @@
import asyncio
import os
import uuid
from pathlib import Path
import httpx
from openai import NOT_GIVEN, AsyncOpenAI
@@ -56,10 +54,9 @@ class ProviderOpenAITTSAPI(TTSProvider):
response_format="wav",
input=text,
) as response:
audio_data = bytearray()
async for chunk in response.iter_bytes(chunk_size=1024):
audio_data.extend(chunk)
await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data))
with open(path, "wb") as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
return path
async def terminate(self):
@@ -53,12 +53,14 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
async def get_timestamped_path(self) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_dir = Path(get_astrbot_temp_path())
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
return str(temp_dir / timestamp)
async def _is_silk_file(self, file_path) -> bool:
silk_header = b"SILK"
file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8]
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
return False
@@ -74,7 +76,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
await download_file(audio_url, path)
audio_url = path
if not await asyncio.to_thread(os.path.isfile, audio_url):
if not os.path.isfile(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith((".amr", ".silk")) or is_tencent:
@@ -4,7 +4,6 @@ import json
import os
import traceback
import uuid
from pathlib import Path
import aiohttp
@@ -101,9 +100,10 @@ class ProviderVolcengineTTS(TTSProvider):
f"volcengine_tts_{uuid.uuid4()}.mp3",
)
await asyncio.to_thread(
Path(file_path).write_bytes,
audio_data,
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: open(file_path, "wb").write(audio_data),
)
return file_path
@@ -1,7 +1,5 @@
import asyncio
import os
import uuid
from pathlib import Path
from openai import NOT_GIVEN, AsyncOpenAI
@@ -46,7 +44,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
amr_header = b"#!AMR"
try:
file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8]
with open(file_path, "rb") as f:
file_header = f.read(8)
except FileNotFoundError:
return None
@@ -74,7 +73,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
await download_file(audio_url, path)
audio_url = path
if not await asyncio.to_thread(os.path.exists, audio_url):
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
@@ -101,14 +100,13 @@ class ProviderOpenAIWhisperAPI(STTProvider):
audio_url = output_path
audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes)
result = await self.client.audio.transcriptions.create(
model=self.model_name,
file=("audio.wav", audio_bytes),
file=("audio.wav", open(audio_url, "rb")),
)
# remove temp file
if output_path and await asyncio.to_thread(os.path.exists, output_path):
if output_path and os.path.exists(output_path):
try:
os.remove(audio_url)
except Exception as e:
@@ -1,7 +1,6 @@
import asyncio
import os
import uuid
from pathlib import Path
from typing import cast
import whisper
@@ -43,7 +42,9 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
async def _is_silk_file(self, file_path) -> bool:
silk_header = b"SILK"
file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8]
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
return False
@@ -65,7 +66,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
await download_file(audio_url, path)
audio_url = path
if not await asyncio.to_thread(os.path.exists, audio_url):
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
@@ -1,7 +1,5 @@
import asyncio
import os
import uuid
from pathlib import Path
import aiohttp
from xinference_client.client.restful.async_restful_client import (
@@ -104,8 +102,9 @@ class ProviderXinferenceSTT(STTProvider):
f"Failed to download audio from {audio_url}, status: {resp.status}",
)
return ""
elif await asyncio.to_thread(os.path.exists, audio_url):
audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes)
elif os.path.exists(audio_url):
with open(audio_url, "rb") as f:
audio_bytes = f.read()
else:
logger.error(f"File not found: {audio_url}")
return ""
@@ -144,7 +143,8 @@ class ProviderXinferenceSTT(STTProvider):
)
temp_files.extend([input_path, output_path])
await asyncio.to_thread(Path(input_path).write_bytes, audio_bytes)
with open(input_path, "wb") as f:
f.write(audio_bytes)
if conversion_type == "silk":
logger.info("Converting silk to wav ...")
@@ -153,7 +153,8 @@ class ProviderXinferenceSTT(STTProvider):
logger.info("Converting amr to wav ...")
await convert_to_pcm_wav(input_path, output_path)
audio_bytes = await asyncio.to_thread(Path(output_path).read_bytes)
with open(output_path, "rb") as f:
audio_bytes = f.read()
# 4. Transcribe
# 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来
@@ -198,7 +199,7 @@ class ProviderXinferenceSTT(STTProvider):
# 5. Cleanup
for temp_file in temp_files:
try:
if await asyncio.to_thread(os.path.exists, temp_file):
if os.path.exists(temp_file):
os.remove(temp_file)
logger.debug(f"Removed temporary file: {temp_file}")
except Exception as e:
+2 -2
View File
@@ -5,7 +5,7 @@ import json
import os
import re
from dataclasses import dataclass
from datetime import UTC, datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
@@ -19,7 +19,7 @@ _SKILL_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
def _now_iso() -> str:
return datetime.now(UTC).isoformat()
return datetime.now(timezone.utc).isoformat()
def _to_jsonable(model_like: Any) -> dict[str, Any]:
+2 -2
View File
@@ -7,7 +7,7 @@ import shutil
import tempfile
import zipfile
from dataclasses import dataclass
from datetime import UTC, datetime
from datetime import datetime, timezone
from pathlib import Path, PurePosixPath
from astrbot.core.utils.astrbot_path import (
@@ -175,7 +175,7 @@ class SkillManager:
def _save_sandbox_skills_cache(self, cache: dict) -> None:
cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION
cache["updated_at"] = datetime.now(UTC).isoformat()
cache["updated_at"] = datetime.now(timezone.utc).isoformat()
with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
+3 -3
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import enum
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Literal, TypeVar, overload
from typing import Any, Generic, Literal, TypeVar, overload
from .filter import HandlerFilter
from .star import star_map
@@ -11,7 +11,7 @@ from .star import star_map
T = TypeVar("T", bound="StarHandlerMetadata")
class StarHandlerRegistry[T: "StarHandlerMetadata"]:
class StarHandlerRegistry(Generic[T]):
def __init__(self) -> None:
self.star_handlers_map: dict[str, StarHandlerMetadata] = {}
self._handlers: list[StarHandlerMetadata] = []
@@ -227,7 +227,7 @@ H = TypeVar("H", bound=Callable[..., Any])
@dataclass
class StarHandlerMetadata[H: Callable[..., Any]]:
class StarHandlerMetadata(Generic[H]):
"""描述一个 Star 所注册的某一个 Handler。"""
event_type: EventType
+24 -32
View File
@@ -8,7 +8,6 @@ import logging
import os
import sys
import traceback
from pathlib import Path
from types import ModuleType
import yaml
@@ -189,7 +188,7 @@ class PluginManager:
如果 target_plugin None则检查所有插件的依赖
"""
plugin_dir = self.plugin_store_path
if not await asyncio.to_thread(os.path.exists, plugin_dir):
if not os.path.exists(plugin_dir):
return False
to_update = []
if target_plugin:
@@ -199,9 +198,7 @@ class PluginManager:
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
if await asyncio.to_thread(
os.path.exists, os.path.join(plugin_path, "requirements.txt")
):
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
try:
@@ -220,7 +217,7 @@ class PluginManager:
try:
return __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError) as import_exc:
if await asyncio.to_thread(os.path.exists, requirements_path):
if os.path.exists(requirements_path):
try:
logger.info(
f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}"
@@ -654,19 +651,16 @@ class PluginManager:
plugin_dir_path,
self.conf_schema_fname,
)
if await asyncio.to_thread(os.path.exists, plugin_schema_path):
if os.path.exists(plugin_schema_path):
# 加载插件配置
plugin_schema_text = await asyncio.to_thread(
Path(plugin_schema_path).read_text,
encoding="utf-8",
)
plugin_config = AstrBotConfig(
config_path=os.path.join(
self.plugin_config_path,
f"{root_dir_name}_config.json",
),
schema=json.loads(plugin_schema_text),
)
with open(plugin_schema_path, encoding="utf-8") as f:
plugin_config = AstrBotConfig(
config_path=os.path.join(
self.plugin_config_path,
f"{root_dir_name}_config.json",
),
schema=json.loads(f.read()),
)
logo_path = os.path.join(plugin_dir_path, self.logo_fname)
if path in star_map:
@@ -842,7 +836,7 @@ class PluginManager:
metadata.activated = False
# Plugin logo path
if await asyncio.to_thread(os.path.exists, logo_path):
if os.path.exists(logo_path):
metadata.logo_path = logo_path
assert metadata.module_path, f"插件 {metadata.name} 模块路径为空"
@@ -961,7 +955,7 @@ class PluginManager:
except Exception:
logger.warning(traceback.format_exc())
if await asyncio.to_thread(os.path.exists, plugin_path):
if os.path.exists(plugin_path):
try:
remove_dir(plugin_path)
logger.warning(f"已清理安装失败的插件目录: {plugin_path}")
@@ -974,7 +968,7 @@ class PluginManager:
self.plugin_config_path,
f"{dir_name}_config.json",
)
if await asyncio.to_thread(os.path.exists, plugin_config_path):
if os.path.exists(plugin_config_path):
try:
os.remove(plugin_config_path)
logger.warning(f"已清理安装失败插件配置: {plugin_config_path}")
@@ -1106,14 +1100,13 @@ class PluginManager:
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(plugin_path, "README.md")
if not await asyncio.to_thread(os.path.exists, readme_path):
if not os.path.exists(readme_path):
readme_path = os.path.join(plugin_path, "readme.md")
if await asyncio.to_thread(os.path.exists, readme_path):
if os.path.exists(readme_path):
try:
readme_content = await asyncio.to_thread(
Path(readme_path).read_text, encoding="utf-8"
)
with open(readme_path, encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(
f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}",
@@ -1218,7 +1211,7 @@ class PluginManager:
self._cleanup_plugin_state(dir_name)
plugin_path = os.path.join(self.plugin_store_path, dir_name)
if await asyncio.to_thread(os.path.exists, plugin_path):
if os.path.exists(plugin_path):
try:
remove_dir(plugin_path)
except Exception as e:
@@ -1505,14 +1498,13 @@ class PluginManager:
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(desti_dir, "README.md")
if not await asyncio.to_thread(os.path.exists, readme_path):
if not os.path.exists(readme_path):
readme_path = os.path.join(desti_dir, "readme.md")
if await asyncio.to_thread(os.path.exists, readme_path):
if os.path.exists(readme_path):
try:
readme_content = await asyncio.to_thread(
Path(readme_path).read_text, encoding="utf-8"
)
with open(readme_path, encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}")
+3 -3
View File
@@ -1,4 +1,4 @@
from datetime import UTC, datetime
from datetime import datetime, timezone
def normalize_datetime_utc(dt: datetime | None) -> datetime | None:
@@ -9,8 +9,8 @@ def normalize_datetime_utc(dt: datetime | None) -> datetime | None:
if dt is None:
return None
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
return dt.replace(tzinfo=UTC)
return dt.astimezone(UTC)
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def to_utc_isoformat(dt: datetime | None) -> str | None:
+57 -96
View File
@@ -1,4 +1,3 @@
import asyncio
import base64
import logging
import os
@@ -9,7 +8,6 @@ import time
import uuid
import zipfile
from pathlib import Path
from typing import BinaryIO
import aiohttp
import certifi
@@ -19,8 +17,6 @@ from PIL import Image
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
logger = logging.getLogger("astrbot")
_DOWNLOAD_READ_CHUNK_SIZE = 64 * 1024
_DOWNLOAD_FLUSH_THRESHOLD = 256 * 1024
def on_error(func, path, exc_info) -> None:
@@ -62,7 +58,8 @@ def save_temp_img(img: Image.Image | bytes) -> str:
if isinstance(img, Image.Image):
img.save(p)
else:
Path(p).write_bytes(img)
with open(p, "wb") as f:
f.write(img)
return p
@@ -86,13 +83,15 @@ async def download_image_by_url(
async with session.post(url, json=post_data) as resp:
if not path:
return save_temp_img(await resp.read())
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
if not path:
return save_temp_img(await resp.read())
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证(仅在证书验证失败时作为fallback)
@@ -110,13 +109,15 @@ async def download_image_by_url(
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
except Exception as e:
raise e
@@ -137,20 +138,28 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
file_obj = await asyncio.to_thread(Path(path).open, "wb")
try:
await _stream_to_file(
resp.content,
file_obj,
total_size=total_size,
start_time=start_time,
show_progress=show_progress,
)
finally:
await asyncio.to_thread(file_obj.close)
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = (
time.time() - start_time
if time.time() - start_time > 0
else 1
)
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证(仅在证书验证失败时作为fallback)
logger.warning(
@@ -168,76 +177,32 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
file_obj = await asyncio.to_thread(Path(path).open, "wb")
try:
await _stream_to_file(
resp.content,
file_obj,
total_size=total_size,
start_time=start_time,
show_progress=show_progress,
)
finally:
await asyncio.to_thread(file_obj.close)
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
if show_progress:
print()
async def _stream_to_file(
stream: aiohttp.StreamReader,
file_obj: BinaryIO,
*,
total_size: int,
start_time: float,
show_progress: bool,
) -> None:
"""Stream HTTP response into file with buffered thread-offloaded writes."""
downloaded_size = 0
known_total = total_size if total_size > 0 else None
buffered = bytearray()
try:
while True:
chunk = await stream.read(_DOWNLOAD_READ_CHUNK_SIZE)
if not chunk:
break
buffered.extend(chunk)
downloaded_size += len(chunk)
if len(buffered) >= _DOWNLOAD_FLUSH_THRESHOLD:
await asyncio.to_thread(file_obj.write, bytes(buffered))
buffered.clear()
if show_progress:
_print_download_progress(downloaded_size, known_total, start_time)
finally:
if buffered:
# Ensure buffered data is flushed even on cancellation.
await asyncio.shield(asyncio.to_thread(file_obj.write, bytes(buffered)))
def _print_download_progress(
downloaded_size: int, total_size: int | None, start_time: float
) -> None:
elapsed_time = max(time.time() - start_time, 1e-6)
speed = downloaded_size / 1024 / elapsed_time # KB/s
if total_size:
percent = downloaded_size / total_size
msg = f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s"
else:
msg = f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s"
print(msg, end="")
async def file_to_base64(file_path: str) -> str:
data_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
base64_str = base64.b64encode(data_bytes).decode()
def file_to_base64(file_path: str) -> str:
with open(file_path, "rb") as f:
data_bytes = f.read()
base64_str = base64.b64encode(data_bytes).decode()
return "base64://" + base64_str
@@ -256,18 +221,17 @@ def get_local_ip_addresses():
async def get_dashboard_version():
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if not await asyncio.to_thread(os.path.exists, dist_dir):
if not os.path.exists(dist_dir):
# Fall back to the dist bundled inside the installed wheel.
_bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist"
if await asyncio.to_thread(_bundled.exists):
if _bundled.exists():
dist_dir = str(_bundled)
if await asyncio.to_thread(os.path.exists, dist_dir):
if os.path.exists(dist_dir):
version_file = os.path.join(dist_dir, "assets", "version")
if await asyncio.to_thread(os.path.exists, version_file):
v = (
await asyncio.to_thread(Path(version_file).read_text, encoding="utf-8")
).strip()
return v
if os.path.exists(version_file):
with open(version_file, encoding="utf-8") as f:
v = f.read().strip()
return v
return None
@@ -280,12 +244,9 @@ async def download_dashboard(
) -> None:
"""下载管理面板文件"""
if path is None:
zip_path = (
await asyncio.to_thread(Path(get_astrbot_data_path()).absolute)
/ "dashboard.zip"
)
zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip"
else:
zip_path = await asyncio.to_thread(Path(path).absolute)
zip_path = Path(path).absolute()
if latest or len(str(version)) != 40:
ver_name = "latest" if latest else version
+6 -6
View File
@@ -108,7 +108,7 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None)
if process.returncode != 0:
# 清理可能已生成但无效的临时文件
if output_path and await asyncio.to_thread(os.path.exists, output_path):
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
logger.debug(
@@ -183,7 +183,7 @@ async def convert_video_format(
if process.returncode != 0:
# 清理可能已生成但无效的临时文件
if output_path and await asyncio.to_thread(os.path.exists, output_path):
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
logger.debug(
@@ -231,7 +231,7 @@ async def convert_audio_format(
if output_path is None:
temp_dir = Path(get_astrbot_temp_path())
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}")
args = ["ffmpeg", "-y", "-i", audio_path]
@@ -249,7 +249,7 @@ async def convert_audio_format(
)
_, stderr = await process.communicate()
if process.returncode != 0:
if output_path and await asyncio.to_thread(os.path.exists, output_path):
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
except OSError as e:
@@ -287,7 +287,7 @@ async def extract_video_cover(
"""从视频中提取封面图(JPG)。"""
if output_path is None:
temp_dir = Path(get_astrbot_temp_path())
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg")
try:
@@ -306,7 +306,7 @@ async def extract_video_cover(
)
_, stderr = await process.communicate()
if process.returncode != 0:
if output_path and await asyncio.to_thread(os.path.exists, output_path):
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
except OSError as e:
+5 -5
View File
@@ -71,11 +71,11 @@ class SessionController:
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
async def _holding(self, event: asyncio.Event, timeout_seconds: float) -> None:
async def _holding(self, event: asyncio.Event, timeout: float) -> None:
"""等待事件结束或超时"""
try:
await asyncio.wait_for(event.wait(), timeout_seconds)
except TimeoutError:
await asyncio.wait_for(event.wait(), timeout)
except asyncio.TimeoutError:
if not self.future.done():
self.future.set_exception(TimeoutError("等待超时"))
except asyncio.CancelledError:
@@ -124,14 +124,14 @@ class SessionWaiter:
async def register_wait(
self,
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
timeout_seconds: int = 30,
timeout: int = 30,
) -> Any:
"""等待外部输入并处理"""
self.handler = handler
USER_SESSIONS[self.session_id] = self
# 开始一个会话保持事件
self.session_controller.keep(timeout_seconds, reset_timeout=True)
self.session_controller.keep(timeout, reset_timeout=True)
try:
return await self.session_controller.future
+1 -1
View File
@@ -141,7 +141,7 @@ class TempDirCleaner:
self._stop_event.wait(),
timeout=self.CHECK_INTERVAL_SECONDS,
)
except TimeoutError:
except asyncio.TimeoutError:
continue
logger.info("TempDirCleaner stopped.")

Some files were not shown because too many files have changed in this diff Show More