@@ -46,8 +46,6 @@ jobs:
|
||||
include:
|
||||
- language: python
|
||||
build-mode: none
|
||||
- language: javascript-typescript
|
||||
build-mode: none
|
||||
# CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
|
||||
# Use `c-cpp` to analyze code written in C, C++ or both
|
||||
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
|
||||
|
||||
@@ -23,13 +23,12 @@ jobs:
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip uv
|
||||
uv sync --group dev
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-asyncio pytest-cov
|
||||
pip install --editable .
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
@@ -38,7 +37,7 @@ jobs:
|
||||
mkdir -p data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
uv run pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
|
||||
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
|
||||
@@ -13,23 +13,18 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.28.2
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '24.13.0'
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: dashboard/pnpm-lock.yaml
|
||||
|
||||
- name: Install and build
|
||||
- name: npm install, build
|
||||
run: |
|
||||
pnpm --dir dashboard install --frozen-lockfile
|
||||
pnpm --dir dashboard run typecheck
|
||||
pnpm --dir dashboard run build
|
||||
cd dashboard
|
||||
npm install pnpm -g
|
||||
pnpm install
|
||||
pnpm i --save-dev @types/markdown-it
|
||||
pnpm run build
|
||||
|
||||
- name: Inject Commit SHA
|
||||
id: get_sha
|
||||
|
||||
@@ -25,18 +25,6 @@ jobs:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.28.2
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '24.13.0'
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: dashboard/pnpm-lock.yaml
|
||||
|
||||
- name: Check for new commits today
|
||||
if: github.event_name == 'schedule'
|
||||
id: check-commits
|
||||
@@ -58,10 +46,12 @@ jobs:
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
pnpm --dir dashboard install --frozen-lockfile
|
||||
pnpm --dir dashboard run build
|
||||
mkdir -p dashboard/dist/assets
|
||||
echo $(git rev-parse HEAD) > dashboard/dist/assets/version
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
mkdir -p dist/assets
|
||||
echo $(git rev-parse HEAD) > dist/assets/version
|
||||
cd ..
|
||||
mkdir -p data
|
||||
cp -r dashboard/dist data/
|
||||
|
||||
@@ -133,18 +123,6 @@ jobs:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.28.2
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '24.13.0'
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: dashboard/pnpm-lock.yaml
|
||||
|
||||
- name: Get latest tag (only on manual trigger)
|
||||
id: get-latest-tag
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
@@ -175,10 +153,12 @@ jobs:
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
pnpm --dir dashboard install --frozen-lockfile
|
||||
pnpm --dir dashboard run build
|
||||
mkdir -p dashboard/dist/assets
|
||||
echo $(git rev-parse HEAD) > dashboard/dist/assets/version
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
mkdir -p dist/assets
|
||||
echo $(git rev-parse HEAD) > dist/assets/version
|
||||
cd ..
|
||||
mkdir -p data
|
||||
cp -r dashboard/dist data/
|
||||
|
||||
|
||||
@@ -18,29 +18,6 @@ permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
verify-core:
|
||||
name: Verify Core Quality Gate
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ inputs.ref || github.ref }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install uv
|
||||
shell: bash
|
||||
run: python -m pip install uv
|
||||
|
||||
- name: Run local PR gate checks
|
||||
shell: bash
|
||||
run: make pr-test-neo
|
||||
|
||||
build-dashboard:
|
||||
name: Build Dashboard
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -108,8 +85,7 @@ jobs:
|
||||
VERSION_TAG: ${{ steps.tag.outputs.tag }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y rclone
|
||||
curl https://rclone.org/install.sh | sudo bash
|
||||
|
||||
mkdir -p ~/.config/rclone
|
||||
cat <<EOF > ~/.config/rclone/rclone.conf
|
||||
@@ -130,7 +106,6 @@ jobs:
|
||||
name: Publish GitHub Release
|
||||
runs-on: ubuntu-24.04
|
||||
needs:
|
||||
- verify-core
|
||||
- build-dashboard
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -251,7 +226,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install uv
|
||||
shell: bash
|
||||
|
||||
@@ -8,7 +8,7 @@ ci:
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.15.1
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
@@ -22,4 +22,4 @@ repos:
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py312-plus]
|
||||
args: [--py310-plus]
|
||||
|
||||
+3
-2
@@ -13,9 +13,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
bash \
|
||||
ffmpeg \
|
||||
curl \
|
||||
gnupg \
|
||||
git \
|
||||
nodejs \
|
||||
npm \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
|
||||
&& apt-get install -y --no-install-recommends nodejs \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0LjYxNTZDNS4zMTUwMiAxNC4zOTk5IDUuNjAxNTYgMTQuMTEzNCA1LjYwMTU2IDEzLjc1OTlWMTEuMDM5OUM1LjYwMTU2IDEwLjY4NjQgNS4zMTUwMiAxMC4zOTk5IDQuOTYxNTYgMTAuMzk5OVoiIGZpbGw9IiNmZmYiLz4KPHBhdGggZD0iTTEzLjc1ODQgMS42MDAxSDExLjAzODRDMTAuNjg1IDEuNjAwMSAxMC4zOTg0IDEuODg2NjQgMTAuMzk4NCAyLjI0MDFWNC45NjAxQzEwLjM5ODQgNS4zMTM1NiAxMC42ODUgNS42MDAxIDExLjAzODQgNS42MDAxSDEzLjc1ODRDMTQuMTExOSA1LjYwMDEgMTQuMzk4NCA1LjMxMzU2IDE0LjM5ODQgNC45NjAxVjIuMjQwMUMxNC4zOTg0IDEuODg2NjQgMTQuMTExOSAxLjYwMDEgMTMuNzU4NCAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDRMNCAxMlpFIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjczODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
from enum import StrEnum
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
@@ -12,7 +12,7 @@ import yaml
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
class PluginStatus(StrEnum):
|
||||
class PluginStatus(str, Enum):
|
||||
INSTALLED = "installed"
|
||||
NEED_UPDATE = "needs-update"
|
||||
NOT_INSTALLED = "not-installed"
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Generic
|
||||
|
||||
from .hooks import BaseAgentRunHooks
|
||||
from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent[TContext]:
|
||||
class Agent(Generic[TContext]):
|
||||
name: str
|
||||
instructions: str | None = None
|
||||
tools: list[str | FunctionTool] | None = None
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import Generic
|
||||
|
||||
from .agent import Agent
|
||||
from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
class HandoffTool[TContext](FunctionTool):
|
||||
class HandoffTool(FunctionTool, Generic[TContext]):
|
||||
"""Handoff tool for delegating tasks to another agent."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from typing import Generic
|
||||
|
||||
import mcp
|
||||
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from .run_context import ContextWrapper
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
|
||||
class BaseAgentRunHooks[TContext]:
|
||||
class BaseAgentRunHooks(Generic[TContext]):
|
||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ...
|
||||
async def on_tool_start(
|
||||
self,
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from typing import Generic
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@@ -15,6 +16,7 @@ from astrbot import logger
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
try:
|
||||
@@ -99,7 +101,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
return True, ""
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"Connection timeout: {timeout} seconds"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
@@ -358,7 +360,7 @@ class MCPClient:
|
||||
self.running_event.set()
|
||||
|
||||
|
||||
class MCPTool[TContext](FunctionTool):
|
||||
class MCPTool(FunctionTool, Generic[TContext]):
|
||||
"""A function tool that calls an MCP service."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -7,7 +7,7 @@ from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..response import AgentResponse
|
||||
from ..run_context import ContextWrapper
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
@@ -19,7 +19,7 @@ class AgentState(Enum):
|
||||
ERROR = auto() # Error state
|
||||
|
||||
|
||||
class BaseAgentRunner[TContext]:
|
||||
class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(
|
||||
self,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
import typing as T
|
||||
from typing import override
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
@@ -18,6 +18,11 @@ from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Coze Agent Runner"""
|
||||
@@ -246,7 +251,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout_seconds=self.timeout,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -91,7 +90,7 @@ class CozeAPIClient:
|
||||
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("文件上传超时")
|
||||
raise Exception("文件上传超时")
|
||||
except Exception as e:
|
||||
@@ -129,7 +128,7 @@ class CozeAPIClient:
|
||||
conversation_id: str | None = None,
|
||||
auto_save_history: bool = True,
|
||||
stream: bool = True,
|
||||
timeout_seconds: float = 120,
|
||||
timeout: float = 120,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""发送聊天消息并返回流式响应
|
||||
|
||||
@@ -140,7 +139,7 @@ class CozeAPIClient:
|
||||
conversation_id: 会话ID
|
||||
auto_save_history: 是否自动保存历史
|
||||
stream: 是否流式响应
|
||||
timeout_seconds: 超时时间
|
||||
timeout: 超时时间
|
||||
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
@@ -167,7 +166,7 @@ class CozeAPIClient:
|
||||
url,
|
||||
json=payload,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout_seconds),
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
@@ -204,8 +203,8 @@ class CozeAPIClient:
|
||||
except json.JSONDecodeError:
|
||||
event_data = {"content": data_str}
|
||||
|
||||
except TimeoutError:
|
||||
raise Exception(f"Coze API 流式请求超时 ({timeout_seconds}秒)")
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
||||
except Exception as e:
|
||||
raise Exception(f"Coze API 流式请求失败: {e!s}")
|
||||
|
||||
@@ -237,7 +236,7 @@ class CozeAPIClient:
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Coze API 返回非JSON格式")
|
||||
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception("Coze API 请求超时")
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"Coze API 请求失败: {e!s}")
|
||||
@@ -295,7 +294,8 @@ if __name__ == "__main__":
|
||||
client = CozeAPIClient(api_key=api_key)
|
||||
|
||||
try:
|
||||
file_data = await asyncio.to_thread(Path("README.md").read_bytes)
|
||||
with open("README.md", "rb") as f:
|
||||
file_data = f.read()
|
||||
file_id = await client.upload_file(file_data)
|
||||
print(f"Uploaded file_id: {file_id}")
|
||||
async for event in client.chat_messages(
|
||||
|
||||
@@ -2,9 +2,9 @@ import asyncio
|
||||
import functools
|
||||
import queue
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import typing as T
|
||||
from typing import override
|
||||
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
@@ -22,6 +22,11 @@ from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dashscope Agent Runner"""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
import typing as T
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import override
|
||||
from uuid import uuid4
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
@@ -40,6 +40,11 @@ from .deerflow_stream_utils import (
|
||||
get_message_id,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""DeerFlow Agent Runner via LangGraph HTTP API."""
|
||||
@@ -373,9 +378,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
|
||||
if thread_id:
|
||||
return thread_id
|
||||
|
||||
thread = await self.api_client.create_thread(
|
||||
timeout_seconds=min(30, self.timeout)
|
||||
)
|
||||
thread = await self.api_client.create_thread(timeout=min(30, self.timeout))
|
||||
thread_id = thread.get("thread_id", "")
|
||||
if not thread_id:
|
||||
raise Exception(
|
||||
@@ -636,7 +639,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for event in self.api_client.stream_run(
|
||||
thread_id=thread_id,
|
||||
payload=payload,
|
||||
timeout_seconds=self.timeout,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = event.get("event")
|
||||
data = event.get("data")
|
||||
@@ -663,7 +666,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
if event_type == "end":
|
||||
break
|
||||
except TimeoutError:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
logger.warning(
|
||||
"DeerFlow stream timed out after %ss for thread_id=%s; returning partial result.",
|
||||
self.timeout,
|
||||
|
||||
@@ -139,7 +139,7 @@ class DeerFlowAPIClient:
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
async def create_thread(self, timeout_seconds: float = 20) -> dict[str, Any]:
|
||||
async def create_thread(self, timeout: float = 20) -> dict[str, Any]:
|
||||
session = self._get_session()
|
||||
url = f"{self.api_base}/api/langgraph/threads"
|
||||
payload = {"metadata": {}}
|
||||
@@ -147,7 +147,7 @@ class DeerFlowAPIClient:
|
||||
url,
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
proxy=self.proxy,
|
||||
) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
@@ -161,7 +161,7 @@ class DeerFlowAPIClient:
|
||||
self,
|
||||
thread_id: str,
|
||||
payload: dict[str, Any],
|
||||
timeout_seconds: float = 120,
|
||||
timeout: float = 120,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
session = self._get_session()
|
||||
url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream"
|
||||
@@ -183,9 +183,9 @@ class DeerFlowAPIClient:
|
||||
# Use socket read timeout so active heartbeats/chunks can keep the stream alive.
|
||||
stream_timeout = ClientTimeout(
|
||||
total=None,
|
||||
connect=min(timeout_seconds, 30),
|
||||
sock_connect=min(timeout_seconds, 30),
|
||||
sock_read=timeout_seconds,
|
||||
connect=min(timeout, 30),
|
||||
sock_connect=min(timeout, 30),
|
||||
sock_read=timeout,
|
||||
)
|
||||
async with session.post(
|
||||
url,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
from typing import override
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
@@ -19,6 +19,11 @@ from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .dify_api_client import DifyAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DifyAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dify Agent Runner"""
|
||||
@@ -171,7 +176,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout_seconds=self.timeout,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk["event"] == "message" or chunk["event"] == "agent_message":
|
||||
@@ -211,7 +216,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout_seconds=self.timeout,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify workflow resp chunk: {chunk}")
|
||||
match chunk["event"]:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ClientResponse, ClientSession, FormData
|
||||
@@ -49,20 +47,20 @@ class DifyAPIClient:
|
||||
response_mode: str = "streaming",
|
||||
conversation_id: str = "",
|
||||
files: list[dict[str, Any]] | None = None,
|
||||
timeout_seconds: float = 60,
|
||||
timeout: float = 60,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
if files is None:
|
||||
files = []
|
||||
url = f"{self.api_base}/chat-messages"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout_seconds")
|
||||
payload.pop("timeout")
|
||||
logger.info(f"chat_messages payload: {payload}")
|
||||
async with self.session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
@@ -78,20 +76,20 @@ class DifyAPIClient:
|
||||
user: str,
|
||||
response_mode: str = "streaming",
|
||||
files: list[dict[str, Any]] | None = None,
|
||||
timeout_seconds: float = 60,
|
||||
timeout: float = 60,
|
||||
):
|
||||
if files is None:
|
||||
files = []
|
||||
url = f"{self.api_base}/workflows/run"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout_seconds")
|
||||
payload.pop("timeout")
|
||||
logger.info(f"workflow_run payload: {payload}")
|
||||
async with self.session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
@@ -136,13 +134,14 @@ class DifyAPIClient:
|
||||
# 使用文件路径
|
||||
import os
|
||||
|
||||
file_content = await asyncio.to_thread(Path(file_path).read_bytes)
|
||||
form.add_field(
|
||||
"file",
|
||||
file_content,
|
||||
filename=os.path.basename(file_path),
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
)
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
form.add_field(
|
||||
"file",
|
||||
file_content,
|
||||
filename=os.path.basename(file_path),
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
)
|
||||
else:
|
||||
raise ValueError("file_path 和 file_data 不能同时为 None")
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
from dataclasses import dataclass, field
|
||||
from typing import override
|
||||
|
||||
from mcp.types import (
|
||||
BlobResourceContents,
|
||||
@@ -44,6 +44,11 @@ from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _HandleFunctionToolsResult:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
from typing import Any, Generic
|
||||
|
||||
import jsonschema
|
||||
import mcp
|
||||
@@ -10,7 +10,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
|
||||
from .run_context import ContextWrapper
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
ToolExecResult = str | mcp.types.CallToolResult
|
||||
@@ -38,7 +38,7 @@ class ToolSchema:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool[TContext](ToolSchema):
|
||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: (
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from typing import Any, Generic
|
||||
|
||||
import mcp
|
||||
|
||||
from .run_context import ContextWrapper
|
||||
from .run_context import ContextWrapper, TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
class BaseFunctionToolExecutor[TContext]:
|
||||
class BaseFunctionToolExecutor(Generic[TContext]):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
|
||||
@@ -3,7 +3,6 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
@@ -510,7 +509,8 @@ async def _simulated_stream_tts(
|
||||
audio_path = await tts_provider.get_audio(text)
|
||||
|
||||
if audio_path:
|
||||
audio_data = await asyncio.to_thread(Path(audio_path).read_bytes)
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
await audio_queue.put((text, audio_data))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -625,7 +625,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
exc_info=True,
|
||||
)
|
||||
yield None
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
@@ -242,7 +241,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||
|
||||
bool: indicates whether the file was downloaded from sandbox.
|
||||
"""
|
||||
if await asyncio.to_thread(os.path.exists, path):
|
||||
if os.path.exists(path):
|
||||
return path, False
|
||||
|
||||
# Try to check if the file exists in the sandbox
|
||||
|
||||
@@ -4,12 +4,11 @@
|
||||
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -84,7 +83,7 @@ class AstrBotExporter:
|
||||
output_dir = get_astrbot_backups_path()
|
||||
|
||||
# 确保输出目录存在
|
||||
await asyncio.to_thread(Path(output_dir).mkdir, parents=True, exist_ok=True)
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||
@@ -161,10 +160,9 @@ class AstrBotExporter:
|
||||
# 3. 导出配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||
config_content = await asyncio.to_thread(
|
||||
self._read_text_if_exists, self.config_path
|
||||
)
|
||||
if config_content is not None:
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_content = f.read()
|
||||
zf.writestr("config/cmd_config.json", config_content)
|
||||
self._add_checksum("config/cmd_config.json", config_content)
|
||||
if progress_callback:
|
||||
@@ -201,7 +199,7 @@ class AstrBotExporter:
|
||||
except Exception as e:
|
||||
logger.error(f"备份导出失败: {e}")
|
||||
# 清理失败的文件
|
||||
if await asyncio.to_thread(os.path.exists, zip_path):
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
raise
|
||||
|
||||
@@ -319,7 +317,7 @@ class AstrBotExporter:
|
||||
|
||||
for dir_name, dir_path in backup_directories.items():
|
||||
full_path = Path(dir_path)
|
||||
if not await asyncio.to_thread(full_path.exists):
|
||||
if not full_path.exists():
|
||||
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||
continue
|
||||
|
||||
@@ -361,44 +359,17 @@ class AstrBotExporter:
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""导出附件文件"""
|
||||
await asyncio.to_thread(self._export_attachments_sync, zf, attachments)
|
||||
|
||||
def _export_attachments_sync(
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""在单个线程中批量导出附件,减少高频线程切换。"""
|
||||
for attachment in attachments:
|
||||
file_path = attachment.get("path", "")
|
||||
attachment_id = attachment.get("attachment_id")
|
||||
try:
|
||||
if not file_path:
|
||||
continue
|
||||
if not attachment_id:
|
||||
logger.warning(
|
||||
f"跳过附件导出:attachment_id 为空 (path={file_path})"
|
||||
)
|
||||
continue
|
||||
# 使用 attachment_id 作为文件名
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except FileNotFoundError:
|
||||
# 和旧逻辑保持一致:缺失文件直接跳过。
|
||||
continue
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"导出附件失败 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
|
||||
)
|
||||
file_path = attachment.get("path", "")
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 使用 attachment_id 作为文件名
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"导出附件时发生非预期错误,已跳过 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
|
||||
)
|
||||
|
||||
def _read_text_if_exists(self, file_path: str) -> str | None:
|
||||
"""Read text file when it exists in a single synchronous call."""
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
return Path(file_path).read_text(encoding="utf-8")
|
||||
logger.warning(f"导出附件失败: {e}")
|
||||
|
||||
def _model_to_dict(self, record: Any) -> dict:
|
||||
"""将 SQLModel 实例转换为字典
|
||||
@@ -475,7 +446,7 @@ class AstrBotExporter:
|
||||
manifest = {
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(UTC).isoformat(),
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
|
||||
@@ -7,13 +7,12 @@
|
||||
- 版本匹配时也需要用户确认
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -365,7 +364,7 @@ class AstrBotImporter:
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
if not await asyncio.to_thread(os.path.exists, zip_path):
|
||||
if not os.path.exists(zip_path):
|
||||
result.add_error(f"备份文件不存在: {zip_path}")
|
||||
return result
|
||||
|
||||
@@ -447,13 +446,12 @@ class AstrBotImporter:
|
||||
try:
|
||||
config_content = zf.read("config/cmd_config.json")
|
||||
# 备份现有配置
|
||||
if await asyncio.to_thread(os.path.exists, self.config_path):
|
||||
if os.path.exists(self.config_path):
|
||||
backup_path = f"{self.config_path}.bak"
|
||||
shutil.copy2(self.config_path, backup_path)
|
||||
|
||||
await asyncio.to_thread(
|
||||
Path(self.config_path).write_bytes, config_content
|
||||
)
|
||||
with open(self.config_path, "wb") as f:
|
||||
f.write(config_content)
|
||||
result.imported_files["config"] = 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入配置文件失败: {e}")
|
||||
@@ -677,9 +675,9 @@ class AstrBotImporter:
|
||||
if isinstance(value, datetime):
|
||||
dt = value
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
dt = dt.astimezone(UTC)
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
return dt.isoformat()
|
||||
if isinstance(value, str):
|
||||
timestamp = value.strip()
|
||||
@@ -690,9 +688,9 @@ class AstrBotImporter:
|
||||
try:
|
||||
dt = datetime.fromisoformat(timestamp)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
dt = dt.astimezone(UTC)
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
return dt.isoformat()
|
||||
except ValueError:
|
||||
return None
|
||||
@@ -755,8 +753,8 @@ class AstrBotImporter:
|
||||
if faiss_path in zf.namelist():
|
||||
try:
|
||||
target_path = kb_dir / "index.faiss"
|
||||
with zf.open(faiss_path) as src:
|
||||
await asyncio.to_thread(target_path.write_bytes, src.read())
|
||||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||
|
||||
@@ -768,8 +766,8 @@ class AstrBotImporter:
|
||||
rel_path = name[len(media_prefix) :]
|
||||
target_path = kb_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src:
|
||||
await asyncio.to_thread(target_path.write_bytes, src.read())
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||
|
||||
@@ -830,8 +828,8 @@ class AstrBotImporter:
|
||||
target_path = attachments_dir / os.path.basename(name)
|
||||
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src:
|
||||
await asyncio.to_thread(target_path.write_bytes, src.read())
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||
@@ -887,15 +885,15 @@ class AstrBotImporter:
|
||||
continue
|
||||
|
||||
# 备份现有目录(如果存在)
|
||||
if await asyncio.to_thread(target_dir.exists):
|
||||
if target_dir.exists():
|
||||
backup_path = Path(f"{target_dir}.bak")
|
||||
if await asyncio.to_thread(backup_path.exists):
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(str(target_dir), str(backup_path))
|
||||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||
|
||||
# 创建目标目录
|
||||
await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解压文件
|
||||
for name in dir_files:
|
||||
@@ -908,8 +906,8 @@ class AstrBotImporter:
|
||||
target_path = target_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zf.open(name) as src:
|
||||
await asyncio.to_thread(target_path.write_bytes, src.read())
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
file_count += 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||
|
||||
@@ -118,10 +118,10 @@ class BayContainerManager:
|
||||
|
||||
return f"http://127.0.0.1:{self._host_port}"
|
||||
|
||||
async def wait_healthy(self, timeout_seconds: int = HEALTH_TIMEOUT_S) -> None:
|
||||
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
||||
"""Block until Bay's ``/health`` endpoint returns 200."""
|
||||
url = f"http://127.0.0.1:{self._host_port}/health"
|
||||
deadline = asyncio.get_event_loop().time() + timeout_seconds
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
last_error: str = ""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -140,7 +140,7 @@ class BayContainerManager:
|
||||
await asyncio.sleep(HEALTH_POLL_INTERVAL_S)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Bay did not become healthy within {timeout_seconds}s (last error: {last_error})"
|
||||
f"Bay did not become healthy within {timeout}s (last error: {last_error})"
|
||||
)
|
||||
|
||||
async def read_credentials(self) -> str:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -47,7 +46,8 @@ class MockShipyardSandboxClient:
|
||||
|
||||
try:
|
||||
# Read file content
|
||||
file_content = await asyncio.to_thread(Path(path).read_bytes)
|
||||
with open(path, "rb") as f:
|
||||
file_content = f.read()
|
||||
|
||||
# Create multipart form data
|
||||
data = aiohttp.FormData()
|
||||
@@ -88,7 +88,7 @@ class MockShipyardSandboxClient:
|
||||
"error": f"Connection error: {str(e)}",
|
||||
"message": "File upload failed",
|
||||
}
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "File upload timeout",
|
||||
|
||||
@@ -59,7 +59,7 @@ class LocalShellComponent(ShellComponent):
|
||||
command: str,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout_seconds: int | None = 30,
|
||||
timeout: int | None = 30,
|
||||
shell: bool = True,
|
||||
background: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
@@ -87,7 +87,7 @@ class LocalShellComponent(ShellComponent):
|
||||
shell=shell,
|
||||
cwd=working_dir,
|
||||
env=run_env,
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
@@ -106,14 +106,14 @@ class LocalPythonComponent(PythonComponent):
|
||||
self,
|
||||
code: str,
|
||||
kernel_id: str | None = None,
|
||||
timeout_seconds: int = 30,
|
||||
timeout: int = 30,
|
||||
silent: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[os.environ.get("PYTHON", sys.executable), "-c", code],
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -35,11 +33,11 @@ class NeoPythonComponent(PythonComponent):
|
||||
self,
|
||||
code: str,
|
||||
kernel_id: str | None = None,
|
||||
timeout_seconds: int = 30,
|
||||
timeout: int = 30,
|
||||
silent: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
_ = kernel_id # Bay runtime does not expose kernel_id in current SDK.
|
||||
result = await self._sandbox.python.exec(code, timeout=timeout_seconds)
|
||||
result = await self._sandbox.python.exec(code, timeout=timeout)
|
||||
payload = _maybe_model_dump(result)
|
||||
|
||||
output_text = payload.get("output", "") or ""
|
||||
@@ -77,7 +75,7 @@ class NeoShellComponent(ShellComponent):
|
||||
command: str,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout_seconds: int | None = 30,
|
||||
timeout: int | None = 30,
|
||||
shell: bool = True,
|
||||
background: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
@@ -101,7 +99,7 @@ class NeoShellComponent(ShellComponent):
|
||||
|
||||
result = await self._sandbox.shell.exec(
|
||||
run_command,
|
||||
timeout=timeout_seconds or 30,
|
||||
timeout=timeout or 30,
|
||||
cwd=cwd,
|
||||
)
|
||||
payload = _maybe_model_dump(result)
|
||||
@@ -194,7 +192,7 @@ class NeoBrowserComponent(BrowserComponent):
|
||||
async def exec(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout_seconds: int = 30,
|
||||
timeout: int = 30,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
@@ -202,7 +200,7 @@ class NeoBrowserComponent(BrowserComponent):
|
||||
) -> dict[str, Any]:
|
||||
result = await self._sandbox.browser.exec(
|
||||
cmd,
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
description=description,
|
||||
tags=tags,
|
||||
learn=learn,
|
||||
@@ -213,7 +211,7 @@ class NeoBrowserComponent(BrowserComponent):
|
||||
async def exec_batch(
|
||||
self,
|
||||
commands: list[str],
|
||||
timeout_seconds: int = 60,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
@@ -222,7 +220,7 @@ class NeoBrowserComponent(BrowserComponent):
|
||||
) -> dict[str, Any]:
|
||||
result = await self._sandbox.browser.exec_batch(
|
||||
commands,
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
description=description,
|
||||
tags=tags,
|
||||
@@ -234,7 +232,7 @@ class NeoBrowserComponent(BrowserComponent):
|
||||
async def run_skill(
|
||||
self,
|
||||
skill_key: str,
|
||||
timeout_seconds: int = 60,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
include_trace: bool = False,
|
||||
description: str | None = None,
|
||||
@@ -242,7 +240,7 @@ class NeoBrowserComponent(BrowserComponent):
|
||||
) -> dict[str, Any]:
|
||||
result = await self._sandbox.browser.run_skill(
|
||||
skill_key=skill_key,
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
include_trace=include_trace,
|
||||
description=description,
|
||||
@@ -470,7 +468,8 @@ class ShipyardNeoBooter(ComputerBooter):
|
||||
async def upload_file(self, path: str, file_name: str) -> dict:
|
||||
if self._sandbox is None:
|
||||
raise RuntimeError("ShipyardNeoBooter is not initialized.")
|
||||
content = await asyncio.to_thread(Path(path).read_bytes)
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
remote_path = file_name.lstrip("/")
|
||||
await self._sandbox.filesystem.upload(remote_path, content)
|
||||
logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path)
|
||||
@@ -487,7 +486,8 @@ class ShipyardNeoBooter(ComputerBooter):
|
||||
local_dir = os.path.dirname(local_path)
|
||||
if local_dir:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
await asyncio.to_thread(Path(local_path).write_bytes, cast(bytes, content))
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(cast(bytes, content))
|
||||
logger.info(
|
||||
"[Computer] File downloaded from Neo sandbox: %s -> %s",
|
||||
remote_path,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@@ -373,12 +372,12 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
|
||||
splitting into `apply` and `scan` phases.
|
||||
"""
|
||||
skills_root = Path(get_astrbot_skills_path())
|
||||
if not await asyncio.to_thread(skills_root.is_dir):
|
||||
if not skills_root.is_dir():
|
||||
return
|
||||
local_skill_dirs = _list_local_skill_dirs(skills_root)
|
||||
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
zip_base = temp_dir / "skills_bundle"
|
||||
zip_path = zip_base.with_suffix(".zip")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class BrowserComponent(Protocol):
|
||||
async def exec(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout_seconds: int = 30,
|
||||
timeout: int = 30,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
@@ -23,7 +23,7 @@ class BrowserComponent(Protocol):
|
||||
async def exec_batch(
|
||||
self,
|
||||
commands: list[str],
|
||||
timeout_seconds: int = 60,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
@@ -36,7 +36,7 @@ class BrowserComponent(Protocol):
|
||||
async def run_skill(
|
||||
self,
|
||||
skill_key: str,
|
||||
timeout_seconds: int = 60,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
include_trace: bool = False,
|
||||
description: str | None = None,
|
||||
|
||||
@@ -12,7 +12,7 @@ class PythonComponent(Protocol):
|
||||
self,
|
||||
code: str,
|
||||
kernel_id: str | None = None,
|
||||
timeout_seconds: int = 30,
|
||||
timeout: int = 30,
|
||||
silent: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute Python code"""
|
||||
|
||||
@@ -13,7 +13,7 @@ class ShellComponent(Protocol):
|
||||
command: str,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout_seconds: int | None = 30,
|
||||
timeout: int | None = 30,
|
||||
shell: bool = True,
|
||||
background: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -71,23 +71,19 @@ class BrowserExecTool(FunctionTool):
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
cmd: str,
|
||||
timeout_seconds: int = 30,
|
||||
timeout: int = 30,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ToolExecResult:
|
||||
legacy_timeout = kwargs.pop("timeout", None)
|
||||
if legacy_timeout is not None:
|
||||
timeout_seconds = int(legacy_timeout)
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
try:
|
||||
browser = await _get_browser_component(context)
|
||||
result = await browser.exec(
|
||||
cmd=cmd,
|
||||
timeout_seconds=timeout_seconds,
|
||||
timeout=timeout,
|
||||
description=description,
|
||||
tags=tags,
|
||||
learn=learn,
|
||||
@@ -137,24 +133,20 @@ class BrowserBatchExecTool(FunctionTool):
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
commands: list[str],
|
||||
timeout_seconds: int = 60,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ToolExecResult:
|
||||
legacy_timeout = kwargs.pop("timeout", None)
|
||||
if legacy_timeout is not None:
|
||||
timeout_seconds = int(legacy_timeout)
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
try:
|
||||
browser = await _get_browser_component(context)
|
||||
result = await browser.exec_batch(
|
||||
commands=commands,
|
||||
timeout_seconds=timeout_seconds,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
description=description,
|
||||
tags=tags,
|
||||
@@ -189,23 +181,19 @@ class RunBrowserSkillTool(FunctionTool):
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
skill_key: str,
|
||||
timeout_seconds: int = 60,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
include_trace: bool = False,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ToolExecResult:
|
||||
legacy_timeout = kwargs.pop("timeout", None)
|
||||
if legacy_timeout is not None:
|
||||
timeout_seconds = int(legacy_timeout)
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
try:
|
||||
browser = await _get_browser_component(context)
|
||||
result = await browser.run_skill(
|
||||
skill_key=skill_key,
|
||||
timeout_seconds=timeout_seconds,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
include_trace=include_trace,
|
||||
description=description,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@@ -112,10 +111,10 @@ class FileUploadTool(FunctionTool):
|
||||
)
|
||||
try:
|
||||
# Check if file exists
|
||||
if not await asyncio.to_thread(os.path.exists, local_path):
|
||||
if not os.path.exists(local_path):
|
||||
return f"Error: File does not exist: {local_path}"
|
||||
|
||||
if not await asyncio.to_thread(os.path.isfile, local_path):
|
||||
if not os.path.isfile(local_path):
|
||||
return f"Error: Path is not a file: {local_path}"
|
||||
|
||||
# Use basename if sandbox_filename is not provided
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
@@ -192,7 +192,7 @@ class CronJobManager:
|
||||
job = await self.db.get_cron_job(job_id)
|
||||
if not job or not job.enabled:
|
||||
return
|
||||
start_time = datetime.now(UTC)
|
||||
start_time = datetime.now(timezone.utc)
|
||||
await self.db.update_cron_job(
|
||||
job_id, status="running", last_run_at=start_time, last_error=None
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
@@ -23,7 +22,7 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
data_dir = get_astrbot_data_path()
|
||||
data_v3_db = os.path.join(data_dir, "data_v3.db")
|
||||
|
||||
if not await asyncio.to_thread(os.path.exists, data_v3_db):
|
||||
if not os.path.exists(data_v3_db):
|
||||
return False
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global",
|
||||
|
||||
@@ -106,8 +106,8 @@ async def migration_platform_table(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||
)
|
||||
secs_from_2023_4_10_to_now = (
|
||||
datetime.datetime.now(datetime.UTC)
|
||||
- datetime.datetime(2023, 4, 10, tzinfo=datetime.UTC)
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
|
||||
).total_seconds()
|
||||
offset_sec = int(secs_from_2023_4_10_to_now)
|
||||
logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。")
|
||||
@@ -162,7 +162,7 @@ async def migration_platform_table(
|
||||
{
|
||||
"timestamp": datetime.datetime.fromtimestamp(
|
||||
bucket_end,
|
||||
tz=datetime.UTC,
|
||||
tz=datetime.timezone.utc,
|
||||
),
|
||||
"platform_id": platform_id,
|
||||
"platform_type": platform_type,
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
|
||||
|
||||
|
||||
class TimestampMixin(SQLModel):
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(UTC),
|
||||
sa_column_kwargs={"onupdate": lambda: datetime.now(UTC)},
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import threading
|
||||
import typing as T
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult, Row
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -633,7 +633,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""Get an active API key by hash (not revoked, not expired)."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now(UTC)
|
||||
now = datetime.now(timezone.utc)
|
||||
query = select(ApiKey).where(
|
||||
ApiKey.key_hash == key_hash,
|
||||
col(ApiKey.revoked_at).is_(None),
|
||||
@@ -650,7 +650,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
await session.execute(
|
||||
update(ApiKey)
|
||||
.where(col(ApiKey.key_id) == key_id)
|
||||
.values(last_used_at=datetime.now(UTC)),
|
||||
.values(last_used_at=datetime.now(timezone.utc)),
|
||||
)
|
||||
|
||||
async def revoke_api_key(self, key_id: str) -> bool:
|
||||
@@ -661,7 +661,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = (
|
||||
update(ApiKey)
|
||||
.where(col(ApiKey.key_id) == key_id)
|
||||
.values(revoked_at=datetime.now(UTC))
|
||||
.values(revoked_at=datetime.now(timezone.utc))
|
||||
)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount > 0
|
||||
@@ -1534,7 +1534,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)}
|
||||
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
|
||||
@@ -1622,7 +1622,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)}
|
||||
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
||||
if title is not None:
|
||||
values["title"] = title
|
||||
if emoji is not None:
|
||||
|
||||
@@ -28,17 +28,12 @@ class FileTokenService:
|
||||
await self._cleanup_expired_tokens()
|
||||
return file_token not in self.staged_files
|
||||
|
||||
async def register_file(
|
||||
self,
|
||||
file_path: str,
|
||||
timeout_seconds: float | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
async def register_file(self, file_path: str, timeout: float | None = None) -> str:
|
||||
"""向令牌服务注册一个文件。
|
||||
|
||||
Args:
|
||||
file_path(str): 文件路径
|
||||
timeout_seconds(float): 超时时间,单位秒(可选)
|
||||
timeout(float): 超时时间,单位秒(可选)
|
||||
|
||||
Returns:
|
||||
str: 一个单次令牌
|
||||
@@ -63,18 +58,15 @@ class FileTokenService:
|
||||
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
legacy_timeout = kwargs.pop("timeout", None)
|
||||
if legacy_timeout is not None:
|
||||
timeout_seconds = float(legacy_timeout)
|
||||
|
||||
if not await asyncio.to_thread(os.path.exists, local_path):
|
||||
if not os.path.exists(local_path):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在: {local_path} (原始输入: {file_path})",
|
||||
)
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
expire_time = time.time() + (
|
||||
timeout_seconds if timeout_seconds is not None else self.default_timeout
|
||||
timeout if timeout is not None else self.default_timeout
|
||||
)
|
||||
# 存储转换后的真实路径
|
||||
self.staged_files[file_token] = (local_path, expire_time)
|
||||
@@ -101,6 +93,6 @@ class FileTokenService:
|
||||
raise KeyError(f"无效或过期的文件 token: {file_token}")
|
||||
|
||||
file_path, _ = self.staged_files.pop(file_token)
|
||||
if not await asyncio.to_thread(os.path.exists, file_path):
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
return file_path
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint
|
||||
|
||||
@@ -40,10 +40,10 @@ class KnowledgeBase(BaseKBModel, table=True):
|
||||
top_k_dense: int | None = Field(default=50, nullable=True)
|
||||
top_k_sparse: int | None = Field(default=50, nullable=True)
|
||||
top_m_final: int | None = Field(default=5, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(UTC),
|
||||
sa_column_kwargs={"onupdate": datetime.now(UTC)},
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
doc_count: int = Field(default=0, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
@@ -83,10 +83,10 @@ class KBDocument(BaseKBModel, table=True):
|
||||
file_path: str = Field(max_length=512, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
media_count: int = Field(default=0, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(UTC),
|
||||
sa_column_kwargs={"onupdate": datetime.now(UTC)},
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
@@ -117,4 +117,4 @@ class KBMedia(BaseKBModel, table=True):
|
||||
file_path: str = Field(max_length=512, nullable=False)
|
||||
file_size: int = Field(nullable=False)
|
||||
mime_type: str = Field(max_length=100, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@@ -27,8 +27,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
|
||||
if sys.version_info >= (3, 14):
|
||||
from pydantic import BaseModel
|
||||
@@ -40,17 +39,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||
|
||||
|
||||
def _absolute_path(path: str) -> str:
|
||||
return os.path.abspath(path)
|
||||
|
||||
|
||||
def _absolute_path_if_exists(path: str | None) -> str | None:
|
||||
if not path or not os.path.exists(path):
|
||||
return None
|
||||
return os.path.abspath(path)
|
||||
|
||||
|
||||
class ComponentType(StrEnum):
|
||||
class ComponentType(str, Enum):
|
||||
# Basic Segment Types
|
||||
Plain = "Plain" # plain text message
|
||||
Image = "Image" # image
|
||||
@@ -169,18 +158,18 @@ class Record(BaseMessageComponent):
|
||||
return self.file[8:]
|
||||
if self.file.startswith("http"):
|
||||
file_path = await download_image_by_url(self.file)
|
||||
return await asyncio.to_thread(_absolute_path, file_path)
|
||||
return os.path.abspath(file_path)
|
||||
if self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
file_path = os.path.join(
|
||||
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
await asyncio.to_thread(Path(file_path).write_bytes, image_bytes)
|
||||
return await asyncio.to_thread(_absolute_path, file_path)
|
||||
local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file)
|
||||
if local_path:
|
||||
return local_path
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
if os.path.exists(self.file):
|
||||
return os.path.abspath(self.file)
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
|
||||
async def convert_to_base64(self) -> str:
|
||||
@@ -194,17 +183,16 @@ class Record(BaseMessageComponent):
|
||||
if not self.file:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
if self.file.startswith("file:///"):
|
||||
bs64_data = await file_to_base64(self.file[8:])
|
||||
bs64_data = file_to_base64(self.file[8:])
|
||||
elif self.file.startswith("http"):
|
||||
file_path = await download_image_by_url(self.file)
|
||||
bs64_data = await file_to_base64(file_path)
|
||||
bs64_data = file_to_base64(file_path)
|
||||
elif self.file.startswith("base64://"):
|
||||
bs64_data = self.file
|
||||
elif os.path.exists(self.file):
|
||||
bs64_data = file_to_base64(self.file)
|
||||
else:
|
||||
try:
|
||||
bs64_data = await file_to_base64(self.file)
|
||||
except OSError as exc:
|
||||
raise Exception(f"not a valid file: {self.file}") from exc
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
@@ -268,15 +256,11 @@ class Video(BaseMessageComponent):
|
||||
get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
|
||||
)
|
||||
await download_file(url, video_file_path)
|
||||
local_path = await asyncio.to_thread(
|
||||
_absolute_path_if_exists, video_file_path
|
||||
)
|
||||
if local_path:
|
||||
return local_path
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
raise Exception(f"download failed: {url}")
|
||||
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
|
||||
if local_path:
|
||||
return local_path
|
||||
if os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
@@ -465,18 +449,18 @@ class Image(BaseMessageComponent):
|
||||
return url[8:]
|
||||
if url.startswith("http"):
|
||||
image_file_path = await download_image_by_url(url)
|
||||
return await asyncio.to_thread(_absolute_path, image_file_path)
|
||||
return os.path.abspath(image_file_path)
|
||||
if url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
image_file_path = os.path.join(
|
||||
get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes)
|
||||
return await asyncio.to_thread(_absolute_path, image_file_path)
|
||||
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
|
||||
if local_path:
|
||||
return local_path
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
if os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def convert_to_base64(self) -> str:
|
||||
@@ -491,17 +475,16 @@ class Image(BaseMessageComponent):
|
||||
if not url:
|
||||
raise ValueError("No valid file or URL provided")
|
||||
if url.startswith("file:///"):
|
||||
bs64_data = await file_to_base64(url[8:])
|
||||
bs64_data = file_to_base64(url[8:])
|
||||
elif url.startswith("http"):
|
||||
image_file_path = await download_image_by_url(url)
|
||||
bs64_data = await file_to_base64(image_file_path)
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
elif url.startswith("base64://"):
|
||||
bs64_data = url
|
||||
elif os.path.exists(url):
|
||||
bs64_data = file_to_base64(url)
|
||||
else:
|
||||
try:
|
||||
bs64_data = await file_to_base64(url)
|
||||
except OSError as exc:
|
||||
raise Exception(f"not a valid file: {url}") from exc
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
@@ -752,9 +735,8 @@ class File(BaseMessageComponent):
|
||||
):
|
||||
path = path[1:]
|
||||
|
||||
local_path = await asyncio.to_thread(_absolute_path_if_exists, path)
|
||||
if local_path:
|
||||
return local_path
|
||||
if os.path.exists(path):
|
||||
return os.path.abspath(path)
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
@@ -769,7 +751,7 @@ class File(BaseMessageComponent):
|
||||
and path[2] == ":"
|
||||
):
|
||||
path = path[1:]
|
||||
return await asyncio.to_thread(_absolute_path, path)
|
||||
return os.path.abspath(path)
|
||||
|
||||
return ""
|
||||
|
||||
@@ -785,7 +767,7 @@ class File(BaseMessageComponent):
|
||||
filename = f"fileseg_{uuid.uuid4().hex}"
|
||||
file_path = os.path.join(download_dir, filename)
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = await asyncio.to_thread(_absolute_path, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""将文件注册到文件服务。
|
||||
|
||||
@@ -254,7 +254,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
"robotCode": robot_code,
|
||||
}
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}"
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
@@ -412,7 +412,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field(
|
||||
"media",
|
||||
await asyncio.to_thread(media_file_path.read_bytes),
|
||||
media_file_path.read_bytes(),
|
||||
filename=media_file_path.name,
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
import discord
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# Discord Bot客户端
|
||||
class DiscordBotClient(discord.Bot):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any, cast, override
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
|
||||
import discord
|
||||
from discord.abc import GuildChannel, Messageable, PrivateChannel
|
||||
@@ -26,6 +27,11 @@ from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_re
|
||||
from .client import DiscordBotClient
|
||||
from .discord_platform_event import DiscordPlatformEvent
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# 注册平台适配器
|
||||
@register_platform_adapter(
|
||||
|
||||
@@ -130,7 +130,7 @@ class KookPlatformAdapter(Platform):
|
||||
await asyncio.wait_for(
|
||||
self.client.wait_until_closed(), timeout=1.0
|
||||
)
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
# 正常超时,继续下一轮 while 检查
|
||||
continue
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class KookClient:
|
||||
# 处理不同类型的信令
|
||||
await self._handle_signal(data)
|
||||
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
# 超时检查,继续循环
|
||||
continue
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
@@ -362,14 +362,12 @@ class KookClient:
|
||||
b64_str = file_url.removeprefix("base64://")
|
||||
bytes_data = base64.b64decode(b64_str)
|
||||
|
||||
elif file_url.startswith("file://") or await asyncio.to_thread(
|
||||
os.path.exists, file_url
|
||||
):
|
||||
elif file_url.startswith("file://") or os.path.exists(file_url):
|
||||
file_url = file_url.removeprefix("file:///")
|
||||
file_url = file_url.removeprefix("file://")
|
||||
|
||||
try:
|
||||
target_path = await asyncio.to_thread(Path(file_url).resolve)
|
||||
target_path = Path(file_url).resolve()
|
||||
except Exception as exp:
|
||||
logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"')
|
||||
raise FileNotFoundError(
|
||||
|
||||
@@ -429,7 +429,7 @@ class LarkPlatformAdapter(Platform):
|
||||
|
||||
suffix = Path(file_name).suffix if file_name else default_suffix
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = (
|
||||
temp_dir / f"lark_{message_type}_{file_name}_{uuid4().hex[:4]}{suffix}"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
@@ -138,7 +136,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
Returns:
|
||||
成功返回file_key,失败返回None
|
||||
"""
|
||||
if not path or not await asyncio.to_thread(os.path.exists, path):
|
||||
if not path or not os.path.exists(path):
|
||||
logger.error(f"[Lark] 文件不存在: {path}")
|
||||
return None
|
||||
|
||||
@@ -147,32 +145,36 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
return None
|
||||
|
||||
try:
|
||||
file_obj = BytesIO(await asyncio.to_thread(Path(path).read_bytes))
|
||||
body_builder = (
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(os.path.basename(path))
|
||||
.file(file_obj)
|
||||
)
|
||||
if duration is not None:
|
||||
body_builder.duration(duration)
|
||||
with open(path, "rb") as file_obj:
|
||||
body_builder = (
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(os.path.basename(path))
|
||||
.file(file_obj)
|
||||
)
|
||||
if duration is not None:
|
||||
body_builder.duration(duration)
|
||||
|
||||
request = (
|
||||
CreateFileRequest.builder().request_body(body_builder.build()).build()
|
||||
)
|
||||
response = await lark_client.im.v1.file.acreate(request)
|
||||
request = (
|
||||
CreateFileRequest.builder()
|
||||
.request_body(body_builder.build())
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.file.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"[Lark] 无法上传文件({response.code}): {response.msg}")
|
||||
return None
|
||||
if not response.success():
|
||||
logger.error(
|
||||
f"[Lark] 无法上传文件({response.code}): {response.msg}"
|
||||
)
|
||||
return None
|
||||
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
|
||||
return None
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
|
||||
return None
|
||||
|
||||
file_key = response.data.file_key
|
||||
logger.debug(f"[Lark] 文件上传成功: {file_key}")
|
||||
return file_key
|
||||
file_key = response.data.file_key
|
||||
logger.debug(f"[Lark] 文件上传成功: {file_key}")
|
||||
return file_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开或上传文件: {e}")
|
||||
@@ -205,9 +207,8 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
temp_dir,
|
||||
f"lark_image_{uuid.uuid4().hex[:8]}.jpg",
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
Path(file_path).write_bytes, BytesIO(image_data).getvalue()
|
||||
)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
file_path = comp.file if comp.file else ""
|
||||
|
||||
@@ -216,9 +217,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
logger.error("[Lark] 图片路径为空,无法上传")
|
||||
continue
|
||||
try:
|
||||
image_file = BytesIO(
|
||||
await asyncio.to_thread(Path(file_path).read_bytes)
|
||||
)
|
||||
image_file = open(file_path, "rb")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开图片文件: {e}")
|
||||
continue
|
||||
@@ -413,9 +412,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
logger.error(f"[Lark] 无法获取音频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_audio_path or not await asyncio.to_thread(
|
||||
os.path.exists, original_audio_path
|
||||
):
|
||||
if not original_audio_path or not os.path.exists(original_audio_path):
|
||||
logger.error(f"[Lark] 音频文件不存在: {original_audio_path}")
|
||||
return
|
||||
|
||||
@@ -445,9 +442,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
|
||||
# 清理转换后的临时音频文件
|
||||
if converted_audio_path and await asyncio.to_thread(
|
||||
os.path.exists, converted_audio_path
|
||||
):
|
||||
if converted_audio_path and os.path.exists(converted_audio_path):
|
||||
try:
|
||||
os.remove(converted_audio_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}")
|
||||
@@ -490,9 +485,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
logger.error(f"[Lark] 无法获取视频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_video_path or not await asyncio.to_thread(
|
||||
os.path.exists, original_video_path
|
||||
):
|
||||
if not original_video_path or not os.path.exists(original_video_path):
|
||||
logger.error(f"[Lark] 视频文件不存在: {original_video_path}")
|
||||
return
|
||||
|
||||
@@ -522,9 +515,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
|
||||
# 清理转换后的临时视频文件
|
||||
if converted_video_path and await asyncio.to_thread(
|
||||
os.path.exists, converted_video_path
|
||||
):
|
||||
if converted_video_path and os.path.exists(converted_video_path):
|
||||
try:
|
||||
os.remove(converted_video_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}")
|
||||
|
||||
@@ -161,7 +161,7 @@ class LineMessageEvent(AstrMessageEvent):
|
||||
try:
|
||||
video_path = await segment.convert_to_file_path()
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
thumb_path = temp_dir / f"line_video_preview_{uuid.uuid4().hex}.jpg"
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
@@ -201,8 +201,8 @@ class LineMessageEvent(AstrMessageEvent):
|
||||
async def _resolve_file_size(segment: File) -> int:
|
||||
try:
|
||||
file_path = await segment.get_file(allow_return_url=False)
|
||||
if file_path and await asyncio.to_thread(os.path.exists, file_path):
|
||||
return int(await asyncio.to_thread(os.path.getsize, file_path))
|
||||
if file_path and os.path.exists(file_path):
|
||||
return int(os.path.getsize(file_path))
|
||||
except Exception as e:
|
||||
logger.debug("[LINE] resolve file size failed: %s", e)
|
||||
return 0
|
||||
|
||||
@@ -499,8 +499,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
# 清理临时文件
|
||||
if local_path and isinstance(local_path, str):
|
||||
data_temp = get_astrbot_temp_path()
|
||||
if local_path.startswith(data_temp) and await asyncio.to_thread(
|
||||
os.path.exists,
|
||||
if local_path.startswith(data_temp) and os.path.exists(
|
||||
local_path,
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -3,7 +3,6 @@ import json
|
||||
import random
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, NoReturn
|
||||
|
||||
try:
|
||||
@@ -556,19 +555,22 @@ class MisskeyAPI:
|
||||
form.add_field("folderId", str(folder_id))
|
||||
|
||||
try:
|
||||
file_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
|
||||
f = open(file_path, "rb")
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"[Misskey API] 本地文件不存在: {file_path}")
|
||||
raise APIError(f"File not found: {file_path}") from e
|
||||
|
||||
form.add_field("file", file_bytes, filename=filename)
|
||||
async with self.session.post(url, data=form) as resp:
|
||||
result = await self._process_response(resp, "drive/files/create")
|
||||
file_id = FileIDExtractor.extract_file_id(result)
|
||||
logger.debug(
|
||||
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}",
|
||||
)
|
||||
return {"id": file_id, "raw": result}
|
||||
try:
|
||||
form.add_field("file", f, filename=filename)
|
||||
async with self.session.post(url, data=form) as resp:
|
||||
result = await self._process_response(resp, "drive/files/create")
|
||||
file_id = FileIDExtractor.extract_file_id(result)
|
||||
logger.debug(
|
||||
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}",
|
||||
)
|
||||
return {"id": file_id, "raw": result}
|
||||
finally:
|
||||
f.close()
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[Misskey API] 文件上传网络错误: {e}")
|
||||
raise APIConnectionError(f"Upload failed: {e}") from e
|
||||
|
||||
@@ -339,7 +339,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
|
||||
|
||||
# 处理文件数据
|
||||
if await asyncio.to_thread(os.path.exists, file_source):
|
||||
if os.path.exists(file_source):
|
||||
# 读取本地文件
|
||||
async with aiofiles.open(file_source, "rb") as f:
|
||||
file_content = await f.read()
|
||||
@@ -421,15 +421,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
plain_text += i.text
|
||||
elif isinstance(i, Image) and not image_base64:
|
||||
if i.file and i.file.startswith("file:///"):
|
||||
image_base64 = await file_to_base64(i.file[8:])
|
||||
image_base64 = file_to_base64(i.file[8:])
|
||||
image_file_path = i.file[8:]
|
||||
elif i.file and i.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(i.file)
|
||||
image_base64 = await file_to_base64(image_file_path)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
elif i.file and i.file.startswith("base64://"):
|
||||
image_base64 = i.file
|
||||
elif i.file:
|
||||
image_base64 = await file_to_base64(i.file)
|
||||
image_base64 = file_to_base64(i.file)
|
||||
else:
|
||||
raise ValueError("Unsupported image file format")
|
||||
image_base64 = image_base64.removeprefix("base64://")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
from typing import cast, override
|
||||
from typing import cast
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from telegram import BotCommand, Update
|
||||
@@ -32,6 +33,11 @@ from astrbot.core.utils.media_utils import convert_audio_to_wav
|
||||
|
||||
from .tg_event import TelegramPlatformEvent
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("telegram", "telegram 适配器")
|
||||
class TelegramPlatformAdapter(Platform):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
import shutil
|
||||
@@ -140,15 +139,13 @@ async def parse_webchat_message_parts(
|
||||
continue
|
||||
|
||||
file_path = Path(str(path))
|
||||
if verify_media_path_exists and not await asyncio.to_thread(file_path.exists):
|
||||
if verify_media_path_exists and not file_path.exists():
|
||||
if strict:
|
||||
raise ValueError(f"file not found: {file_path!s}")
|
||||
continue
|
||||
|
||||
file_path_str = (
|
||||
str(await asyncio.to_thread(file_path.resolve))
|
||||
if verify_media_path_exists
|
||||
else str(file_path)
|
||||
str(file_path.resolve()) if verify_media_path_exists else str(file_path)
|
||||
)
|
||||
has_content = True
|
||||
if part_type == "image":
|
||||
@@ -369,7 +366,7 @@ async def message_chain_to_storage_message_parts(
|
||||
attachments_dir: str | Path,
|
||||
) -> list[dict]:
|
||||
target_dir = Path(attachments_dir)
|
||||
await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
parts: list[dict] = []
|
||||
for comp in message_chain.chain:
|
||||
@@ -445,9 +442,7 @@ async def _copy_file_to_attachment_part(
|
||||
display_name: str | None = None,
|
||||
) -> dict | None:
|
||||
src_path = Path(file_path)
|
||||
if not await asyncio.to_thread(src_path.exists) or not await asyncio.to_thread(
|
||||
src_path.is_file
|
||||
):
|
||||
if not src_path.exists() or not src_path.is_file():
|
||||
return None
|
||||
|
||||
suffix = src_path.suffix
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -82,9 +80,8 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
filename = f"{str(uuid.uuid4())}.jpg"
|
||||
path = os.path.join(attachments_dir, filename)
|
||||
image_base64 = await comp.convert_to_base64()
|
||||
await asyncio.to_thread(
|
||||
Path(path).write_bytes, base64.b64decode(image_base64)
|
||||
)
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(image_base64))
|
||||
data = f"[IMAGE]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
@@ -99,9 +96,8 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
filename = f"{str(uuid.uuid4())}.wav"
|
||||
path = os.path.join(attachments_dir, filename)
|
||||
record_base64 = await comp.convert_to_base64()
|
||||
await asyncio.to_thread(
|
||||
Path(path).write_bytes, base64.b64decode(record_base64)
|
||||
)
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(record_base64))
|
||||
data = f"[RECORD]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -33,6 +33,11 @@ from .wecom_event import WecomPlatformEvent
|
||||
from .wecom_kf import WeChatKF
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class WecomServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
|
||||
@@ -341,7 +346,8 @@ class WecomPlatformAdapter(Platform):
|
||||
)
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
|
||||
await asyncio.to_thread(Path(path).write_bytes, resp.content)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
|
||||
@@ -396,7 +402,8 @@ class WecomPlatformAdapter(Platform):
|
||||
)
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg")
|
||||
await asyncio.to_thread(Path(path).write_bytes, resp.content)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
abm.message = [Image(file=path, url=path)]
|
||||
elif msgtype == "voice":
|
||||
media_id = msg.get("voice", {}).get("media_id", "")
|
||||
@@ -408,7 +415,8 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr")
|
||||
await asyncio.to_thread(Path(path).write_bytes, resp.content)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav")
|
||||
|
||||
@@ -12,13 +12,6 @@ from astrbot.core.utils.media_utils import convert_audio_to_amr
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
|
||||
def _upload_media_from_path(
|
||||
client: WeChatClient, media_type: str, file_path: str
|
||||
) -> dict:
|
||||
with open(file_path, "rb") as f:
|
||||
return client.media.upload(media_type, f)
|
||||
|
||||
|
||||
class WecomPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -107,52 +100,45 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"image",
|
||||
img_path,
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传图片失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传图片返回: {response}")
|
||||
kf_message_api.send_image(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传图片失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传图片返回: {response}")
|
||||
kf_message_api.send_image(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
record_path_amr = await convert_audio_to_amr(record_path)
|
||||
|
||||
try:
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"voice",
|
||||
record_path_amr,
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(
|
||||
f"微信客服上传语音失败: {e}"
|
||||
),
|
||||
)
|
||||
return
|
||||
logger.info(f"微信客服上传语音返回: {response}")
|
||||
kf_message_api.send_voice(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传语音失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.info(f"微信客服上传语音返回: {response}")
|
||||
kf_message_api.send_voice(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
finally:
|
||||
if record_path_amr != record_path and await asyncio.to_thread(
|
||||
os.path.exists,
|
||||
if record_path_amr != record_path and os.path.exists(
|
||||
record_path_amr,
|
||||
):
|
||||
try:
|
||||
@@ -162,47 +148,39 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, File):
|
||||
file_path = await comp.get_file()
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"file",
|
||||
file_path,
|
||||
with open(file_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("file", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传文件失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传文件失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传文件返回: {response}")
|
||||
kf_message_api.send_file(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传文件失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传文件失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传文件返回: {response}")
|
||||
kf_message_api.send_file(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Video):
|
||||
video_path = await comp.convert_to_file_path()
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"video",
|
||||
video_path,
|
||||
with open(video_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("video", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传视频失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传视频失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传视频返回: {response}")
|
||||
kf_message_api.send_video(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传视频失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传视频失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传视频返回: {response}")
|
||||
kf_message_api.send_video(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
else:
|
||||
@@ -221,52 +199,45 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"image",
|
||||
img_path,
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
record_path_amr = await convert_audio_to_amr(record_path)
|
||||
|
||||
try:
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"voice",
|
||||
record_path_amr,
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(
|
||||
f"企业微信上传语音失败: {e}"
|
||||
),
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传语音失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
finally:
|
||||
if record_path_amr != record_path and await asyncio.to_thread(
|
||||
os.path.exists,
|
||||
if record_path_amr != record_path and os.path.exists(
|
||||
record_path_amr,
|
||||
):
|
||||
try:
|
||||
@@ -276,47 +247,39 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, File):
|
||||
file_path = await comp.get_file()
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"file",
|
||||
file_path,
|
||||
with open(file_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("file", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传文件失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传文件失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传文件返回: {response}")
|
||||
self.client.message.send_file(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传文件失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传文件失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传文件返回: {response}")
|
||||
self.client.message.send_file(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Video):
|
||||
video_path = await comp.convert_to_file_path()
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"video",
|
||||
video_path,
|
||||
with open(video_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("video", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传视频失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传视频失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传视频返回: {response}")
|
||||
self.client.message.send_video(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传视频失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传视频失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传视频返回: {response}")
|
||||
self.client.message.send_video(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
提供常量定义、工具函数和辅助方法
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -173,7 +174,7 @@ async def process_encrypted_image(
|
||||
response.raise_for_status()
|
||||
encrypted_data = await response.read()
|
||||
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
|
||||
except (TimeoutError, aiohttp.ClientError) as e:
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
error_msg = f"下载图片失败: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import mimetypes
|
||||
@@ -104,9 +103,7 @@ class WecomAIBotWebhookClient:
|
||||
async def upload_media(
|
||||
self, file_path: Path, media_type: Literal["file", "voice"]
|
||||
) -> str:
|
||||
if not await asyncio.to_thread(file_path.exists) or not await asyncio.to_thread(
|
||||
file_path.is_file
|
||||
):
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
raise WecomAIBotWebhookError(f"文件不存在: {file_path}")
|
||||
|
||||
content_type = (
|
||||
@@ -115,7 +112,7 @@ class WecomAIBotWebhookClient:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field(
|
||||
"media",
|
||||
await asyncio.to_thread(file_path.read_bytes),
|
||||
file_path.read_bytes(),
|
||||
filename=file_path.name,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Coroutine
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -32,6 +32,11 @@ from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class WeixinOfficialAccountServer:
|
||||
def __init__(
|
||||
@@ -374,7 +379,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
) # wait for 180s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
return result
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"callback 处理消息超时: message_id={msg.id}")
|
||||
return create_reply("处理消息超时,请稍后再试。", msg)
|
||||
except Exception as e:
|
||||
@@ -463,7 +468,8 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
)
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr")
|
||||
await asyncio.to_thread(Path(path).write_bytes, resp.content)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
path_wav = os.path.join(
|
||||
|
||||
@@ -12,13 +12,6 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.core.utils.media_utils import convert_audio_to_amr
|
||||
|
||||
|
||||
def _upload_media_from_path(
|
||||
client: WeChatClient, media_type: str, file_path: str
|
||||
) -> dict:
|
||||
with open(file_path, "rb") as f:
|
||||
return client.media.upload(media_type, f)
|
||||
|
||||
|
||||
class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -108,63 +101,24 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"image",
|
||||
img_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"微信公众平台上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信公众平台上传图片失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信公众平台上传图片返回: {response}")
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
record_path_amr = await convert_audio_to_amr(record_path)
|
||||
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
_upload_media_from_path,
|
||||
self.client,
|
||||
"voice",
|
||||
record_path_amr,
|
||||
)
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信公众平台上传语音失败: {e}")
|
||||
logger.error(f"微信公众平台上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信公众平台上传语音失败: {e}"),
|
||||
MessageChain().message(f"微信公众平台上传图片失败: {e}"),
|
||||
)
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
logger.debug(f"微信公众平台上传图片返回: {response}")
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
@@ -172,9 +126,44 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
record_path_amr = await convert_audio_to_amr(record_path)
|
||||
|
||||
try:
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信公众平台上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(
|
||||
f"微信公众平台上传语音失败: {e}"
|
||||
),
|
||||
)
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=cast(dict, self.message_obj.raw_message)[
|
||||
"message"
|
||||
],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
finally:
|
||||
if record_path_amr != record_path and await asyncio.to_thread(
|
||||
os.path.exists, record_path_amr
|
||||
if record_path_amr != record_path and os.path.exists(
|
||||
record_path_amr
|
||||
):
|
||||
try:
|
||||
os.remove(record_path_amr)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from anthropic.types import Message as AnthropicMessage
|
||||
@@ -220,10 +218,9 @@ class ProviderRequest:
|
||||
"""将图片转换为 base64"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
image_bs64 = base64.b64encode(
|
||||
await asyncio.to_thread(Path(image_url).read_bytes)
|
||||
).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import threading
|
||||
import urllib.parse
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
@@ -199,7 +198,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
return True, ""
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
@@ -378,24 +377,15 @@ class FunctionToolManager:
|
||||
data_dir = get_astrbot_data_path()
|
||||
|
||||
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||
if not await asyncio.to_thread(os.path.exists, mcp_json_file):
|
||||
if not os.path.exists(mcp_json_file):
|
||||
# 配置文件不存在错误处理
|
||||
config_text = json.dumps(DEFAULT_MCP_CONFIG, ensure_ascii=False, indent=4)
|
||||
await asyncio.to_thread(
|
||||
Path(mcp_json_file).write_text,
|
||||
config_text,
|
||||
encoding="utf-8",
|
||||
)
|
||||
with open(mcp_json_file, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
||||
return MCPInitSummary(total=0, success=0, failed=[])
|
||||
|
||||
mcp_json_content = await asyncio.to_thread(
|
||||
Path(mcp_json_file).read_text,
|
||||
encoding="utf-8",
|
||||
)
|
||||
mcp_server_json_obj: dict[str, dict] = json.loads(mcp_json_content)[
|
||||
"mcpServers"
|
||||
]
|
||||
with open(mcp_json_file, encoding="utf-8") as f:
|
||||
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
|
||||
|
||||
init_timeout_value = _resolve_timeout(
|
||||
timeout=init_timeout,
|
||||
@@ -469,7 +459,7 @@ class FunctionToolManager:
|
||||
cfg: dict,
|
||||
*,
|
||||
shutdown_event: asyncio.Event | None = None,
|
||||
timeout_seconds: float,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
"""Initialize MCP server with timeout and register task/event together.
|
||||
|
||||
@@ -479,7 +469,7 @@ class FunctionToolManager:
|
||||
async with self._runtime_lock:
|
||||
if name in self._mcp_server_runtime or name in self._mcp_starting:
|
||||
logger.warning(
|
||||
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout_seconds:g})。"
|
||||
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。"
|
||||
)
|
||||
self._log_safe_mcp_debug_config(cfg)
|
||||
return
|
||||
@@ -492,11 +482,11 @@ class FunctionToolManager:
|
||||
try:
|
||||
mcp_client = await asyncio.wait_for(
|
||||
self._init_mcp_client(name, cfg),
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise MCPInitTimeoutError(
|
||||
f"MCP 服务 {name} 初始化超时({timeout_seconds:g} 秒)"
|
||||
f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
|
||||
) from exc
|
||||
except Exception:
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
@@ -529,7 +519,7 @@ class FunctionToolManager:
|
||||
async def _shutdown_runtimes(
|
||||
self,
|
||||
runtimes: list[_MCPServerRuntime],
|
||||
timeout_seconds: float,
|
||||
timeout: float,
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> list[str]:
|
||||
@@ -548,9 +538,9 @@ class FunctionToolManager:
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*lifecycle_tasks, return_exceptions=True),
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
)
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
pending_names = [
|
||||
runtime.name
|
||||
for runtime in runtimes
|
||||
@@ -561,10 +551,10 @@ class FunctionToolManager:
|
||||
task.cancel()
|
||||
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
|
||||
if strict:
|
||||
raise MCPShutdownTimeoutError(pending_names, timeout_seconds)
|
||||
raise MCPShutdownTimeoutError(pending_names, timeout)
|
||||
logger.warning(
|
||||
"MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s",
|
||||
f"{timeout_seconds:g}",
|
||||
f"{timeout:g}",
|
||||
", ".join(pending_names),
|
||||
)
|
||||
return pending_names
|
||||
@@ -675,8 +665,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
config: dict,
|
||||
shutdown_event: asyncio.Event | None = None,
|
||||
timeout_seconds: float | int | str | None = None,
|
||||
**kwargs: Any,
|
||||
timeout: float | int | str | None = None,
|
||||
) -> None:
|
||||
"""Enable a new MCP server and initialize it.
|
||||
|
||||
@@ -684,22 +673,18 @@ class FunctionToolManager:
|
||||
name: The name of the MCP server.
|
||||
config: Configuration for the MCP server.
|
||||
shutdown_event: Event to signal when the MCP client should shut down.
|
||||
timeout_seconds: Timeout in seconds for initialization.
|
||||
timeout: Timeout in seconds for initialization.
|
||||
Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout).
|
||||
|
||||
Raises:
|
||||
MCPInitTimeoutError: If initialization does not complete within timeout.
|
||||
Exception: If there is an error during initialization.
|
||||
"""
|
||||
legacy_timeout = kwargs.pop("timeout", None)
|
||||
if legacy_timeout is not None:
|
||||
timeout_seconds = legacy_timeout
|
||||
|
||||
if timeout_seconds is None:
|
||||
if timeout is None:
|
||||
timeout_value = self._enable_timeout_default
|
||||
else:
|
||||
timeout_value = _resolve_timeout(
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
env_name=ENABLE_MCP_TIMEOUT_ENV,
|
||||
default=self._enable_timeout_default,
|
||||
)
|
||||
@@ -707,45 +692,36 @@ class FunctionToolManager:
|
||||
name=name,
|
||||
cfg=config,
|
||||
shutdown_event=shutdown_event,
|
||||
timeout_seconds=timeout_value,
|
||||
timeout=timeout_value,
|
||||
)
|
||||
|
||||
async def disable_mcp_server(
|
||||
self,
|
||||
name: str | None = None,
|
||||
timeout_seconds: float = 10,
|
||||
**kwargs: Any,
|
||||
timeout: float = 10,
|
||||
) -> None:
|
||||
"""Disable an MCP server by its name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||
timeout_seconds (int): Timeout.
|
||||
timeout (int): Timeout.
|
||||
|
||||
Raises:
|
||||
MCPShutdownTimeoutError: If shutdown does not complete within timeout.
|
||||
Only raised when disabling a specific server (name is not None).
|
||||
|
||||
"""
|
||||
legacy_timeout = kwargs.pop("timeout", None)
|
||||
if legacy_timeout is not None:
|
||||
timeout_seconds = float(legacy_timeout)
|
||||
|
||||
if name:
|
||||
async with self._runtime_lock:
|
||||
runtime = self._mcp_server_runtime.get(name)
|
||||
if runtime is None:
|
||||
return
|
||||
|
||||
await self._shutdown_runtimes(
|
||||
[runtime], timeout_seconds=timeout_seconds, strict=True
|
||||
)
|
||||
await self._shutdown_runtimes([runtime], timeout, strict=True)
|
||||
else:
|
||||
async with self._runtime_lock:
|
||||
runtimes = list(self._mcp_server_runtime.values())
|
||||
await self._shutdown_runtimes(
|
||||
runtimes, timeout_seconds=timeout_seconds, strict=False
|
||||
)
|
||||
await self._shutdown_runtimes(runtimes, timeout, strict=False)
|
||||
|
||||
def _warn_on_timeout_mismatch(
|
||||
self,
|
||||
|
||||
@@ -2,8 +2,7 @@ import abc
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
@@ -16,9 +15,13 @@ from astrbot.core.provider.entities import (
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
type Providers = (
|
||||
"Provider" | "STTProvider" | "TTSProvider" | "EmbeddingProvider" | "RerankProvider"
|
||||
)
|
||||
Providers: TypeAlias = Union[
|
||||
"Provider",
|
||||
"STTProvider",
|
||||
"TTSProvider",
|
||||
"EmbeddingProvider",
|
||||
"RerankProvider",
|
||||
]
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
@@ -185,13 +188,10 @@ class Provider(AbstractProvider):
|
||||
|
||||
return dicts
|
||||
|
||||
async def test(self, timeout_seconds: float = 45.0, **kwargs: Any) -> None:
|
||||
legacy_timeout = kwargs.pop("timeout", None)
|
||||
if legacy_timeout is not None:
|
||||
timeout_seconds = float(legacy_timeout)
|
||||
async def test(self, timeout: float = 45.0) -> None:
|
||||
await asyncio.wait_for(
|
||||
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=timeout_seconds,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
@@ -268,9 +268,8 @@ class TTSProvider(AbstractProvider):
|
||||
# 调用原有的 get_audio 方法获取音频文件路径
|
||||
audio_path = await self.get_audio(accumulated_text)
|
||||
# 读取音频文件内容
|
||||
audio_data = await asyncio.to_thread(
|
||||
Path(audio_path).read_bytes
|
||||
)
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
await audio_queue.put((accumulated_text, audio_data))
|
||||
except Exception:
|
||||
# 出错时也要发送 None 结束标记
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
@@ -639,10 +637,11 @@ class ProviderAnthropic(Provider):
|
||||
except Exception:
|
||||
mime_type = "image/jpeg"
|
||||
return f"data:{mime_type};base64,{raw_base64}", mime_type
|
||||
image_bytes = await asyncio.to_thread(Path(image_url).read_bytes)
|
||||
mime_type = self._detect_image_mime_type(image_bytes)
|
||||
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{image_bs64}", mime_type
|
||||
with open(image_url, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
mime_type = self._detect_image_mime_type(image_bytes)
|
||||
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{image_bs64}", mime_type
|
||||
return "", "image/jpeg"
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
|
||||
@@ -3,7 +3,6 @@ import base64
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import dashscope
|
||||
@@ -60,7 +59,8 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
)
|
||||
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
|
||||
await asyncio.to_thread(Path(path).write_bytes, audio_bytes)
|
||||
with open(path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
return path
|
||||
|
||||
def _call_qwen_tts(self, model: str, text: str):
|
||||
@@ -129,7 +129,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
) as response,
|
||||
):
|
||||
return await response.read()
|
||||
except (TimeoutError, aiohttp.ClientError, OSError) as e:
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
|
||||
logging.exception(f"Failed to download audio from URL {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,129 +1,126 @@
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
import edge_tts
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
"""
|
||||
edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库
|
||||
```
|
||||
pip install edge_tts
|
||||
```
|
||||
Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot
|
||||
"""
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"edge_tts",
|
||||
"Microsoft Edge TTS",
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
)
|
||||
class ProviderEdgeTTS(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
# 设置默认语音,如果没有指定则使用中文小萱
|
||||
self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural")
|
||||
self.rate = provider_config.get("rate")
|
||||
self.volume = provider_config.get("volume")
|
||||
self.pitch = provider_config.get("pitch")
|
||||
self.timeout = provider_config.get("timeout", 30)
|
||||
|
||||
self.proxy = os.getenv("https_proxy", None)
|
||||
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
|
||||
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
|
||||
|
||||
# 构建 Edge TTS 参数
|
||||
kwargs = {"text": text, "voice": self.voice}
|
||||
if self.rate:
|
||||
kwargs["rate"] = self.rate
|
||||
if self.volume:
|
||||
kwargs["volume"] = self.volume
|
||||
if self.pitch:
|
||||
kwargs["pitch"] = self.pitch
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
|
||||
await communicate.save(mp3_path)
|
||||
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input_file=mp3_path, output_file=wav_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
# use ffmpeg command line
|
||||
|
||||
# 使用ffmpeg将MP3转换为标准WAV格式
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y", # 覆盖输出文件
|
||||
"-i",
|
||||
mp3_path, # 输入文件
|
||||
"-acodec",
|
||||
"pcm_s16le", # 16位PCM编码
|
||||
"-ar",
|
||||
"24000", # 采样率24kHz (适合微信语音)
|
||||
"-ac",
|
||||
"1", # 单声道
|
||||
"-af",
|
||||
"apad=pad_dur=2", # 确保输出时长准确
|
||||
"-fflags",
|
||||
"+genpts", # 强制生成时间戳
|
||||
"-hide_banner", # 隐藏版本信息
|
||||
wav_path, # 输出文件
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
# 等待进程完成并获取输出
|
||||
stdout, stderr = await p.communicate()
|
||||
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
|
||||
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
|
||||
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
|
||||
|
||||
os.remove(mp3_path)
|
||||
if (
|
||||
await asyncio.to_thread(os.path.exists, wav_path)
|
||||
and await asyncio.to_thread(os.path.getsize, wav_path) > 0
|
||||
):
|
||||
return wav_path
|
||||
logger.error("生成的WAV文件不存在或为空")
|
||||
raise RuntimeError("生成的WAV文件不存在或为空")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(
|
||||
f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}",
|
||||
)
|
||||
try:
|
||||
if await asyncio.to_thread(os.path.exists, mp3_path):
|
||||
os.remove(mp3_path)
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError(f"FFmpeg 转换失败: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"音频生成失败: {e!s}")
|
||||
try:
|
||||
if await asyncio.to_thread(os.path.exists, mp3_path):
|
||||
os.remove(mp3_path)
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError(f"音频生成失败: {e!s}")
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
import edge_tts
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
"""
|
||||
edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库
|
||||
```
|
||||
pip install edge_tts
|
||||
```
|
||||
Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot
|
||||
"""
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"edge_tts",
|
||||
"Microsoft Edge TTS",
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
)
|
||||
class ProviderEdgeTTS(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
# 设置默认语音,如果没有指定则使用中文小萱
|
||||
self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural")
|
||||
self.rate = provider_config.get("rate")
|
||||
self.volume = provider_config.get("volume")
|
||||
self.pitch = provider_config.get("pitch")
|
||||
self.timeout = provider_config.get("timeout", 30)
|
||||
|
||||
self.proxy = os.getenv("https_proxy", None)
|
||||
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
|
||||
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
|
||||
|
||||
# 构建 Edge TTS 参数
|
||||
kwargs = {"text": text, "voice": self.voice}
|
||||
if self.rate:
|
||||
kwargs["rate"] = self.rate
|
||||
if self.volume:
|
||||
kwargs["volume"] = self.volume
|
||||
if self.pitch:
|
||||
kwargs["pitch"] = self.pitch
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
|
||||
await communicate.save(mp3_path)
|
||||
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input_file=mp3_path, output_file=wav_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
# use ffmpeg command line
|
||||
|
||||
# 使用ffmpeg将MP3转换为标准WAV格式
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y", # 覆盖输出文件
|
||||
"-i",
|
||||
mp3_path, # 输入文件
|
||||
"-acodec",
|
||||
"pcm_s16le", # 16位PCM编码
|
||||
"-ar",
|
||||
"24000", # 采样率24kHz (适合微信语音)
|
||||
"-ac",
|
||||
"1", # 单声道
|
||||
"-af",
|
||||
"apad=pad_dur=2", # 确保输出时长准确
|
||||
"-fflags",
|
||||
"+genpts", # 强制生成时间戳
|
||||
"-hide_banner", # 隐藏版本信息
|
||||
wav_path, # 输出文件
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
# 等待进程完成并获取输出
|
||||
stdout, stderr = await p.communicate()
|
||||
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
|
||||
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
|
||||
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
|
||||
|
||||
os.remove(mp3_path)
|
||||
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
|
||||
return wav_path
|
||||
logger.error("生成的WAV文件不存在或为空")
|
||||
raise RuntimeError("生成的WAV文件不存在或为空")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(
|
||||
f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}",
|
||||
)
|
||||
try:
|
||||
if os.path.exists(mp3_path):
|
||||
os.remove(mp3_path)
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError(f"FFmpeg 转换失败: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"音频生成失败: {e!s}")
|
||||
try:
|
||||
if os.path.exists(mp3_path):
|
||||
os.remove(mp3_path)
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError(f"音频生成失败: {e!s}")
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import ormsgpack
|
||||
@@ -161,10 +159,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
if response.status_code == 200 and response.headers.get(
|
||||
"content-type", ""
|
||||
).startswith("audio/"):
|
||||
audio_data = bytearray()
|
||||
async for chunk in response.aiter_bytes():
|
||||
audio_data.extend(chunk)
|
||||
await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data))
|
||||
with open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
error_bytes = await response.aread()
|
||||
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from google import genai
|
||||
@@ -925,10 +924,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
"""将图片转换为 base64"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
image_bs64 = base64.b64encode(
|
||||
await asyncio.to_thread(Path(image_url).read_bytes)
|
||||
).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
|
||||
async def terminate(self) -> None:
|
||||
if self.client:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
@@ -73,7 +72,7 @@ class GenieTTSProvider(TTSProvider):
|
||||
try:
|
||||
await loop.run_in_executor(None, _generate, path)
|
||||
|
||||
if await asyncio.to_thread(os.path.exists, path):
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
|
||||
raise RuntimeError("Genie TTS did not save to file.")
|
||||
@@ -110,8 +109,9 @@ class GenieTTSProvider(TTSProvider):
|
||||
|
||||
await loop.run_in_executor(None, _generate, path, text)
|
||||
|
||||
if await asyncio.to_thread(os.path.exists, path):
|
||||
audio_data = await asyncio.to_thread(Path(path).read_bytes)
|
||||
if os.path.exists(path):
|
||||
with open(path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
# Put (text, bytes) into queue so frontend can display text
|
||||
await audio_queue.put((text, audio_data))
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -130,7 +129,8 @@ class ProviderGSVTTS(TTSProvider):
|
||||
|
||||
result = await self._make_request(endpoint, params)
|
||||
if isinstance(result, bytes):
|
||||
await asyncio.to_thread(Path(path).write_bytes, result)
|
||||
with open(path, "wb") as f:
|
||||
f.write(result)
|
||||
return path
|
||||
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
|
||||
|
||||
|
||||
@@ -1,62 +1,59 @@
|
||||
import asyncio
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"gsvi_tts_api",
|
||||
"GSVI TTS API",
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
)
|
||||
class ProviderGSVITTS(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000")
|
||||
self.api_base = self.api_base.removesuffix("/")
|
||||
self.character = provider_config.get("character")
|
||||
self.emotion = provider_config.get("emotion")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
||||
params = {"text": text}
|
||||
|
||||
if self.character:
|
||||
params["character"] = self.character
|
||||
if self.emotion:
|
||||
params["emotion"] = self.emotion
|
||||
|
||||
query_parts = []
|
||||
for key, value in params.items():
|
||||
encoded_value = urllib.parse.quote(str(value))
|
||||
query_parts.append(f"{key}={encoded_value}")
|
||||
|
||||
url = f"{self.api_base}/tts?{'&'.join(query_parts)}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
await asyncio.to_thread(
|
||||
Path(path).write_bytes, await response.read()
|
||||
)
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}",
|
||||
)
|
||||
|
||||
return path
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"gsvi_tts_api",
|
||||
"GSVI TTS API",
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
)
|
||||
class ProviderGSVITTS(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000")
|
||||
self.api_base = self.api_base.removesuffix("/")
|
||||
self.character = provider_config.get("character")
|
||||
self.emotion = provider_config.get("emotion")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
||||
params = {"text": text}
|
||||
|
||||
if self.character:
|
||||
params["character"] = self.character
|
||||
if self.emotion:
|
||||
params["emotion"] = self.emotion
|
||||
|
||||
query_parts = []
|
||||
for key, value in params.items():
|
||||
encoded_value = urllib.parse.quote(str(value))
|
||||
query_parts.append(f"{key}={encoded_value}")
|
||||
|
||||
url = f"{self.api_base}/tts?{'&'.join(query_parts)}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await response.read())
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}",
|
||||
)
|
||||
|
||||
return path
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -157,7 +155,8 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
audio = await self._audio_play(audio_stream)
|
||||
|
||||
# 结果保存至文件
|
||||
await asyncio.to_thread(Path(path).write_bytes, audio)
|
||||
with open(path, "wb") as file:
|
||||
file.write(audio)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import json
|
||||
import random
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -950,10 +949,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"""将图片转换为 base64"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
image_bs64 = base64.b64encode(
|
||||
await asyncio.to_thread(Path(image_url).read_bytes)
|
||||
).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
@@ -56,10 +54,9 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
response_format="wav",
|
||||
input=text,
|
||||
) as response:
|
||||
audio_data = bytearray()
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
audio_data.extend(chunk)
|
||||
await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data))
|
||||
with open(path, "wb") as f:
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
return path
|
||||
|
||||
async def terminate(self):
|
||||
|
||||
@@ -53,12 +53,14 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
async def get_timestamped_path(self) -> str:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
return str(temp_dir / timestamp)
|
||||
|
||||
async def _is_silk_file(self, file_path) -> bool:
|
||||
silk_header = b"SILK"
|
||||
file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8]
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
return False
|
||||
@@ -74,7 +76,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
if not await asyncio.to_thread(os.path.isfile, audio_url):
|
||||
if not os.path.isfile(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith((".amr", ".silk")) or is_tencent:
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -101,9 +100,10 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
f"volcengine_tts_{uuid.uuid4()}.mp3",
|
||||
)
|
||||
|
||||
await asyncio.to_thread(
|
||||
Path(file_path).write_bytes,
|
||||
audio_data,
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: open(file_path, "wb").write(audio_data),
|
||||
)
|
||||
|
||||
return file_path
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
|
||||
@@ -46,7 +44,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
amr_header = b"#!AMR"
|
||||
|
||||
try:
|
||||
file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8]
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
@@ -74,7 +73,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
if not await asyncio.to_thread(os.path.exists, audio_url):
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
@@ -101,14 +100,13 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
|
||||
audio_url = output_path
|
||||
|
||||
audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes)
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
file=("audio.wav", audio_bytes),
|
||||
file=("audio.wav", open(audio_url, "rb")),
|
||||
)
|
||||
|
||||
# remove temp file
|
||||
if output_path and await asyncio.to_thread(os.path.exists, output_path):
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(audio_url)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import whisper
|
||||
@@ -43,7 +42,9 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
|
||||
async def _is_silk_file(self, file_path) -> bool:
|
||||
silk_header = b"SILK"
|
||||
file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8]
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
return False
|
||||
@@ -65,7 +66,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
if not await asyncio.to_thread(os.path.exists, audio_url):
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
from xinference_client.client.restful.async_restful_client import (
|
||||
@@ -104,8 +102,9 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
f"Failed to download audio from {audio_url}, status: {resp.status}",
|
||||
)
|
||||
return ""
|
||||
elif await asyncio.to_thread(os.path.exists, audio_url):
|
||||
audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes)
|
||||
elif os.path.exists(audio_url):
|
||||
with open(audio_url, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
else:
|
||||
logger.error(f"File not found: {audio_url}")
|
||||
return ""
|
||||
@@ -144,7 +143,8 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
)
|
||||
temp_files.extend([input_path, output_path])
|
||||
|
||||
await asyncio.to_thread(Path(input_path).write_bytes, audio_bytes)
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
if conversion_type == "silk":
|
||||
logger.info("Converting silk to wav ...")
|
||||
@@ -153,7 +153,8 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
logger.info("Converting amr to wav ...")
|
||||
await convert_to_pcm_wav(input_path, output_path)
|
||||
|
||||
audio_bytes = await asyncio.to_thread(Path(output_path).read_bytes)
|
||||
with open(output_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# 4. Transcribe
|
||||
# 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来
|
||||
@@ -198,7 +199,7 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
# 5. Cleanup
|
||||
for temp_file in temp_files:
|
||||
try:
|
||||
if await asyncio.to_thread(os.path.exists, temp_file):
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
logger.debug(f"Removed temporary file: {temp_file}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -19,7 +19,7 @@ _SKILL_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _to_jsonable(model_like: Any) -> dict[str, Any]:
|
||||
|
||||
@@ -7,7 +7,7 @@ import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
@@ -175,7 +175,7 @@ class SkillManager:
|
||||
|
||||
def _save_sandbox_skills_cache(self, cache: dict) -> None:
|
||||
cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION
|
||||
cache["updated_at"] = datetime.now(UTC).isoformat()
|
||||
cache["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(cache, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import enum
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, TypeVar, overload
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from .filter import HandlerFilter
|
||||
from .star import star_map
|
||||
@@ -11,7 +11,7 @@ from .star import star_map
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry[T: "StarHandlerMetadata"]:
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
def __init__(self) -> None:
|
||||
self.star_handlers_map: dict[str, StarHandlerMetadata] = {}
|
||||
self._handlers: list[StarHandlerMetadata] = []
|
||||
@@ -227,7 +227,7 @@ H = TypeVar("H", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarHandlerMetadata[H: Callable[..., Any]]:
|
||||
class StarHandlerMetadata(Generic[H]):
|
||||
"""描述一个 Star 所注册的某一个 Handler。"""
|
||||
|
||||
event_type: EventType
|
||||
|
||||
@@ -8,7 +8,6 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import yaml
|
||||
@@ -189,7 +188,7 @@ class PluginManager:
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
plugin_dir = self.plugin_store_path
|
||||
if not await asyncio.to_thread(os.path.exists, plugin_dir):
|
||||
if not os.path.exists(plugin_dir):
|
||||
return False
|
||||
to_update = []
|
||||
if target_plugin:
|
||||
@@ -199,9 +198,7 @@ class PluginManager:
|
||||
to_update.append(p.root_dir_name)
|
||||
for p in to_update:
|
||||
plugin_path = os.path.join(plugin_dir, p)
|
||||
if await asyncio.to_thread(
|
||||
os.path.exists, os.path.join(plugin_path, "requirements.txt")
|
||||
):
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
@@ -220,7 +217,7 @@ class PluginManager:
|
||||
try:
|
||||
return __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError) as import_exc:
|
||||
if await asyncio.to_thread(os.path.exists, requirements_path):
|
||||
if os.path.exists(requirements_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}"
|
||||
@@ -654,19 +651,16 @@ class PluginManager:
|
||||
plugin_dir_path,
|
||||
self.conf_schema_fname,
|
||||
)
|
||||
if await asyncio.to_thread(os.path.exists, plugin_schema_path):
|
||||
if os.path.exists(plugin_schema_path):
|
||||
# 加载插件配置
|
||||
plugin_schema_text = await asyncio.to_thread(
|
||||
Path(plugin_schema_path).read_text,
|
||||
encoding="utf-8",
|
||||
)
|
||||
plugin_config = AstrBotConfig(
|
||||
config_path=os.path.join(
|
||||
self.plugin_config_path,
|
||||
f"{root_dir_name}_config.json",
|
||||
),
|
||||
schema=json.loads(plugin_schema_text),
|
||||
)
|
||||
with open(plugin_schema_path, encoding="utf-8") as f:
|
||||
plugin_config = AstrBotConfig(
|
||||
config_path=os.path.join(
|
||||
self.plugin_config_path,
|
||||
f"{root_dir_name}_config.json",
|
||||
),
|
||||
schema=json.loads(f.read()),
|
||||
)
|
||||
logo_path = os.path.join(plugin_dir_path, self.logo_fname)
|
||||
|
||||
if path in star_map:
|
||||
@@ -842,7 +836,7 @@ class PluginManager:
|
||||
metadata.activated = False
|
||||
|
||||
# Plugin logo path
|
||||
if await asyncio.to_thread(os.path.exists, logo_path):
|
||||
if os.path.exists(logo_path):
|
||||
metadata.logo_path = logo_path
|
||||
|
||||
assert metadata.module_path, f"插件 {metadata.name} 模块路径为空"
|
||||
@@ -961,7 +955,7 @@ class PluginManager:
|
||||
except Exception:
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
if await asyncio.to_thread(os.path.exists, plugin_path):
|
||||
if os.path.exists(plugin_path):
|
||||
try:
|
||||
remove_dir(plugin_path)
|
||||
logger.warning(f"已清理安装失败的插件目录: {plugin_path}")
|
||||
@@ -974,7 +968,7 @@ class PluginManager:
|
||||
self.plugin_config_path,
|
||||
f"{dir_name}_config.json",
|
||||
)
|
||||
if await asyncio.to_thread(os.path.exists, plugin_config_path):
|
||||
if os.path.exists(plugin_config_path):
|
||||
try:
|
||||
os.remove(plugin_config_path)
|
||||
logger.warning(f"已清理安装失败插件配置: {plugin_config_path}")
|
||||
@@ -1106,14 +1100,13 @@ class PluginManager:
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(plugin_path, "README.md")
|
||||
if not await asyncio.to_thread(os.path.exists, readme_path):
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if await asyncio.to_thread(os.path.exists, readme_path):
|
||||
if os.path.exists(readme_path):
|
||||
try:
|
||||
readme_content = await asyncio.to_thread(
|
||||
Path(readme_path).read_text, encoding="utf-8"
|
||||
)
|
||||
with open(readme_path, encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}",
|
||||
@@ -1218,7 +1211,7 @@ class PluginManager:
|
||||
self._cleanup_plugin_state(dir_name)
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, dir_name)
|
||||
if await asyncio.to_thread(os.path.exists, plugin_path):
|
||||
if os.path.exists(plugin_path):
|
||||
try:
|
||||
remove_dir(plugin_path)
|
||||
except Exception as e:
|
||||
@@ -1505,14 +1498,13 @@ class PluginManager:
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(desti_dir, "README.md")
|
||||
if not await asyncio.to_thread(os.path.exists, readme_path):
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(desti_dir, "readme.md")
|
||||
|
||||
if await asyncio.to_thread(os.path.exists, readme_path):
|
||||
if os.path.exists(readme_path):
|
||||
try:
|
||||
readme_content = await asyncio.to_thread(
|
||||
Path(readme_path).read_text, encoding="utf-8"
|
||||
)
|
||||
with open(readme_path, encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def normalize_datetime_utc(dt: datetime | None) -> datetime | None:
|
||||
@@ -9,8 +9,8 @@ def normalize_datetime_utc(dt: datetime | None) -> datetime | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||
return dt.replace(tzinfo=UTC)
|
||||
return dt.astimezone(UTC)
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def to_utc_isoformat(dt: datetime | None) -> str | None:
|
||||
|
||||
+57
-96
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
@@ -9,7 +8,6 @@ import time
|
||||
import uuid
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO
|
||||
|
||||
import aiohttp
|
||||
import certifi
|
||||
@@ -19,8 +17,6 @@ from PIL import Image
|
||||
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
_DOWNLOAD_READ_CHUNK_SIZE = 64 * 1024
|
||||
_DOWNLOAD_FLUSH_THRESHOLD = 256 * 1024
|
||||
|
||||
|
||||
def on_error(func, path, exc_info) -> None:
|
||||
@@ -62,7 +58,8 @@ def save_temp_img(img: Image.Image | bytes) -> str:
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
else:
|
||||
Path(p).write_bytes(img)
|
||||
with open(p, "wb") as f:
|
||||
f.write(img)
|
||||
return p
|
||||
|
||||
|
||||
@@ -86,13 +83,15 @@ async def download_image_by_url(
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
else:
|
||||
async with session.get(url) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||
# 关闭SSL验证(仅在证书验证失败时作为fallback)
|
||||
@@ -110,13 +109,15 @@ async def download_image_by_url(
|
||||
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
else:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
await asyncio.to_thread(Path(path).write_bytes, await resp.read())
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -137,20 +138,28 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
file_obj = await asyncio.to_thread(Path(path).open, "wb")
|
||||
try:
|
||||
await _stream_to_file(
|
||||
resp.content,
|
||||
file_obj,
|
||||
total_size=total_size,
|
||||
start_time=start_time,
|
||||
show_progress=show_progress,
|
||||
)
|
||||
finally:
|
||||
await asyncio.to_thread(file_obj.close)
|
||||
with open(path, "wb") as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = (
|
||||
time.time() - start_time
|
||||
if time.time() - start_time > 0
|
||||
else 1
|
||||
)
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(
|
||||
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
|
||||
end="",
|
||||
)
|
||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||
# 关闭SSL验证(仅在证书验证失败时作为fallback)
|
||||
logger.warning(
|
||||
@@ -168,76 +177,32 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
file_obj = await asyncio.to_thread(Path(path).open, "wb")
|
||||
try:
|
||||
await _stream_to_file(
|
||||
resp.content,
|
||||
file_obj,
|
||||
total_size=total_size,
|
||||
start_time=start_time,
|
||||
show_progress=show_progress,
|
||||
)
|
||||
finally:
|
||||
await asyncio.to_thread(file_obj.close)
|
||||
with open(path, "wb") as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(
|
||||
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
|
||||
end="",
|
||||
)
|
||||
if show_progress:
|
||||
print()
|
||||
|
||||
|
||||
async def _stream_to_file(
|
||||
stream: aiohttp.StreamReader,
|
||||
file_obj: BinaryIO,
|
||||
*,
|
||||
total_size: int,
|
||||
start_time: float,
|
||||
show_progress: bool,
|
||||
) -> None:
|
||||
"""Stream HTTP response into file with buffered thread-offloaded writes."""
|
||||
downloaded_size = 0
|
||||
known_total = total_size if total_size > 0 else None
|
||||
buffered = bytearray()
|
||||
|
||||
try:
|
||||
while True:
|
||||
chunk = await stream.read(_DOWNLOAD_READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffered.extend(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
|
||||
if len(buffered) >= _DOWNLOAD_FLUSH_THRESHOLD:
|
||||
await asyncio.to_thread(file_obj.write, bytes(buffered))
|
||||
buffered.clear()
|
||||
|
||||
if show_progress:
|
||||
_print_download_progress(downloaded_size, known_total, start_time)
|
||||
finally:
|
||||
if buffered:
|
||||
# Ensure buffered data is flushed even on cancellation.
|
||||
await asyncio.shield(asyncio.to_thread(file_obj.write, bytes(buffered)))
|
||||
|
||||
|
||||
def _print_download_progress(
|
||||
downloaded_size: int, total_size: int | None, start_time: float
|
||||
) -> None:
|
||||
elapsed_time = max(time.time() - start_time, 1e-6)
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
|
||||
if total_size:
|
||||
percent = downloaded_size / total_size
|
||||
msg = f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s"
|
||||
else:
|
||||
msg = f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s"
|
||||
|
||||
print(msg, end="")
|
||||
|
||||
|
||||
async def file_to_base64(file_path: str) -> str:
|
||||
data_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
|
||||
base64_str = base64.b64encode(data_bytes).decode()
|
||||
def file_to_base64(file_path: str) -> str:
|
||||
with open(file_path, "rb") as f:
|
||||
data_bytes = f.read()
|
||||
base64_str = base64.b64encode(data_bytes).decode()
|
||||
return "base64://" + base64_str
|
||||
|
||||
|
||||
@@ -256,18 +221,17 @@ def get_local_ip_addresses():
|
||||
async def get_dashboard_version():
|
||||
# First check user data directory (manually updated / downloaded dashboard).
|
||||
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if not await asyncio.to_thread(os.path.exists, dist_dir):
|
||||
if not os.path.exists(dist_dir):
|
||||
# Fall back to the dist bundled inside the installed wheel.
|
||||
_bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist"
|
||||
if await asyncio.to_thread(_bundled.exists):
|
||||
if _bundled.exists():
|
||||
dist_dir = str(_bundled)
|
||||
if await asyncio.to_thread(os.path.exists, dist_dir):
|
||||
if os.path.exists(dist_dir):
|
||||
version_file = os.path.join(dist_dir, "assets", "version")
|
||||
if await asyncio.to_thread(os.path.exists, version_file):
|
||||
v = (
|
||||
await asyncio.to_thread(Path(version_file).read_text, encoding="utf-8")
|
||||
).strip()
|
||||
return v
|
||||
if os.path.exists(version_file):
|
||||
with open(version_file, encoding="utf-8") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
@@ -280,12 +244,9 @@ async def download_dashboard(
|
||||
) -> None:
|
||||
"""下载管理面板文件"""
|
||||
if path is None:
|
||||
zip_path = (
|
||||
await asyncio.to_thread(Path(get_astrbot_data_path()).absolute)
|
||||
/ "dashboard.zip"
|
||||
)
|
||||
zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip"
|
||||
else:
|
||||
zip_path = await asyncio.to_thread(Path(path).absolute)
|
||||
zip_path = Path(path).absolute()
|
||||
|
||||
if latest or len(str(version)) != 40:
|
||||
ver_name = "latest" if latest else version
|
||||
|
||||
@@ -108,7 +108,7 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None)
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and await asyncio.to_thread(os.path.exists, output_path):
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
@@ -183,7 +183,7 @@ async def convert_video_format(
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and await asyncio.to_thread(os.path.exists, output_path):
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
@@ -231,7 +231,7 @@ async def convert_audio_format(
|
||||
|
||||
if output_path is None:
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}")
|
||||
|
||||
args = ["ffmpeg", "-y", "-i", audio_path]
|
||||
@@ -249,7 +249,7 @@ async def convert_audio_format(
|
||||
)
|
||||
_, stderr = await process.communicate()
|
||||
if process.returncode != 0:
|
||||
if output_path and await asyncio.to_thread(os.path.exists, output_path):
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
except OSError as e:
|
||||
@@ -287,7 +287,7 @@ async def extract_video_cover(
|
||||
"""从视频中提取封面图(JPG)。"""
|
||||
if output_path is None:
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg")
|
||||
|
||||
try:
|
||||
@@ -306,7 +306,7 @@ async def extract_video_cover(
|
||||
)
|
||||
_, stderr = await process.communicate()
|
||||
if process.returncode != 0:
|
||||
if output_path and await asyncio.to_thread(os.path.exists, output_path):
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
except OSError as e:
|
||||
|
||||
@@ -71,11 +71,11 @@ class SessionController:
|
||||
|
||||
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
|
||||
async def _holding(self, event: asyncio.Event, timeout_seconds: float) -> None:
|
||||
async def _holding(self, event: asyncio.Event, timeout: float) -> None:
|
||||
"""等待事件结束或超时"""
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout_seconds)
|
||||
except TimeoutError:
|
||||
await asyncio.wait_for(event.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if not self.future.done():
|
||||
self.future.set_exception(TimeoutError("等待超时"))
|
||||
except asyncio.CancelledError:
|
||||
@@ -124,14 +124,14 @@ class SessionWaiter:
|
||||
async def register_wait(
|
||||
self,
|
||||
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
timeout_seconds: int = 30,
|
||||
timeout: int = 30,
|
||||
) -> Any:
|
||||
"""等待外部输入并处理"""
|
||||
self.handler = handler
|
||||
USER_SESSIONS[self.session_id] = self
|
||||
|
||||
# 开始一个会话保持事件
|
||||
self.session_controller.keep(timeout_seconds, reset_timeout=True)
|
||||
self.session_controller.keep(timeout, reset_timeout=True)
|
||||
|
||||
try:
|
||||
return await self.session_controller.future
|
||||
|
||||
@@ -141,7 +141,7 @@ class TempDirCleaner:
|
||||
self._stop_event.wait(),
|
||||
timeout=self.CHECK_INTERVAL_SECONDS,
|
||||
)
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
logger.info("TempDirCleaner stopped.")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user