Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter 55e1431084 docs: update sponsors 2026-03-03 19:07:38 +08:00
451 changed files with 2540 additions and 38938 deletions
+5 -12
View File
@@ -3,8 +3,8 @@
### Modifications / 改动点 ### Modifications / 改动点
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?--> <!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
- [x] This is NOT a breaking change. / 这不是一个破坏性变更。 - [x] This is NOT a breaking change. / 这不是一个破坏性变更。
<!-- If your changes is a breaking change, please uncheck the checkbox above --> <!-- If your changes is a breaking change, please uncheck the checkbox above -->
@@ -21,14 +21,7 @@
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.--> <!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。--> <!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
- [ ] 😊 If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc. - [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
/ 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。 - [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
- [ ] 👀 My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**. - [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
/ 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。
- [ ] 🤓 I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
/ 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到 `requirements.txt``pyproject.toml` 文件相应位置。
- [ ] 😮 My changes do not introduce malicious code.
/ 我的更改没有引入恶意代码。
-43
View File
@@ -1,43 +0,0 @@
name: release
on:
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest # 运行环境
steps:
- name: checkout
uses: actions/checkout@v6
- name: nodejs installation
uses: actions/setup-node@v6
with:
node-version: "18"
- name: npm install
run: npm add -D vitepress
working-directory: './docs' # working-directory 指定 shell 命令运行目录
- name: npm run build
run: npm run docs:build
working-directory: './docs'
- name: scp
uses: appleboy/scp-action@v1.0.0
with:
host: ${{ secrets.HOST_NEKO }}
username: ${{ secrets.USERNAME }}
password: ${{ secrets.PASSWORDNEKO }}
source: 'docs/.vitepress/dist/*'
target: '/tmp/'
- name: script
uses: appleboy/ssh-action@v1.2.5
with:
host: ${{ secrets.HOST_NEKO }}
username: ${{ secrets.USERNAME }}
password: ${{ secrets.PASSWORDNEKO }}
script: |
mkdir -p /root/docker_data/caddy/caddy_data/static_site/abv4/
rm -rf /root/docker_data/caddy/caddy_data/static_site/abv4/*
mv /tmp/docs/.vitepress/dist/* /root/docker_data/caddy/caddy_data/static_site/abv4/
rm -rf /tmp/docs/
+1 -1
View File
@@ -45,7 +45,7 @@ jobs:
- name: Create GitHub Release - name: Create GitHub Release
if: github.event_name == 'push' if: github.event_name == 'push'
uses: ncipollo/release-action@v1.21.0 uses: ncipollo/release-action@v1
with: with:
tag: release-${{ github.sha }} tag: release-${{ github.sha }}
owner: AstrBotDevs owner: AstrBotDevs
+10 -10
View File
@@ -64,20 +64,20 @@ jobs:
echo "build_date=$build_date" >> $GITHUB_OUTPUT echo "build_date=$build_date" >> $GITHUB_OUTPUT
- name: Set QEMU - name: Set QEMU
uses: docker/setup-qemu-action@v4.0.0 uses: docker/setup-qemu-action@v3
- name: Set Docker Buildx - name: Set Docker Buildx
uses: docker/setup-buildx-action@v4.0.0 uses: docker/setup-buildx-action@v3
- name: Log in to DockerHub - name: Log in to DockerHub
uses: docker/login-action@v4.0.0 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKER_HUB_USERNAME }} username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }} password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true' if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v4.0.0 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ env.GHCR_OWNER }} username: ${{ env.GHCR_OWNER }}
@@ -98,7 +98,7 @@ jobs:
echo "EOF" >> $GITHUB_OUTPUT echo "EOF" >> $GITHUB_OUTPUT
- name: Build and Push Nightly Image - name: Build and Push Nightly Image
uses: docker/build-push-action@v7.0.0 uses: docker/build-push-action@v6
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
@@ -163,27 +163,27 @@ jobs:
cp -r dashboard/dist data/ cp -r dashboard/dist data/
- name: Set QEMU - name: Set QEMU
uses: docker/setup-qemu-action@v4.0.0 uses: docker/setup-qemu-action@v3
- name: Set Docker Buildx - name: Set Docker Buildx
uses: docker/setup-buildx-action@v4.0.0 uses: docker/setup-buildx-action@v3
- name: Log in to DockerHub - name: Log in to DockerHub
uses: docker/login-action@v4.0.0 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKER_HUB_USERNAME }} username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }} password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true' if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v4.0.0 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ env.GHCR_OWNER }} username: ${{ env.GHCR_OWNER }}
password: ${{ secrets.GHCR_GITHUB_TOKEN }} password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Release Image - name: Build and Push Release Image
uses: docker/build-push-action@v7.0.0 uses: docker/build-push-action@v6
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
-53
View File
@@ -1,53 +0,0 @@
name: PR Title Check
on:
pull_request_target:
types: [opened, edited, reopened, synchronize]
jobs:
title-format:
runs-on: ubuntu-latest
permissions:
pull-requests: write
issues: write
steps:
- name: Validate PR title
uses: actions/github-script@v8
with:
script: |
const title = (context.payload.pull_request.title || "").trim();
// allow only:
// feat: xxx
// feat(scope): xxx
const pattern = /^(feat)(\([a-z0-9-]+\))?:\s.+$/i;
const isValid = pattern.test(title);
const isSameRepo =
context.payload.pull_request.head.repo.full_name === context.payload.repository.full_name;
if (!isValid) {
if (isSameRepo) {
try {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.payload.pull_request.number,
body: [
"⚠️ PR title format check failed.",
"Required formats:",
"- `feat: xxx`",
"- `feat(scope): xxx`",
"Please update your PR title and push again."
].join("\n")
});
} catch (e) {
core.warning(`Failed to post PR title comment: ${e.message}`);
}
} else {
core.warning("Fork PR: comment permission is restricted; skip posting review comment.");
}
}
if (!isValid) {
core.setFailed("Invalid PR title. Expected format: feat: xxx or feat(scope): xxx.");
}
+2 -35
View File
@@ -50,7 +50,7 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT" echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@v4.4.0 uses: pnpm/action-setup@v4
with: with:
version: 10.28.2 version: 10.28.2
@@ -184,8 +184,7 @@ jobs:
publish-pypi: publish-pypi:
name: Publish PyPI name: Publish PyPI
runs-on: ubuntu-24.04 runs-on: ubuntu-24.04
needs: needs: publish-release
- publish-release
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v6
@@ -193,36 +192,6 @@ jobs:
fetch-depth: 0 fetch-depth: 0
ref: ${{ inputs.ref || github.ref }} ref: ${{ inputs.ref || github.ref }}
- name: Resolve tag
id: tag
shell: bash
run: |
if [ "${{ github.event_name }}" = "push" ]; then
tag="${GITHUB_REF_NAME}"
elif [ -n "${{ inputs.tag }}" ]; then
tag="${{ inputs.tag }}"
else
tag="$(git describe --tags --abbrev=0)"
fi
if [ -z "$tag" ]; then
echo "Failed to resolve tag." >&2
exit 1
fi
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Download dashboard artifact
uses: actions/download-artifact@v8
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
path: dashboard-artifact
- name: Unpack dashboard dist into package tree
shell: bash
run: |
mkdir -p astrbot/dashboard/dist
unzip -q "dashboard-artifact/AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" -d dashboard-artifact/unpacked
cp -r dashboard-artifact/unpacked/dist/. astrbot/dashboard/dist/
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
@@ -234,8 +203,6 @@ jobs:
- name: Build package - name: Build package
shell: bash shell: bash
# Dashboard assets are already in astrbot/dashboard/dist/;
# ASTRBOT_BUILD_DASHBOARD is intentionally unset so the hatch hook skips npm.
run: uv build run: uv build
- name: Publish to PyPI - name: Publish to PyPI
-68
View File
@@ -1,68 +0,0 @@
name: sync wiki
on:
workflow_dispatch:
push:
branches:
- master
paths:
- '.github/workflows/sync-wiki.yml'
- 'docs/scripts/sync_docs_to_wiki.py'
- 'docs/tests/test_sync_docs_to_wiki.py'
- 'docs/zh/**'
- 'docs/en/**'
concurrency:
group: sync-wiki-${{ github.ref }}
cancel-in-progress: true
jobs:
sync:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Validate manual ref
if: github.event_name == 'workflow_dispatch' && github.ref != 'refs/heads/master'
run: |
echo "This workflow only publishes from refs/heads/master. Re-run it from the master branch."
exit 1
- name: Check out docs repository
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.11'
- name: Run sync unit tests
working-directory: docs
run: python -m unittest discover -s tests -p 'test_sync_docs_to_wiki.py' -v
- name: Validate internal doc links
run: python docs/scripts/sync_docs_to_wiki.py --source-root docs --check-links-only
- name: Clone AstrBot wiki
env:
WIKI_TOKEN: ${{ secrets.ASTRBOT_WIKI_TOKEN }}
run: |
test -n "$WIKI_TOKEN"
git clone "https://x-access-token:${WIKI_TOKEN}@github.com/AstrBotDevs/AstrBot.wiki.git" wiki
- name: Generate wiki pages
run: python docs/scripts/sync_docs_to_wiki.py --source-root docs --wiki-root wiki
- name: Commit and push wiki changes
working-directory: wiki
run: |
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git add .
if git diff --cached --quiet; then
echo "No wiki changes to push"
exit 0
fi
git commit -m "docs: sync wiki from AstrBot-1/docs"
git push
-2
View File
@@ -61,5 +61,3 @@ GenieData/
.codex/ .codex/
.opencode/ .opencode/
.kilocode/ .kilocode/
.worktrees/
+17 -34
View File
@@ -73,68 +73,57 @@ AstrBot is an open-source all-in-one Agent chatbot platform that integrates with
### One-Click Deployment ### One-Click Deployment
For users who want to quickly experience AstrBot, are familiar with command-line usage, and can install a `uv` environment on their own, we recommend the `uv` one-click deployment method ⚡️: For users who want to quickly experience AstrBot, we recommend using the one-click deployment method with `uv` ⚡️:
```bash ```bash
uv tool install astrbot uv tool install astrbot
astrbot init # Only execute this command for the first time to initialize the environment astrbot init # Only execute this command for the first time to initialize the environment
astrbot run astrbot
``` ```
> Requires [uv](https://docs.astral.sh/uv/) to be installed. > Requires [uv](https://docs.astral.sh/uv/) to be installed.
> [!NOTE]
> For macOS user: due to macOS security checks, the first run of the `astrbot` command may take longer (about 10-20s).
Update `astrbot`:
```bash
uv tool upgrade astrbot
```
### Docker Deployment ### Docker Deployment
For users familiar with containers and looking for a more stable, production-ready deployment method, we recommend deploying AstrBot with Docker / Docker Compose. For users who want a more stable and production-ready deployment, we recommend using Docker / Docker Compose to deploy AstrBot.
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot). Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Deploy on RainYun ### Deploy on RainYun
For users who want one-click deployment and do not want to manage servers themselves, we recommend RainYun's one-click cloud deployment service ☁️: For users who want to deploy AstrBot with one-click and don't want to manage the server, we recommend using RainYun's one-click cloud deployment service ☁️:
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) [![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### Desktop Application Deployment ### Desktop Application (Tauri)
For users who want to use AstrBot on desktop and mainly use ChatUI, we recommend AstrBot App. For users who want to deploy AstrBot on their desktop, primarily using AstrBot ChatUI, rarely use AstrBot plugins, we recommend using the AstrBot App:
Visit [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) to download and install; this method is designed for desktop usage and is not recommended for server scenarios. Desktop repository: [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop).
### Launcher Deployment Supports multiple system architectures, direct package installation, and out-of-the-box usage. A convenient one-click desktop deployment option for beginners.
For desktop users who also want fast deployment and isolated multi-instance usage, we recommend AstrBot Launcher. ### One-Click Launcher Deployment (AstrBot Launcher)
Visit [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) to download and install. For users who want a quick deployment and multi-instance solution with environment isolation, we recommend using the AstrBot Launcher:
Visit the [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) repository and install the package for your OS from the latest release.
A quick deployment and multi-instance solution with environment isolation.
### Deploy on Replit ### Deploy on Replit
Replit deployment is maintained by the community and is suitable for online demos and lightweight trials. Community-contributed deployment method.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR ### AUR
AUR deployment targets Arch Linux users who prefer installing AstrBot through the system package workflow.
Run the command below to install `astrbot-git`, then start AstrBot in your local environment.
```bash ```bash
yay -S astrbot-git yay -S astrbot-git
``` ```
**More deployment methods** **More deployment methods**: [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html) | [Manual Deployment](https://astrbot.app/deploy/astrbot/cli.html)
If you need panel-based management or deeper customization, see [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html) for BT Panel app-store setup, [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html) for 1Panel app-market deployment, [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html) for NAS/home-server visual deployment, and [Manual Deployment](https://astrbot.app/deploy/astrbot/cli.html) for fully custom source-based installation with `uv`.
## Supported Messaging Platforms ## Supported Messaging Platforms
@@ -201,7 +190,6 @@ Connect AstrBot to your favorite chat platform.
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1"> <img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p> </p>
## ❤️ Contributing ## ❤️ Contributing
Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :) Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
@@ -220,22 +208,17 @@ pip install pre-commit
pre-commit install pre-commit install
``` ```
## 🌍 Community ## 🌍 Community
### QQ Groups ### QQ Groups
- Group 9: 1076659624 (New)
- Group 10: 1078079676 (New)
- Group 1: 322154837 - Group 1: 322154837
- Group 3: 630166526 - Group 3: 630166526
- Group 5: 822130018 - Group 5: 822130018
- Group 6: 753075035 - Group 6: 753075035
- Group 7: 743746109 - Group 7: 743746109
- Group 8: 1030353265 - Group 8: 1030353265
- Developer Group: 975206796
- Developer Group(Chit-chat): 975206796
- Developer Group(Formal): 1039761811
### Discord Server ### Discord Server
+17 -29
View File
@@ -73,68 +73,57 @@ AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègr
### Déploiement en un clic ### Déploiement en un clic
Pour les utilisateurs qui veulent découvrir AstrBot rapidement, qui sont familiers avec la ligne de commande et peuvent installer eux-mêmes l'environnement `uv`, nous recommandons la méthode de déploiement en un clic avec `uv` ⚡️ : Pour les utilisateurs qui souhaitent découvrir AstrBot rapidement, nous recommandons la méthode de déploiement en un clic avec `uv` ⚡️ :
```bash ```bash
uv tool install astrbot uv tool install astrbot
astrbot init # Exécutez cette commande uniquement la première fois pour initialiser l'environnement astrbot init # Exécutez cette commande uniquement la première fois pour initialiser l'environnement
astrbot run astrbot
``` ```
> [uv](https://docs.astral.sh/uv/) doit être installé. > [uv](https://docs.astral.sh/uv/) doit être installé.
> [!NOTE]
> Pour les utilisateurs macOS : en raison des vérifications de sécurité de macOS, la première exécution de la commande `astrbot` peut prendre plus de temps (environ 10-20s).
Mettre à jour `astrbot` :
```bash
uv tool upgrade astrbot
```
### Déploiement Docker ### Déploiement Docker
Pour les utilisateurs familiers avec les conteneurs et qui souhaitent une méthode plus stable et adaptée à la production, nous recommandons de déployer AstrBot avec Docker / Docker Compose. Pour les utilisateurs qui veulent un déploiement plus stable et prêt pour la production, nous recommandons d'utiliser Docker / Docker Compose pour déployer AstrBot.
Veuillez consulter la documentation officielle [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot). Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Déployer sur RainYun ### Déployer sur RainYun
Pour les utilisateurs qui souhaitent déployer AstrBot en un clic sans gérer le serveur eux-mêmes, nous recommandons le service de déploiement cloud en un clic de RainYun ☁️ : Pour les utilisateurs qui souhaitent déployer AstrBot en un clic sans gérer le serveur, nous recommandons le service de déploiement cloud en un clic de RainYun ☁️ :
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) [![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### Déploiement de l'application de bureau ### Application de bureau (Tauri)
Pour les utilisateurs qui veulent utiliser AstrBot sur desktop et passer principalement par ChatUI, nous recommandons AstrBot App. Pour les utilisateurs qui veulent déployer AstrBot sur desktop, utilisent principalement AstrBot ChatUI et utilisent rarement les plugins AstrBot, nous recommandons AstrBot App :
Accédez à [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) pour télécharger et installer l'application ; cette méthode est conçue pour un usage desktop et n'est pas recommandée pour les scénarios serveur. Dépôt de l'application de bureau : [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop).
### Déploiement avec le lanceur Prend en charge plusieurs architectures système, installation directe, prête à l'emploi. Solution de déploiement bureau en un clic, particulièrement adaptée aux débutants. Non recommandée pour les serveurs.
Également sur desktop, pour les utilisateurs qui souhaitent un déploiement rapide avec isolation d'environnement et multi-instances, nous recommandons AstrBot Launcher. ### Déploiement en un clic avec le lanceur (AstrBot Launcher)
Accédez à [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) pour télécharger et installer. Pour les utilisateurs qui veulent une solution de déploiement rapide et multi-instances avec isolation d'environnement, nous recommandons d'utiliser AstrBot Launcher :
Accédez au dépôt [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) et installez le package correspondant à votre système depuis la dernière release.
Une solution de déploiement rapide et multi-instances avec isolation d'environnement.
### Déployer sur Replit ### Déployer sur Replit
Le déploiement sur Replit est maintenu par la communauté et convient aux démonstrations en ligne et aux essais légers. Méthode de déploiement contribuée par la communauté.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR ### AUR
Le mode AUR s'adresse aux utilisateurs Arch Linux qui préfèrent installer AstrBot via le gestionnaire de paquets système.
Exécutez la commande ci-dessous pour installer `astrbot-git`, puis lancez AstrBot localement.
```bash ```bash
yay -S astrbot-git yay -S astrbot-git
``` ```
**Autres méthodes de déploiement** **Autres méthodes de déploiement** : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) | [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [Déploiement manuel](https://astrbot.app/deploy/astrbot/cli.html)
Si vous avez besoin d'une gestion par panneau ou d'une personnalisation plus poussée, consultez [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) pour une installation via BT Panel, [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) pour le marketplace 1Panel, [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) pour un déploiement visuel sur NAS/serveur domestique, et [Déploiement manuel](https://astrbot.app/deploy/astrbot/cli.html) pour une installation complète depuis les sources avec `uv`.
## Plateformes de messagerie prises en charge ## Plateformes de messagerie prises en charge
@@ -222,7 +211,6 @@ pre-commit install
- Groupe 5 : 822130018 - Groupe 5 : 822130018
- Groupe 6 : 753075035 - Groupe 6 : 753075035
- Groupe développeurs : 975206796 - Groupe développeurs : 975206796
- Groupe développeurs (officiel) : 1039761811
### Serveur Discord ### Serveur Discord
+16 -28
View File
@@ -73,68 +73,57 @@ AstrBot は、主要なインスタントメッセージングアプリと統合
### ワンクリックデプロイ ### ワンクリックデプロイ
AstrBot を素早く試したいユーザーで、コマンドラインに慣れており `uv` 環境を自分でインストールできる場合は、`uv` ワンクリックデプロイをおすすめします ⚡️: AstrBot を素早く試したいユーザーは、`uv` を使ったワンクリックデプロイをおすすめします ⚡️:
```bash ```bash
uv tool install astrbot uv tool install astrbot
astrbot init # 初回のみ実行して環境を初期化します astrbot init # 初回のみ実行して環境を初期化します
astrbot run astrbot
``` ```
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。 > [uv](https://docs.astral.sh/uv/) のインストールが必要です。
> [!NOTE]
> macOS ユーザーの場合:macOS のセキュリティチェックにより、`astrbot` コマンドの初回実行に時間がかかる場合があります(約 10〜20 秒)。
`astrbot` の更新:
```bash
uv tool upgrade astrbot
```
### Docker デプロイ ### Docker デプロイ
コンテナ運用に慣れており、より安定した本番向けのデプロイ方法を求めるユーザーには、Docker / Docker Compose で AstrBot デプロイをおすすめします。 より安定した本番向けのデプロイを求めるユーザーには、Docker / Docker Compose で AstrBot デプロイすることをおすすめします。
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。 公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
### 雨云でのデプロイ ### 雨云でのデプロイ
AstrBot をワンクリックでデプロイしたく、サーバーを自分で管理したくないユーザーには、雨云のワンクリッククラウドデプロイサービスをおすすめします ☁️: サーバー管理をせずに AstrBot をワンクリックでデプロイしたいユーザーには、雨云のワンクリッククラウドデプロイサービスをおすすめします ☁️:
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) [![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### デスクトップアプリのデプロイ ### デスクトップクライアント(Tauri
デスクトップで AstrBot を使い、主に ChatUI を入口として利用するユーザーには、AstrBot App をおすすめします デスクトップで AstrBot を使いたいユーザーで、主に AstrBot ChatUI を利用し、AstrBot プラグインの利用頻度が低い場合は、AstrBot App の利用をおすすめします:
[AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) からダウンロードしてインストールしてください。この方式はデスクトップ向けであり、サーバー用途には推奨されません デスクトップアプリのリポジトリ [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop)。
### ランチャーのデプロイ マルチシステムアーキテクチャに対応し、インストーラーですぐ利用可能。初心者にも使いやすいワンクリックのデスクトップデプロイ方式です。サーバー用途には推奨されません。
同じくデスクトップで、素早くデプロイしつつ環境を分離して多重起動したいユーザーには、AstrBot Launcher をおすすめします。 ### ランチャーによるワンクリックデプロイ(AstrBot Launcher
[AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) からダウンロードしてインストールしてください。 高速デプロイと環境分離されたマルチインスタンス運用を求めるユーザーには、AstrBot Launcher の利用をおすすめします:
[AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) リポジトリにアクセスし、最新リリースからお使いの OS 向けパッケージをインストールしてください。
高速デプロイと環境分離されたマルチインスタンス運用を実現できます。
### Replit でのデプロイ ### Replit でのデプロイ
Replit デプロイはコミュニティ提供の方式で、オンラインデモや軽量な試用に向いています コミュニティ貢献によるデプロイ方法
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR ### AUR
AUR 方式は Arch Linux ユーザー向けで、システムのパッケージ運用に合わせて AstrBot を導入したい場合に適しています。
次のコマンドで `astrbot-git` をインストールし、ローカル環境で AstrBot を起動してください。
```bash ```bash
yay -S astrbot-git yay -S astrbot-git
``` ```
**その他のデプロイ方法** **その他のデプロイ方法**[宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html) | [手動デプロイ](https://astrbot.app/deploy/astrbot/cli.html)
パネル操作での導入やより高度なカスタマイズが必要な場合は、[宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html)BT Panel 経由の導入)、[1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html)(1Panel アプリマーケット経由)、[CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html)(NAS / ホームサーバー向け可視化導入)、[手動デプロイ](https://astrbot.app/deploy/astrbot/cli.html)`uv` とソースベースのフルカスタム導入)を参照してください。
## サポートされているメッセージプラットフォーム ## サポートされているメッセージプラットフォーム
@@ -223,7 +212,6 @@ pre-commit install
- 5群: 822130018 - 5群: 822130018
- 6群: 753075035 - 6群: 753075035
- 開発者群: 975206796 - 開発者群: 975206796
- 開発者群(正式): 1039761811
### Discord サーバー ### Discord サーバー
+17 -29
View File
@@ -73,68 +73,57 @@ AstrBot — это универсальная платформа Agent-чатб
### Развёртывание в один клик ### Развёртывание в один клик
Для пользователей, которые хотят быстро попробовать AstrBot, знакомы с командной строкой и могут самостоятельно установить окружение `uv`, мы рекомендуем использовать развёртывание в один клик через `uv` ⚡️: Для пользователей, которые хотят быстро попробовать AstrBot, мы рекомендуем использовать развёртывание в один клик через `uv` ⚡️:
```bash ```bash
uv tool install astrbot uv tool install astrbot
astrbot init # Выполните эту команду только при первом запуске для инициализации окружения astrbot init # Выполните эту команду только при первом запуске для инициализации окружения
astrbot run astrbot
``` ```
> Требуется установленный [uv](https://docs.astral.sh/uv/). > Требуется установленный [uv](https://docs.astral.sh/uv/).
> [!NOTE]
> Для пользователей macOS: из-за проверок безопасности macOS первый запуск команды `astrbot` может занять больше времени (около 10-20 секунд).
Обновить `astrbot`:
```bash
uv tool upgrade astrbot
```
### Развёртывание Docker ### Развёртывание Docker
Для пользователей, знакомых с контейнерами и которым нужен более стабильный и подходящий для production способ, мы рекомендуем разворачивать AstrBot через Docker / Docker Compose. Для пользователей, которым нужен более стабильный и готовый к production вариант, мы рекомендуем развёртывать AstrBot через Docker / Docker Compose.
См. официальную документацию [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot). См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Развёртывание на RainYun ### Развёртывание на RainYun
Для пользователей, которые хотят развернуть AstrBot в один клик и не хотят самостоятельно управлять сервером, мы рекомендуем облачный сервис развёртывания в один клик от RainYun ☁️: Для пользователей, которые хотят развернуть AstrBot в один клик и не управлять сервером самостоятельно, мы рекомендуем облачный сервис развёртывания в один клик от RainYun ☁️:
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) [![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### Развёртывание десктопного приложения ### Десктопное приложение (Tauri)
Для пользователей, которые хотят использовать AstrBot на десктопе и в основном работают через ChatUI, мы рекомендуем AstrBot App. Для пользователей, которые хотят использовать AstrBot на десктопе, в основном работают с AstrBot ChatUI и редко используют плагины AstrBot, мы рекомендуем AstrBot App:
Перейдите в [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop), скачайте и установите приложение; этот вариант предназначен для десктопа и не рекомендуется для серверных сценариев. Репозиторий десктопного приложения: [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop).
### Развёртывание через лаунчер Поддерживает разные архитектуры систем, устанавливается напрямую и работает сразу после установки. Удобное настольное развёртывание в один клик для новичков. Не рекомендуется для серверных сценариев.
Также на десктопе, для пользователей, которым нужен быстрый запуск и мультиинстанс с изоляцией окружений, мы рекомендуем AstrBot Launcher. ### Установка в один клик через лаунчер (AstrBot Launcher)
Перейдите в [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher), чтобы скачать и установить. Для пользователей, которым нужно быстрое развёртывание и мультиинстанс с изоляцией окружений, мы рекомендуем использовать AstrBot Launcher:
Перейдите в репозиторий [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher), откройте Releases и установите пакет для вашей системы из последней версии.
Быстрое развёртывание и мультиинстанс-решение с изоляцией окружений.
### Развёртывание на Replit ### Развёртывание на Replit
Развёртывание через Replit поддерживается сообществом и подходит для онлайн-демо и лёгких тестовых запусков. Метод развёртывания от сообщества.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR ### AUR
AUR-вариант предназначен для пользователей Arch Linux, которым удобна установка через системный менеджер пакетов.
Выполните команду ниже для установки `astrbot-git`, затем запустите AstrBot локально.
```bash ```bash
yay -S astrbot-git yay -S astrbot-git
``` ```
**Другие способы развёртывания** **Другие способы развёртывания**: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) | [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [Ручное развёртывание](https://astrbot.app/deploy/astrbot/cli.html)
Если вам нужна панельная установка или более глубокая кастомизация, смотрите [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) (установка через BT Panel), [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) (развёртывание через маркетплейс 1Panel), [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) (визуальный вариант для NAS и домашних серверов) и [Ручное развёртывание](https://astrbot.app/deploy/astrbot/cli.html) (полностью настраиваемая установка из исходников через `uv`).
## Поддерживаемые платформы обмена сообщениями ## Поддерживаемые платформы обмена сообщениями
@@ -222,7 +211,6 @@ pre-commit install
- Группа 5: 822130018 - Группа 5: 822130018
- Группа 6: 753075035 - Группа 6: 753075035
- Группа разработчиков: 975206796 - Группа разработчиков: 975206796
- Группа разработчиков (официальная): 1039761811
### Сервер Discord ### Сервер Discord
+17 -33
View File
@@ -73,30 +73,21 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
### 一鍵部署 ### 一鍵部署
對於想快速體驗 AstrBot、且熟悉命令列並能自行安裝 `uv` 環境的使用者,我們推薦使用 `uv` 一鍵部署方式 ⚡️ 對於想快速體驗 AstrBot 的使用者,我們推薦使用 `uv` 一鍵部署方式 ⚡️
```bash ```bash
uv tool install astrbot uv tool install astrbot
astrbot init # 僅首次執行此命令以初始化環境 astrbot init # 僅首次執行此命令以初始化環境
astrbot run astrbot
``` ```
> 需要安裝 [uv](https://docs.astral.sh/uv/)。 > 需要安裝 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 對於 macOS 使用者:由於 macOS 安全性檢查,首次執行 `astrbot` 指令可能需要較長時間(約 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署 ### Docker 部署
對於熟悉容器、希望獲得更穩定更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。 對於希望獲得更穩定更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。
請參官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。 請參官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
### 在雨雲上部署 ### 在雨雲上部署
@@ -104,37 +95,35 @@ uv tool upgrade astrbot
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) [![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### 桌面客戶端部署 ### 桌面客戶端Tauri
對於希望在桌面端使用 AstrBot、以 ChatUI 為主要入口的使用者,我們推薦使用 AstrBot App 對於希望在桌面部署 AstrBot、以 AstrBot ChatUI 為主要使用方式、較少使用 AstrBot 外掛的使用者,我們推薦使用 AstrBot App
前往 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) 下載並安裝;此方式面向桌面使用,不建議伺服器場景 桌面應用倉庫 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop)。
### 啟動器部署 支援多系統架構,安裝包直接安裝,開箱即用,最適合新手和懶人的一鍵桌面部署方案,不推薦伺服器場景。
同樣在桌面端,對於希望快速部署並實現環境隔離多開的使用者,我們推薦使用 AstrBot Launcher ### 啟動器一鍵部署(AstrBot Launcher
前往 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 下載並安裝。 對於希望快速部署並實現環境隔離多開的使用者,我們推薦使用 AstrBot Launcher
進入 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 倉庫,在 Releases 頁最新版本下找到對應的系統安裝包安裝即可。
一個快速部署和多開方案,實現環境隔離。
### 在 Replit 上部署 ### 在 Replit 上部署
Replit 部署由社群維護,適合線上示範與輕量試用情境 社群貢獻的部署方式
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR ### AUR
AUR 方式面向 Arch Linux 使用者,適合希望透過系統套件管理器安裝 AstrBot 的場景。
在終端執行下方命令安裝 `astrbot-git` 套件,安裝完成後即可啟動使用。
```bash ```bash
yay -S astrbot-git yay -S astrbot-git
``` ```
**更多部署方式** **更多部署方式**[寶塔面板](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [手動部署](https://astrbot.app/deploy/astrbot/cli.html)
若你需要面板化或更高自訂程度的部署,可參考 [寶塔面板](https://astrbot.app/deploy/astrbot/btpanel.html)(BT Panel 應用商店安裝)、[1Panel](https://astrbot.app/deploy/astrbot/1panel.html)1Panel 應用商店安裝)、[CasaOS](https://astrbot.app/deploy/astrbot/casaos.html)(NAS / 家用伺服器可視化部署)與 [手動部署](https://astrbot.app/deploy/astrbot/cli.html)(基於原始碼與 `uv` 的完整自訂安裝)。
## 支援的訊息平台 ## 支援的訊息平台
@@ -217,16 +206,11 @@ pre-commit install
### QQ 群組 ### QQ 群組
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837 - 1 群:322154837
- 3 群:630166526 - 3 群:630166526
- 5 群:822130018 - 5 群:822130018
- 6 群:753075035 - 6 群:753075035
- 7 群:743746109 - 開發者群:975206796
- 8 群:1030353265
- 開發者群(闲聊吹水):975206796
- 開發者群(正式):1039761811
### Discord 群組 ### Discord 群組
+17 -31
View File
@@ -73,30 +73,21 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、
### 一键部署 ### 一键部署
对于想快速体验 AstrBot、且熟悉命令行并能够自行安装 `uv` 环境的用户,我们推荐使用 `uv` 一键部署方式 ⚡️ 对于想快速体验 AstrBot 的用户,我们推荐使用 `uv` 一键部署方式 ⚡️
```bash ```bash
uv tool install astrbot uv tool install astrbot
astrbot init # 仅首次执行此命令以初始化环境 astrbot init # 仅首次执行此命令以初始化环境
astrbot run astrbot
``` ```
> 需要安装 [uv](https://docs.astral.sh/uv/)。 > 需要安装 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 对于 macOS 用户:由于 macOS 安全检查,首次运行 `astrbot` 命令可能需要较长时间(约 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署 ### Docker 部署
对于熟悉容器、希望获得更稳定更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。 对于希望获得更稳定更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。
请参官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。 请参官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)
### 在 雨云 上部署 ### 在 雨云 上部署
@@ -104,37 +95,35 @@ uv tool upgrade astrbot
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) [![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### 桌面客户端部署 ### 桌面客户端Tauri
对于希望在桌面端使用 AstrBot、以 ChatUI 为主要入口的用户,我们推荐使用 AstrBot App 对于希望在桌面部署 AstrBot、以 AstrBot ChatUI 为主要使用方式、较少使用 AstrBot 插件的用户,我们推荐使用 AstrBot App
前往 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) 下载并安装;该方式面向桌面使用,不推荐服务器场景 桌面应用仓库 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop)。
### 启动器部署 支持多系统架构,安装包直接安装,开箱即用,最适合新手和懒人的一键桌面部署方案,不推荐服务器场景。
同样在桌面端,希望快速部署并实现环境隔离多开的用户,我们推荐使用 AstrBot Launcher ### 启动器一键部署(AstrBot Launcher
前往 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 下载并安装。 对于希望快速部署并实现环境隔离多开的用户,我们推荐使用 AstrBot Launcher
进入 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 仓库,在 Releases 页最新版本下找到对应的系统安装包安装即可。
一个快速部署和多开方案,实现环境隔离。
### 在 Replit 上部署 ### 在 Replit 上部署
Replit 部署由社区维护,适合在线演示和轻量试用场景 社区贡献的部署方式
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR ### AUR
AUR 方式面向 Arch Linux 用户,适合希望通过系统包管理器安装 AstrBot 的场景。
在终端执行下方命令安装 `astrbot-git` 包,安装完成后即可启动使用。
```bash ```bash
yay -S astrbot-git yay -S astrbot-git
``` ```
**更多部署方式** **更多部署方式**[宝塔面板](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [手动部署](https://astrbot.app/deploy/astrbot/cli.html)
若你需要面板化或更高自定义部署,可参考 [宝塔面板](https://astrbot.app/deploy/astrbot/btpanel.html)(BT Panel 应用商店安装)、[1Panel](https://astrbot.app/deploy/astrbot/1panel.html)1Panel 应用商店安装)、[CasaOS](https://astrbot.app/deploy/astrbot/casaos.html)(NAS / 家庭服务器可视化部署)和 [手动部署](https://astrbot.app/deploy/astrbot/cli.html)(基于源码与 `uv` 的完整自定义安装)。
## 支持的消息平台 ## 支持的消息平台
@@ -218,16 +207,13 @@ pre-commit install
### QQ 群组 ### QQ 群组
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837 - 1 群:322154837
- 3 群:630166526 - 3 群:630166526
- 5 群:822130018 - 5 群:822130018
- 6 群:753075035 - 6 群:753075035
- 7 群:743746109 - 7 群:743746109
- 8 群:1030353265 - 8 群:1030353265
- 开发者群(偏闲聊吹水)975206796 - 开发者群:975206796
- 开发者群(正式):1039761811
### Discord 频道 ### Discord 频道
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.20.1" __version__ = "4.18.3"
+1 -15
View File
@@ -4,21 +4,7 @@ from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH from astrbot.core.config.default import DB_PATH
from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.file_token_service import FileTokenService from astrbot.core.file_token_service import FileTokenService
from astrbot.core.utils.pip_installer import ( from astrbot.core.utils.pip_installer import PipInstaller
DependencyConflictError as DependencyConflictError,
)
from astrbot.core.utils.pip_installer import (
PipInstaller,
)
from astrbot.core.utils.requirements_utils import (
RequirementsPrecheckFailed as RequirementsPrecheckFailed,
)
from astrbot.core.utils.requirements_utils import (
find_missing_requirements as find_missing_requirements,
)
from astrbot.core.utils.requirements_utils import (
find_missing_requirements_or_raise as find_missing_requirements_or_raise,
)
from astrbot.core.utils.shared_preferences import SharedPreferences from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.t2i.renderer import HtmlRenderer from astrbot.core.utils.t2i.renderer import HtmlRenderer
+1 -1
View File
@@ -62,4 +62,4 @@ class HandoffTool(FunctionTool, Generic[TContext]):
def default_description(self, agent_name: str | None) -> str: def default_description(self, agent_name: str | None) -> str:
agent_name = agent_name or "another" agent_name = agent_name or "another"
return f"Delegate tasks to {agent_name} agent to handle the request." return f"Delegate tasks to {self.name} agent to handle the request."
+6 -19
View File
@@ -144,14 +144,10 @@ class MCPClient:
cfg = _prepare_config(mcp_server_config.copy()) cfg = _prepare_config(mcp_server_config.copy())
def logging_callback( def logging_callback(msg: str) -> None:
msg: str | mcp.types.LoggingMessageNotificationParams,
) -> None:
# Handle MCP service error logs # Handle MCP service error logs
if isinstance(msg, mcp.types.LoggingMessageNotificationParams): print(f"MCP Server {name} Error: {msg}")
if msg.level in ("warning", "error", "critical", "alert", "emergency"): self.server_errlogs.append(msg)
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
if "url" in cfg: if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg) success, error_msg = await _quick_test_mcp_connection(cfg)
@@ -218,24 +214,15 @@ class MCPClient:
**cfg, **cfg,
) )
def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None: def callback(msg: str) -> None:
# Handle MCP service error logs # Handle MCP service error logs
if isinstance(msg, mcp.types.LoggingMessageNotificationParams): self.server_errlogs.append(msg)
if msg.level in (
"warning",
"error",
"critical",
"alert",
"emergency",
):
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
stdio_transport = await self.exit_stack.enter_async_context( stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client( mcp.stdio_client(
server_params, server_params,
errlog=LogPipe( errlog=LogPipe(
level=logging.INFO, level=logging.ERROR,
logger=logger, logger=logger,
identifier=f"MCPServer-{name}", identifier=f"MCPServer-{name}",
callback=callback, callback=callback,
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
while True: while True:
try: try:
item_type, item_data = await asyncio.get_running_loop().run_in_executor( item_type, item_data = await asyncio.get_event_loop().run_in_executor(
None, response_queue.get, True, 1 None, response_queue.get, True, 1
) )
except queue.Empty: except queue.Empty:
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
# 发起请求 # 发起请求
partial = functools.partial(Application.call, **payload) partial = functools.partial(Application.call, **payload)
response = await asyncio.get_running_loop().run_in_executor(None, partial) response = await asyncio.get_event_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id): async for resp in self._handle_streaming_response(response, session_id):
yield resp yield resp
+8 -8
View File
@@ -390,9 +390,14 @@ async def _ensure_persona_and_skills(
persona_tools = None persona_tools = None
pid = a.get("persona_id") pid = a.get("persona_id")
if pid: if pid:
persona = plugin_context.persona_manager.get_persona_v3_by_id(pid) persona_tools = next(
if persona is not None: (
persona_tools = persona.get("tools") p.get("tools")
for p in plugin_context.persona_manager.personas_v3
if p["name"] == pid
),
None,
)
tools = a.get("tools", []) tools = a.get("tools", [])
if persona_tools is not None: if persona_tools is not None:
tools = persona_tools tools = persona_tools
@@ -773,14 +778,9 @@ def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
continue continue
mp = tool.handler_module_path mp = tool.handler_module_path
if not mp: if not mp:
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*
# 不应受到会话插件过滤影响。
new_tool_set.add_tool(tool)
continue continue
plugin = star_map.get(mp) plugin = star_map.get(mp)
if not plugin: if not plugin:
# 无法解析插件归属时,保守保留工具,避免误过滤。
new_tool_set.add_tool(tool)
continue continue
if plugin.name in event.plugins_name or plugin.reserved: if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool) new_tool_set.add_tool(tool)
+2 -20
View File
@@ -188,12 +188,7 @@ class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
@dataclass @dataclass
class SendMessageToUserTool(FunctionTool[AstrAgentContext]): class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
name: str = "send_message_to_user" name: str = "send_message_to_user"
description: str = ( description: str = "Directly send message to the user. Only use this tool when you need to proactively message the user. Otherwise you can directly output the reply in the conversation."
"Send message to the user. "
"Supports various message types including `plain`, `image`, `record`, `video`, `file`, and `mention_user`. "
"Use this tool to send media files (`image`, `record`, `video`, `file`), "
"or when you need to proactively message the user(such as cron job). For normal text replies, you can output directly."
)
parameters: dict = Field( parameters: dict = Field(
default_factory=lambda: { default_factory=lambda: {
@@ -209,7 +204,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
"type": "string", "type": "string",
"description": ( "description": (
"Component type. One of: " "Component type. One of: "
"plain, image, record, video, file, mention_user. Record is voice message." "plain, image, record, file, mention_user"
), ),
}, },
"text": { "text": {
@@ -325,19 +320,6 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
components.append(Comp.Record.fromURL(url=url)) components.append(Comp.Record.fromURL(url=url))
else: else:
return f"error: messages[{idx}] must include path or url for record component." return f"error: messages[{idx}] must include path or url for record component."
elif msg_type == "video":
path = msg.get("path")
url = msg.get("url")
if path:
(
local_path,
file_from_sandbox,
) = await self._resolve_path_from_sandbox(context, path)
components.append(Comp.Video.fromFileSystem(path=local_path))
elif url:
components.append(Comp.Video.fromURL(url=url))
else:
return f"error: messages[{idx}] must include path or url for video component."
elif msg_type == "file": elif msg_type == "file":
path = msg.get("path") path = msg.get("path")
url = msg.get("url") url = msg.get("url")
+2 -3
View File
@@ -121,12 +121,11 @@ class BayContainerManager:
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
"""Block until Bay's ``/health`` endpoint returns 200.""" """Block until Bay's ``/health`` endpoint returns 200."""
url = f"http://127.0.0.1:{self._host_port}/health" url = f"http://127.0.0.1:{self._host_port}/health"
loop = asyncio.get_running_loop() deadline = asyncio.get_event_loop().time() + timeout
deadline = loop.time() + timeout
last_error: str = "" last_error: str = ""
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
while loop.time() < deadline: while asyncio.get_event_loop().time() < deadline:
try: try:
async with session.get( async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3) url, timeout=aiohttp.ClientTimeout(total=3)
+8 -38
View File
@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import locale
import os import os
import shutil import shutil
import subprocess import subprocess
@@ -53,31 +52,6 @@ def _ensure_safe_path(path: str) -> str:
return abs_path return abs_path
def _decode_shell_output(output: bytes | None) -> str:
if output is None:
return ""
preferred = locale.getpreferredencoding(False) or "utf-8"
try:
return output.decode("utf-8")
except (LookupError, UnicodeDecodeError):
pass
if os.name == "nt":
for encoding in ("mbcs", "cp936", "gbk", "gb18030"):
try:
return output.decode(encoding)
except (LookupError, UnicodeDecodeError):
continue
try:
return output.decode(preferred)
except (LookupError, UnicodeDecodeError):
pass
return output.decode("utf-8", errors="replace")
@dataclass @dataclass
class LocalShellComponent(ShellComponent): class LocalShellComponent(ShellComponent):
async def exec( async def exec(
@@ -98,32 +72,28 @@ class LocalShellComponent(ShellComponent):
run_env.update({str(k): str(v) for k, v in env.items()}) run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root() working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
if background: if background:
# `command` is intentionally executed through the current shell so proc = subprocess.Popen(
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
proc = subprocess.Popen( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
command, command,
shell=shell, shell=shell,
cwd=working_dir, cwd=working_dir,
env=run_env, env=run_env,
stdout=subprocess.DEVNULL, stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL, stderr=subprocess.PIPE,
text=True,
) )
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None} return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
# `command` is intentionally executed through the current shell so result = subprocess.run(
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
result = subprocess.run( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
command, command,
shell=shell, shell=shell,
cwd=working_dir, cwd=working_dir,
env=run_env, env=run_env,
timeout=timeout, timeout=timeout,
capture_output=True, capture_output=True,
text=True,
) )
return { return {
"stdout": _decode_shell_output(result.stdout), "stdout": result.stdout,
"stderr": _decode_shell_output(result.stderr), "stderr": result.stderr,
"exit_code": result.returncode, "exit_code": result.returncode,
} }
+6 -23
View File
@@ -213,25 +213,14 @@ def parse_description(text: str) -> str:
break break
if end_idx is None: if end_idx is None:
return "" return ""
for line in lines[1:end_idx]:
frontmatter = "\n".join(lines[1:end_idx]) if ":" not in line:
try: continue
import yaml key, value = line.split(":", 1)
except ImportError: if key.strip().lower() == "description":
return value.strip().strip('"').strip("'")
return "" return ""
try:
payload = yaml.safe_load(frontmatter) or dict()
except yaml.YAMLError:
return ""
if not isinstance(payload, dict):
return ""
description = payload.get("description", "")
if not isinstance(description, str):
return ""
return description.strip()
def load_managed_skills() -> list[str]: def load_managed_skills() -> list[str]:
if not managed_file.exists(): if not managed_file.exists():
@@ -433,12 +422,6 @@ async def get_booter(
) -> ComputerBooter: ) -> ComputerBooter:
config = context.get_config(umo=session_id) config = context.get_config(umo=session_id)
runtime = config.get("provider_settings", {}).get("computer_use_runtime", "local")
if runtime == "local":
return get_local_booter()
elif runtime == "none":
raise RuntimeError("Sandbox runtime is disabled by configuration.")
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
booter_type = sandbox_cfg.get("booter", "shipyard_neo") booter_type = sandbox_cfg.get("booter", "shipyard_neo")
+1 -4
View File
@@ -164,10 +164,7 @@ class CreateSkillPayloadTool(NeoSkillToolBase):
"type": "object", "type": "object",
"properties": { "properties": {
"payload": { "payload": {
"anyOf": [ "anyOf": [{"type": "object"}, {"type": "array"}],
{"type": "object"},
{"type": "array", "items": {"type": "object"}},
],
"description": ( "description": (
"Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. " "Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. "
"This only stores content and returns payload_ref; it does not create a candidate or release." "This only stores content and returns payload_ref; it does not create a candidate or release."
+12 -84
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.20.1" VERSION = "4.18.3"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [ WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -219,9 +219,6 @@ DEFAULT_CONFIG = {
"telegram": { "telegram": {
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]}, "pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
}, },
"discord": {
"pre_ack_emoji": {"enable": False, "emojis": ["🤔"]},
},
}, },
"wake_prefix": ["/"], "wake_prefix": ["/"],
"log_level": "INFO", "log_level": "INFO",
@@ -345,20 +342,14 @@ CONFIG_METADATA_2 = {
"企业微信智能机器人": { "企业微信智能机器人": {
"id": "wecom_ai_bot", "id": "wecom_ai_bot",
"type": "wecom_ai_bot", "type": "wecom_ai_bot",
"hint": "如果发现字段有异常,请重新创建",
"enable": True, "enable": True,
"wecom_ai_bot_connection_mode": "long_connection", # long_connection, webhook
"wecom_ai_bot_name": "",
"wecomaibot_ws_bot_id": "",
"wecomaibot_ws_secret": "",
"wecomaibot_token": "",
"wecomaibot_encoding_aes_key": "",
"wecomaibot_init_respond_text": "", "wecomaibot_init_respond_text": "",
"wecomaibot_friend_message_welcome_text": "", "wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"msg_push_webhook_url": "", "msg_push_webhook_url": "",
"only_use_webhook_url_to_send": False, "only_use_webhook_url_to_send": False,
"wecomaibot_ws_url": "wss://openws.work.weixin.qq.com", "token": "",
"wecomaibot_heartbeat_interval": 30, "encoding_aes_key": "",
"unified_webhook_mode": True, "unified_webhook_mode": True,
"webhook_uuid": "", "webhook_uuid": "",
"callback_server_host": "0.0.0.0", "callback_server_host": "0.0.0.0",
@@ -463,6 +454,7 @@ CONFIG_METADATA_2 = {
"type": "kook", "type": "kook",
"enable": False, "enable": False,
"kook_bot_token": "", "kook_bot_token": "",
"kook_bot_nickname": "",
"kook_reconnect_delay": 1, "kook_reconnect_delay": 1,
"kook_max_reconnect_delay": 60, "kook_max_reconnect_delay": 60,
"kook_max_retry_delay": 60, "kook_max_retry_delay": 60,
@@ -740,13 +732,6 @@ CONFIG_METADATA_2 = {
"type": "string", "type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。", "hint": "请务必填写正确,否则无法使用一些指令。",
}, },
"wecom_ai_bot_connection_mode": {
"description": "企业微信智能机器人连接模式",
"type": "string",
"options": ["webhook", "long_connection"],
"labels": ["Webhook 回调", "长连接"],
"hint": "Webhook 回调模式需要配置 Token/EncodingAESKey。长连接模式需要配置 BotID/Secret。",
},
"wecomaibot_init_respond_text": { "wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本", "description": "企业微信智能机器人初始响应文本",
"type": "string", "type": "string",
@@ -757,22 +742,6 @@ CONFIG_METADATA_2 = {
"type": "string", "type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。", "hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
}, },
"wecomaibot_token": {
"description": "企业微信智能机器人 Token",
"type": "string",
"hint": "用于 Webhook 回调模式的身份验证。",
"condition": {
"wecom_ai_bot_connection_mode": "webhook",
},
},
"wecomaibot_encoding_aes_key": {
"description": "企业微信智能机器人 EncodingAESKey",
"type": "string",
"hint": "用于 Webhook 回调模式的消息加密解密。",
"condition": {
"wecom_ai_bot_connection_mode": "webhook",
},
},
"msg_push_webhook_url": { "msg_push_webhook_url": {
"description": "企业微信消息推送 Webhook URL", "description": "企业微信消息推送 Webhook URL",
"type": "string", "type": "string",
@@ -783,40 +752,6 @@ CONFIG_METADATA_2 = {
"type": "bool", "type": "bool",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。", "hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
}, },
"wecomaibot_ws_bot_id": {
"description": "长连接 BotID",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 BotID。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_ws_secret": {
"description": "长连接 Secret",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 Secret。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_ws_url": {
"description": "长连接 WebSocket 地址",
"type": "string",
"invisible": True,
"hint": "默认值为 wss://openws.work.weixin.qq.com,一般无需修改。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_heartbeat_interval": {
"description": "长连接心跳间隔",
"type": "int",
"invisible": True,
"hint": "长连接模式心跳间隔(秒),建议 30 秒。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"lark_bot_name": { "lark_bot_name": {
"description": "飞书机器人的名字", "description": "飞书机器人的名字",
"type": "string", "type": "string",
@@ -861,7 +796,7 @@ CONFIG_METADATA_2 = {
"unified_webhook_mode": { "unified_webhook_mode": {
"description": "统一 Webhook 模式", "description": "统一 Webhook 模式",
"type": "bool", "type": "bool",
"hint": "Webhook 模式下使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}", "hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}",
}, },
"webhook_uuid": { "webhook_uuid": {
"invisible": True, "invisible": True,
@@ -874,6 +809,11 @@ CONFIG_METADATA_2 = {
"type": "string", "type": "string",
"hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。", "hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。",
}, },
"kook_bot_nickname": {
"description": "Bot Nickname",
"type": "string",
"hint": "可选项。若发送者昵称与此值一致,将忽略该消息以避免广播风暴。",
},
"kook_reconnect_delay": { "kook_reconnect_delay": {
"description": "重连延迟", "description": "重连延迟",
"type": "int", "type": "int",
@@ -1126,18 +1066,6 @@ CONFIG_METADATA_2 = {
"proxy": "", "proxy": "",
"custom_headers": {}, "custom_headers": {},
}, },
"MiniMax": {
"id": "minimax",
"provider": "minimax",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.minimaxi.com/v1",
"timeout": 120,
"proxy": "",
"custom_headers": {},
},
"xAI": { "xAI": {
"id": "xai", "id": "xai",
"provider": "xai", "provider": "xai",
@@ -1195,7 +1123,7 @@ CONFIG_METADATA_2 = {
"enable": True, "enable": True,
"key": [], "key": [],
"timeout": 120, "timeout": 120,
"api_base": "https://openrouter.ai/api/v1", "api_base": "https://openrouter.ai/v1",
"proxy": "", "proxy": "",
"custom_headers": {}, "custom_headers": {},
}, },
+1 -1
View File
@@ -332,7 +332,7 @@ class CronJobManager:
cron_job=cron_job_str cron_job=cron_job_str
) )
req.prompt = ( req.prompt = (
"You are now responding to a scheduled task. " "You are now responding to a scheduled task"
"Proceed according to your system instructions. " "Proceed according to your system instructions. "
"Output using same language as previous conversation." "Output using same language as previous conversation."
"After completing your task, summarize and output your actions and results." "After completing your task, summarize and output your actions and results."
-15
View File
@@ -33,18 +33,10 @@ class BaseDatabase(abc.ABC):
DATABASE_URL = "" DATABASE_URL = ""
def __init__(self) -> None: def __init__(self) -> None:
# SQLite only supports a single writer at a time. Without a busy
# timeout the driver raises "database is locked" instantly when a
# second write is attempted. Setting timeout=30 tells SQLite to
# wait up to 30 s for the lock, which is enough to ride out brief
# write bursts from concurrent agent/metrics/session operations.
is_sqlite = "sqlite" in self.DATABASE_URL
connect_args = {"timeout": 30} if is_sqlite else {}
self.engine = create_async_engine( self.engine = create_async_engine(
self.DATABASE_URL, self.DATABASE_URL,
echo=False, echo=False,
future=True, future=True,
connect_args=connect_args,
) )
self.AsyncSessionLocal = async_sessionmaker( self.AsyncSessionLocal = async_sessionmaker(
self.engine, self.engine,
@@ -655,13 +647,6 @@ class BaseDatabase(abc.ABC):
"""Get a Platform session by its ID.""" """Get a Platform session by its ID."""
... ...
@abc.abstractmethod
async def get_platform_sessions_by_ids(
self, session_ids: list[str]
) -> list[PlatformSession]:
"""Get platform sessions by IDs."""
...
@abc.abstractmethod @abc.abstractmethod
async def get_platform_sessions_by_creator( async def get_platform_sessions_by_creator(
self, self,
-15
View File
@@ -1417,21 +1417,6 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query) result = await session.execute(query)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def get_platform_sessions_by_ids(
self, session_ids: list[str]
) -> list[PlatformSession]:
"""Get platform sessions by IDs."""
if not session_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformSession).where(
col(PlatformSession.session_id).in_(session_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_platform_sessions_by_creator( async def get_platform_sessions_by_creator(
self, self,
creator: str, creator: str,
+15 -41
View File
@@ -96,10 +96,10 @@ class Plain(BaseMessageComponent):
def __init__(self, text: str, convert: bool = True, **_) -> None: def __init__(self, text: str, convert: bool = True, **_) -> None:
super().__init__(text=text, convert=convert, **_) super().__init__(text=text, convert=convert, **_)
def toDict(self) -> dict: def toDict(self):
return {"type": "text", "data": {"text": self.text}} return {"type": "text", "data": {"text": self.text.strip()}}
async def to_dict(self) -> dict: async def to_dict(self):
return {"type": "text", "data": {"text": self.text}} return {"type": "text", "data": {"text": self.text}}
@@ -539,36 +539,13 @@ class Reply(BaseMessageComponent):
class Poke(BaseMessageComponent): class Poke(BaseMessageComponent):
type: ComponentType = ComponentType.Poke type: str = ComponentType.Poke
_type: str | int = "126" id: int | None = 0
id: int | str | None = 0 qq: int | None = 0
qq: int | str | None = 0 # deprecated: legacy field, kept for compatibility
def __init__(self, poke_type: str | int | None = None, **_) -> None: def __init__(self, type: str, **_) -> None:
# Backward compatible with old signature: Poke(type="poke", ...) type = f"Poke:{type}"
legacy_type = _.pop("type", None) super().__init__(type=type, **_)
if poke_type is None:
poke_type = legacy_type
if poke_type in (None, "", "poke", "Poke"):
poke_type = "126"
super().__init__(_type=str(poke_type), **_)
def target_id(self) -> str | None:
"""Return normalized target id, compatible with old `qq` field."""
for value in (self.id, self.qq):
if value is None:
continue
text = str(value).strip()
if text and text != "0":
return text
return None
def toDict(self):
target_id = self.target_id()
data = {"type": str(self._type or "126")}
if target_id:
data["id"] = target_id
return {"type": "poke", "data": data}
class Forward(BaseMessageComponent): class Forward(BaseMessageComponent):
@@ -699,24 +676,21 @@ class File(BaseMessageComponent):
if self.url: if self.url:
try: try:
# 检查是否有正在运行的 event loop loop = asyncio.get_event_loop()
asyncio.get_running_loop() if loop.is_running():
logger.warning( logger.warning(
"不可以在异步上下文中同步等待下载! " "不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。" "这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段", "请使用 await get_file() 代替直接获取 <File>.file 字段",
) )
return "" return ""
except RuntimeError: # 等待下载完成
# 没有运行中的 event loop,可以同步执行 loop.run_until_complete(self._download_file())
try:
# 使用 asyncio.run 安全地创建和关闭事件循环
asyncio.run(self._download_file())
except Exception:
logger.exception("文件下载失败")
if self.file_ and os.path.exists(self.file_): if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_) return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return "" return ""
+6 -17
View File
@@ -44,22 +44,6 @@ class PersonaManager:
raise ValueError(f"Persona with ID {persona_id} does not exist.") raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona return persona
def get_persona_v3_by_id(self, persona_id: str | None) -> Personality | None:
"""Resolve a v3 persona object by id.
- None/empty id returns None.
- "default" maps to in-memory DEFAULT_PERSONALITY.
- Otherwise search in personas_v3 by persona name.
"""
if not persona_id:
return None
if persona_id == "default":
return DEFAULT_PERSONALITY
return next(
(persona for persona in self.personas_v3 if persona["name"] == persona_id),
None,
)
async def get_default_persona_v3( async def get_default_persona_v3(
self, self,
umo: str | MessageSession | None = None, umo: str | MessageSession | None = None,
@@ -70,7 +54,12 @@ class PersonaManager:
"default_personality", "default_personality",
"default", "default",
) )
return self.get_persona_v3_by_id(default_persona_id) or DEFAULT_PERSONALITY if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
try:
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
except Exception:
return DEFAULT_PERSONALITY
async def resolve_selected_persona( async def resolve_selected_persona(
self, self,
+1 -1
View File
@@ -28,7 +28,7 @@ class RespondStage(Stage):
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @ Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.Image: lambda comp: bool(comp.file), # 图片 Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.target_id() is not None, # 戳一戳 Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Node: lambda comp: bool(comp.content), # 转发节点 Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url), Comp.File: lambda comp: bool(comp.file_ or comp.url),
@@ -5,7 +5,7 @@ import traceback
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, Image, Node, Plain, Record, Reply from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -383,11 +383,8 @@ class ResultDecorateStage(Stage):
) )
result.chain = [node] result.chain = [node]
# at 回复 / 引用回复仅适用于纯文本或图文消息 has_plain = any(isinstance(item, Plain) for item in result.chain)
can_decorate = all( if has_plain:
isinstance(item, (Plain, Image)) for item in result.chain
)
if can_decorate:
# at 回复 # at 回复
if ( if (
self.reply_with_mention self.reply_with_mention
@@ -402,4 +399,5 @@ class ResultDecorateStage(Stage):
# 引用回复 # 引用回复
if self.reply_with_quote: if self.reply_with_quote:
if not any(isinstance(item, File) for item in result.chain):
result.chain.insert(0, Reply(id=event.message_obj.message_id)) result.chain.insert(0, Reply(id=event.message_obj.message_id))
@@ -6,7 +6,6 @@ from aiocqhttp import CQHttp, Event
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import ( from astrbot.api.message_components import (
At,
BaseMessageComponent, BaseMessageComponent,
File, File,
Image, Image,
@@ -71,19 +70,11 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
"""解析成 OneBot json 格式""" """解析成 OneBot json 格式"""
ret = [] ret = []
for segment in message_chain.chain: for segment in message_chain.chain:
if isinstance(segment, At): if isinstance(segment, Plain):
# At 组件后插入一个空格,避免与后续文本粘连
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
ret.append(d)
ret.append({"type": "text", "data": {"text": " "}})
elif isinstance(segment, Plain):
if not segment.text.strip(): if not segment.text.strip():
continue continue
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment) d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
ret.append(d) ret.append(d)
else:
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
ret.append(d)
return ret return ret
@classmethod @classmethod
@@ -191,7 +191,7 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event: if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event: if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(Poke(id=str(event["target_id"]))) abm.message.append(Poke(qq=str(event["target_id"]), type="poke"))
return abm return abm
@@ -11,7 +11,7 @@ from dingtalk_stream import AckMessage
from astrbot import logger from astrbot import logger
from astrbot.api.event import MessageChain from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, File, Image, Plain, Record, Video from astrbot.api.message_components import At, Image, Plain, Record, Video
from astrbot.api.platform import ( from astrbot.api.platform import (
AstrBotMessage, AstrBotMessage,
MessageMember, MessageMember,
@@ -178,110 +178,29 @@ class DingtalkPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
message_type: str = cast(str, message.message_type) message_type: str = cast(str, message.message_type)
robot_code = cast(str, message.robot_code or "")
raw_content = cast(dict, message.extensions.get("content") or {})
if not isinstance(raw_content, dict):
raw_content = {}
match message_type: match message_type:
case "text": case "text":
abm.message_str = message.text.content.strip() abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str)) abm.message.append(Plain(abm.message_str))
case "picture":
if not robot_code:
logger.error("钉钉图片消息解析失败: 回调中缺少 robotCode")
await self._remember_sender_binding(message, abm)
return abm
image_content = cast(
dingtalk_stream.ImageContent | None,
message.image_content,
)
download_code = cast(
str, (image_content.download_code if image_content else "") or ""
)
if not download_code:
logger.warning("钉钉图片消息缺少 downloadCode,已跳过")
else:
f_path = await self.download_ding_file(
download_code,
robot_code,
"jpg",
)
if f_path:
abm.message.append(Image.fromFileSystem(f_path))
else:
logger.warning("钉钉图片消息下载失败,无法解析为图片")
case "richText": case "richText":
rtc: dingtalk_stream.RichTextContent = cast( rtc: dingtalk_stream.RichTextContent = cast(
dingtalk_stream.RichTextContent, message.rich_text_content dingtalk_stream.RichTextContent, message.rich_text_content
) )
contents: list[dict] = cast(list[dict], rtc.rich_text_list) contents: list[dict] = cast(list[dict], rtc.rich_text_list)
plain_parts: list[str] = []
for content in contents: for content in contents:
plains = ""
if "text" in content: if "text" in content:
plain_text = cast(str, content.get("text") or "") plains += content["text"]
if plain_text: abm.message.append(Plain(plains))
plain_parts.append(plain_text)
abm.message.append(Plain(plain_text))
elif "type" in content and content["type"] == "picture": elif "type" in content and content["type"] == "picture":
download_code = cast(str, content.get("downloadCode") or "")
if not download_code:
logger.warning(
"钉钉富文本图片消息缺少 downloadCode,已跳过"
)
continue
if not robot_code:
logger.error(
"钉钉富文本图片消息解析失败: 回调中缺少 robotCode"
)
continue
f_path = await self.download_ding_file( f_path = await self.download_ding_file(
download_code, content["downloadCode"],
robot_code, cast(str, message.robot_code),
"jpg", "jpg",
) )
if f_path:
abm.message.append(Image.fromFileSystem(f_path)) abm.message.append(Image.fromFileSystem(f_path))
abm.message_str = "".join(plain_parts).strip() case "audio":
case "audio" | "voice": pass
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉语音消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉语音消息解析失败: 回调中缺少 robotCode")
else:
voice_ext = cast(str, raw_content.get("fileExtension") or "")
if not voice_ext:
voice_ext = "amr"
voice_ext = voice_ext.lstrip(".")
f_path = await self.download_ding_file(
download_code,
robot_code,
voice_ext,
)
if f_path:
abm.message.append(Record.fromFileSystem(f_path))
case "file":
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉文件消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉文件消息解析失败: 回调中缺少 robotCode")
else:
file_name = cast(str, raw_content.get("fileName") or "")
file_ext = Path(file_name).suffix.lstrip(".") if file_name else ""
if not file_ext:
file_ext = cast(str, raw_content.get("fileExtension") or "")
if not file_ext:
file_ext = "file"
f_path = await self.download_ding_file(
download_code,
robot_code,
file_ext,
)
if f_path:
if not file_name:
file_name = Path(f_path).name
abm.message.append(File(name=file_name, file=f_path))
await self._remember_sender_binding(message, abm) await self._remember_sender_binding(message, abm)
return abm # 别忘了返回转换后的消息对象 return abm # 别忘了返回转换后的消息对象
@@ -351,23 +270,13 @@ class DingtalkPlatformAdapter(Platform):
) )
return "" return ""
resp_data = await resp.json() resp_data = await resp.json()
download_url = cast( download_url = resp_data["data"]["downloadUrl"]
str,
(
resp_data.get("downloadUrl")
or resp_data.get("data", {}).get("downloadUrl")
or ""
),
)
if not download_url:
logger.error(f"下载钉钉文件失败: 未找到 downloadUrl, 响应: {resp_data}")
return ""
await download_file(download_url, str(f_path)) await download_file(download_url, str(f_path))
return str(f_path) return str(f_path)
async def get_access_token(self) -> str: async def get_access_token(self) -> str:
try: try:
access_token = await asyncio.get_running_loop().run_in_executor( access_token = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client_.get_access_token, self.client_.get_access_token,
) )
@@ -632,28 +541,6 @@ class DingtalkPlatformAdapter(Platform):
self._safe_remove_file(cover_path) self._safe_remove_file(cover_path)
if converted_video: if converted_video:
self._safe_remove_file(video_path) self._safe_remove_file(video_path)
elif isinstance(segment, File):
try:
file_path = await segment.get_file()
if not file_path:
logger.warning("钉钉文件发送失败: 无法解析文件路径")
continue
media_id = await self.upload_media(file_path, "file")
if not media_id:
continue
file_name = segment.name or Path(file_path).name
file_type = Path(file_name).suffix.lstrip(".")
await send_message(
msg_key="sampleFile",
msg_param={
"mediaId": media_id,
"fileName": file_name,
"fileType": file_type,
},
)
except Exception as e:
logger.warning(f"钉钉文件发送失败: {e}")
continue
async def send_message_chain_to_group( async def send_message_chain_to_group(
self, self,
@@ -760,7 +647,7 @@ class DingtalkPlatformAdapter(Platform):
return return
logger.error(f"钉钉机器人启动失败: {e}") logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop) await loop.run_in_executor(None, start_client, loop)
async def terminate(self) -> None: async def terminate(self) -> None:
@@ -13,28 +13,11 @@ from astrbot.api.platform import (
PlatformMetadata, PlatformMetadata,
register_platform_adapter, register_platform_adapter,
) )
from astrbot.core.message.components import File, Record, Video
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from .kook_client import KookClient from .kook_client import KookClient
from .kook_config import KookConfig from .kook_config import KookConfig
from .kook_event import KookEvent from .kook_event import KookEvent
from .kook_types import (
ContainerModule,
FileModule,
HeaderModule,
ImageGroupModule,
KmarkdownElement,
KookCardMessageContainer,
KookChannelType,
KookMessageEventData,
KookMessageType,
KookModuleType,
PlainTextElement,
SectionModule,
)
KOOK_AT_SELECTOR_REGEX = re.compile(r"\(met\)([^()]+)\(met\)")
@register_platform_adapter( @register_platform_adapter(
@@ -74,26 +57,35 @@ class KookPlatformAdapter(Platform):
name="kook", description="KOOK 适配器", id=self.kook_config.id name="kook", description="KOOK 适配器", id=self.kook_config.id
) )
def _should_ignore_event_by_bot_nickname(self, author_id: str) -> bool: def _should_ignore_event_by_bot_nickname(self, payload: dict) -> bool:
return self.client.bot_id == author_id bot_nickname = self.kook_config.bot_nickname.strip()
if not bot_nickname:
return False
async def _on_received(self, event: KookMessageEventData): author = payload.get("extra", {}).get("author", {})
logger.debug( if not isinstance(author, dict):
f'[KOOK] 收到来自"{event.channel_type.name}"渠道的消息, 消息类型为: {event.type.name}({event.type.value})' return False
)
event_type = event.type author_nickname = author.get("nickname") or author.get("username") or ""
if event_type in (KookMessageType.KMARKDOWN, KookMessageType.CARD): if not isinstance(author_nickname, str):
if self._should_ignore_event_by_bot_nickname(event.author_id): author_nickname = str(author_nickname)
logger.debug("[KOOK] 收到来自机器人自身的消息, 忽略此消息")
return author_nickname.strip().casefold() == bot_nickname.casefold()
async def _on_received(self, data: dict):
logger.debug(f"KOOK 收到数据: {data}")
if "d" in data and data["s"] == 0:
payload = data["d"]
event_type = payload.get("type")
# 支持type=9(文本)和type=10(卡片)
if event_type in (9, 10):
if self._should_ignore_event_by_bot_nickname(payload):
return return
try: try:
abm = await self.convert_message(event) abm = await self.convert_message(payload)
await self.handle_msg(abm) await self.handle_msg(abm)
except Exception as e: except Exception as e:
logger.error(f"[KOOK] 消息处理异常: {e}") logger.error(f"[KOOK] 消息处理异常: {e}")
elif event_type == KookMessageType.SYSTEM:
logger.debug(f'[KOOK] 消息为系统通知, 通知类型为: "{event.extra.type}"')
logger.debug(f"[KOOK] 原始消息数据: {event.to_json()}")
async def run(self): async def run(self):
"""主运行循环""" """主运行循环"""
@@ -192,26 +184,18 @@ class KookPlatformAdapter(Platform):
logger.info("[KOOK] 资源清理完成") logger.info("[KOOK] 资源清理完成")
def _parse_kmarkdown_text_message( def _parse_kmarkdown_text_message(
self, data: KookMessageEventData, self_id: str self, data: dict, self_id: str
) -> tuple[list, str]: ) -> tuple[list, str]:
kmarkdown = data.extra.kmarkdown kmarkdown = data.get("extra", {}).get("kmarkdown", {})
content = data.content or "" content = data.get("content") or ""
if kmarkdown is None: raw_content = kmarkdown.get("raw_content") or content
logger.error(
f'[KOOK] 无法转换"{KookMessageType.KMARKDOWN.name}"消息, 消息中找不到kmarkdown字段'
)
logger.error(f"[KOOK] 原始消息内容: {data.to_json()}")
return [], ""
raw_content = kmarkdown.raw_content or content
if not isinstance(content, str): if not isinstance(content, str):
content = str(content) content = str(content)
if not isinstance(raw_content, str): if not isinstance(raw_content, str):
raw_content = str(raw_content) raw_content = str(raw_content)
# TODO 后面的pydantic类型替换,以后再来探索吧 :(
mention_name_map: dict[str, str] = {} mention_name_map: dict[str, str] = {}
mention_part = kmarkdown.mention_part mention_part = kmarkdown.get("mention_part", [])
if isinstance(mention_part, list): if isinstance(mention_part, list):
for item in mention_part: for item in mention_part:
if not isinstance(item, dict): if not isinstance(item, dict):
@@ -223,7 +207,7 @@ class KookPlatformAdapter(Platform):
components = [] components = []
cursor = 0 cursor = 0
for match in KOOK_AT_SELECTOR_REGEX.finditer(content): for match in re.finditer(r"\(met\)([^()]+)\(met\)", content):
if match.start() > cursor: if match.start() > cursor:
plain_text = content[cursor : match.start()] plain_text = content[cursor : match.start()]
if plain_text: if plain_text:
@@ -270,109 +254,77 @@ class KookPlatformAdapter(Platform):
return components, message_str return components, message_str
def _parse_card_message(self, data: KookMessageEventData) -> tuple[list, str]: def _parse_card_message(self, data: dict) -> tuple[list, str]:
content = data.content content = data.get("content", "[]")
if not isinstance(content, str): if not isinstance(content, str):
content = str(content) content = str(content)
card_list = json.loads(content)
card_list = KookCardMessageContainer.from_dict(json.loads(content))
text_parts: list[str] = [] text_parts: list[str] = []
images: list[str] = [] images: list[str] = []
files: list[tuple[KookModuleType, str, str]] = []
for card in card_list: for card in card_list:
for module in card.modules: if not isinstance(card, dict):
match module: continue
case SectionModule(): for module in card.get("modules", []):
if content := self._handle_section_text(module): if not isinstance(module, dict):
text_parts.append(content) continue
case ContainerModule() | ImageGroupModule(): module_type = module.get("type")
urls = self._handle_image_group(module) if module_type == "section":
images.extend(urls) section_text = module.get("text", {}).get("content", "")
text_parts.append(" [image]" * len(urls)) if section_text:
text_parts.append(str(section_text))
continue
case HeaderModule(): if module_type != "container":
text_parts.append(module.text.content) continue
case FileModule(): for element in module.get("elements", []):
files.append((module.type, module.title, module.src)) if not isinstance(element, dict):
text_parts.append(f" [{module.type.value}]") continue
if element.get("type") != "image":
continue
case _: image_src = element.get("src")
logger.debug(f"[KOOK] 跳过或未处理模块: {module.type}") if not isinstance(image_src, str):
logger.warning(
f'[KOOK] 处理卡片中的图片时发生错误,图片url "{image_src}" 应该为str类型, 而不是 "{type(image_src)}" '
)
continue
if not image_src.startswith(("http://", "https://")):
logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}")
continue
images.append(image_src)
text = "".join(text_parts) text = "".join(text_parts)
message = [] message = []
if text: if text:
for search in KOOK_AT_SELECTOR_REGEX.finditer(text):
search_text = search.group(1).strip()
if search_text == "all":
message.append(AtAll())
continue
message.append(At(qq=search_text))
text = text.replace(f"(met){search_text}(met)", "")
message.append(Plain(text=text)) message.append(Plain(text=text))
for img_url in images: for img_url in images:
message.append(Image(file=img_url)) message.append(Image(file=img_url))
for file in files:
file_type = file[0]
file_name = file[1]
file_url = file[2]
if file_type == KookModuleType.FILE:
message.append(File(name=file_name, file=file_url))
elif file_type == KookModuleType.VIDEO:
message.append(Video(file=file_url))
elif file_type == KookModuleType.AUDIO:
message.append(Record(file=file_url))
else:
logger.warning(f"[KOOK] 跳过未知文件类型: {file_type.name}")
return message, text return message, text
def _handle_section_text(self, module: SectionModule) -> str: async def convert_message(self, data: dict) -> AstrBotMessage:
"""专门处理 Section 里的文本提取"""
if isinstance(module.text, (KmarkdownElement, PlainTextElement)):
return module.text.content or ""
return ""
def _handle_image_group(
self, module: ContainerModule | ImageGroupModule
) -> list[str]:
"""专门处理图片组/容器里的合法 URL 提取"""
valid_urls = []
for el in module.elements:
image_src = el.src
if not el.src.startswith(("http://", "https://")):
logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}")
continue
valid_urls.append(el.src)
return valid_urls
async def convert_message(self, data: KookMessageEventData) -> AstrBotMessage:
abm = AstrBotMessage() abm = AstrBotMessage()
abm.raw_message = data.to_dict() abm.raw_message = data
abm.self_id = self.client.bot_id abm.self_id = self.client.bot_id
channel_type = data.channel_type channel_type = data.get("channel_type")
author_id = data.author_id author_id = data.get("author_id", "unknown")
# channel_type定义: https://developer.kookapp.cn/doc/event/event-introduction # channel_type定义: https://developer.kookapp.cn/doc/event/event-introduction
match channel_type: match channel_type:
case KookChannelType.GROUP: case "GROUP":
session_id = data.target_id or "unknown" session_id = data.get("target_id") or "unknown"
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
abm.group_id = session_id abm.group_id = session_id
abm.session_id = session_id abm.session_id = session_id
case KookChannelType.PERSON: case "PERSON":
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.group_id = "" abm.group_id = ""
abm.session_id = data.author_id or "unknown" abm.session_id = data.get("author_id", "unknown")
case KookChannelType.BROADCAST: case "BROADCAST":
session_id = data.target_id or "unknown" session_id = data.get("target_id") or "unknown"
abm.type = MessageType.OTHER_MESSAGE abm.type = MessageType.OTHER_MESSAGE
abm.group_id = session_id abm.group_id = session_id
abm.session_id = session_id abm.session_id = session_id
@@ -381,25 +333,28 @@ class KookPlatformAdapter(Platform):
abm.sender = MessageMember( abm.sender = MessageMember(
user_id=author_id, user_id=author_id,
nickname=data.extra.author.username if data.extra.author else "unknown", nickname=data.get("extra", {}).get("author", {}).get("username", ""),
) )
abm.message_id = data.msg_id or "unknown" abm.message_id = data.get("msg_id", "unknown")
if data.type == KookMessageType.KMARKDOWN: # 普通文本消息
message, message_str = self._parse_kmarkdown_text_message(data, abm.self_id) if data.get("type") == 9:
message, message_str = self._parse_kmarkdown_text_message(
data, str(abm.self_id)
)
abm.message = message abm.message = message
abm.message_str = message_str abm.message_str = message_str
elif data.type == KookMessageType.CARD: # 卡片消息
elif data.get("type") == 10:
try: try:
abm.message, abm.message_str = self._parse_card_message(data) abm.message, abm.message_str = self._parse_card_message(data)
except Exception as exp: except Exception as exp:
logger.error(f"[KOOK] 卡片消息解析失败: {exp}") logger.error(f"[KOOK] 卡片消息解析失败: {exp}")
logger.error(f"[KOOK] 原始消息内容: {data.to_json()}")
abm.message_str = "[卡片消息解析失败]" abm.message_str = "[卡片消息解析失败]"
abm.message = [Plain(text="[卡片消息解析失败]")] abm.message = [Plain(text="[卡片消息解析失败]")]
else: else:
logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.type.name}"') logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.get("type")}"')
abm.message_str = "[不支持的消息类型]" abm.message_str = "[不支持的消息类型]"
abm.message = [Plain(text="[不支持的消息类型]")] abm.message = [Plain(text="[不支持的消息类型]")]
+55 -102
View File
@@ -1,5 +1,6 @@
import asyncio import asyncio
import base64 import base64
import json
import os import os
import random import random
import time import time
@@ -8,23 +9,13 @@ from pathlib import Path
import aiofiles import aiofiles
import aiohttp import aiohttp
import pydantic
import websockets import websockets
from astrbot import logger from astrbot import logger
from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.message_type import MessageType
from .kook_config import KookConfig from .kook_config import KookConfig
from .kook_types import ( from .kook_types import KookApiPaths, KookMessageType
KookApiPaths,
KookGatewayIndexResponse,
KookHelloEventData,
KookMessageSignal,
KookMessageType,
KookResumeAckEventData,
KookUserMeResponse,
KookWebsocketEvent,
)
class KookClient: class KookClient:
@@ -32,8 +23,7 @@ class KookClient:
# 数据字段 # 数据字段
self.config = config self.config = config
self._bot_id = "" self._bot_id = ""
self._bot_username = "" self._bot_name = ""
self._bot_nickname = ""
# 资源字段 # 资源字段
self._http_client = aiohttp.ClientSession( self._http_client = aiohttp.ClientSession(
@@ -58,50 +48,37 @@ class KookClient:
return self._bot_id return self._bot_id
@property @property
def bot_nickname(self): def bot_name(self):
return self._bot_nickname return self._bot_name
@property async def get_bot_info(self) -> str:
def bot_username(self): """获取机器人账号ID"""
return self._bot_username
async def get_bot_info(self) -> None:
"""获取机器人账号信息"""
url = KookApiPaths.USER_ME url = KookApiPaths.USER_ME
try: try:
async with self._http_client.get(url) as resp: async with self._http_client.get(url) as resp:
if resp.status != 200: if resp.status != 200:
logger.error( logger.error(f"[KOOK] 获取机器人账号ID失败,状态码: {resp.status}")
f"[KOOK] 获取机器人账号信息失败,状态码: {resp.status} , {await resp.text()}" return ""
)
return
try:
resp_content = KookUserMeResponse.from_dict(await resp.json())
except pydantic.ValidationError as e:
logger.error(
f"[KOOK] 获取机器人账号信息失败, 响应数据格式错误: \n{e}"
)
logger.error(f"[KOOK] 响应内容: {await resp.text()}")
return
if not resp_content.success(): data = await resp.json()
logger.error( if data.get("code") != 0:
f"[KOOK] 获取机器人账号信息失败: {resp_content.model_dump_json()}" logger.error(f"[KOOK] 获取机器人账号ID失败: {data}")
) return ""
return
bot_id: str = resp_content.data.id bot_id: str = data["data"]["id"]
self._bot_id = bot_id self._bot_id = bot_id
logger.info(f"[KOOK] 获取机器人账号ID成功: {bot_id}") logger.info(f"[KOOK] 获取机器人账号ID成功: {bot_id}")
self._bot_nickname = resp_content.data.nickname bot_name: str = data["data"]["nickname"] or data["data"]["username"]
self._bot_username = resp_content.data.username self._bot_name = bot_name
logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_nickname}") logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_name}")
return bot_id
except Exception as e: except Exception as e:
logger.error(f"[KOOK] 获取机器人账号信息异常: {e}") logger.error(f"[KOOK] 获取机器人账号ID异常: {e}")
return ""
async def get_gateway_url(self, resume=False, sn=0, session_id=None) -> str | None: async def get_gateway_url(self, resume=False, sn=0, session_id=None):
"""获取网关连接地址""" """获取网关连接地址"""
url = KookApiPaths.GATEWAY_INDEX url = KookApiPaths.GATEWAY_INDEX
@@ -119,20 +96,14 @@ class KookClient:
logger.error(f"[KOOK] 获取gateway失败,状态码: {resp.status}") logger.error(f"[KOOK] 获取gateway失败,状态码: {resp.status}")
return None return None
resp_content = KookGatewayIndexResponse.from_dict(await resp.json()) data = await resp.json()
if not resp_content.success(): if data.get("code") != 0:
logger.error(f"[KOOK] 获取gateway失败: {resp_content}") logger.error(f"[KOOK] 获取gateway失败: {data}")
return None return None
gateway_url: str = resp_content.data.url gateway_url: str = data["data"]["url"]
logger.info(f"[KOOK] 获取gateway成功: {gateway_url.split('?')[0]}") logger.info(f"[KOOK] 获取gateway成功: {gateway_url.split('?')[0]}")
return gateway_url return gateway_url
except pydantic.ValidationError as e:
logger.error(f"[KOOK] 获取gateway失败, 响应数据格式错误: \n{e}")
logger.error(f"[KOOK] 原始响应内容: {await resp.text()}")
return None
except Exception as e: except Exception as e:
logger.error(f"[KOOK] 获取gateway异常: {e}") logger.error(f"[KOOK] 获取gateway异常: {e}")
return None return None
@@ -185,11 +156,7 @@ class KookClient:
try: try:
while self.running: while self.running:
try: try:
if self.ws is None: msg = await asyncio.wait_for(self.ws.recv(), timeout=10) # type: ignore
logger.error("[KOOK] WebSocket 对象丢失,结束监听流程。")
break
msg = await asyncio.wait_for(self.ws.recv(), timeout=10)
if isinstance(msg, bytes): if isinstance(msg, bytes):
try: try:
@@ -199,15 +166,10 @@ class KookClient:
continue continue
msg = msg.decode("utf-8") msg = msg.decode("utf-8")
event = KookWebsocketEvent.from_json(msg) data = json.loads(msg)
# 处理不同类型的信令 # 处理不同类型的信令
await self._handle_signal(event) await self._handle_signal(data)
except pydantic.ValidationError as e:
logger.error(f"[KOOK] 解析WebSocket事件数据格式失败: \n{e}")
logger.error(f"[KOOK] 原始响应内容: {msg}")
continue
except asyncio.TimeoutError: except asyncio.TimeoutError:
# 超时检查,继续循环 # 超时检查,继续循环
@@ -225,41 +187,38 @@ class KookClient:
self.running = False self.running = False
self._stop_event.set() self._stop_event.set()
async def _handle_signal(self, event: KookWebsocketEvent): async def _handle_signal(self, data):
"""处理不同类型的信令""" """处理不同类型的信令"""
data = event.data signal_type = data.get("s")
match event.signal: if signal_type == 0: # 事件消息
case KookMessageSignal.MESSAGE: # 更新消息序号
if event.sn is not None: if "sn" in data:
self.last_sn = event.sn self.last_sn = data["sn"]
await self.event_callback(data) await self.event_callback(data)
case KookMessageSignal.HELLO: elif signal_type == 1: # HELLO握手
assert isinstance(data, KookHelloEventData)
await self._handle_hello(data) await self._handle_hello(data)
case KookMessageSignal.RESUME_ACK: elif signal_type == 3: # PONG心跳响应
assert isinstance(data, KookResumeAckEventData) await self._handle_pong(data)
elif signal_type == 5: # RECONNECT重连指令
await self._handle_reconnect(data)
elif signal_type == 6: # RESUME ACK
await self._handle_resume_ack(data) await self._handle_resume_ack(data)
case KookMessageSignal.PONG: else:
await self._handle_pong() logger.debug(f"[KOOK] 未处理的信令类型: {signal_type}")
case KookMessageSignal.RECONNECT: async def _handle_hello(self, data):
await self._handle_reconnect()
case _:
logger.debug(
f"[KOOK] 未处理的信令类型: {event.signal.name}({event.signal.value})"
)
async def _handle_hello(self, data: KookHelloEventData):
"""处理HELLO握手""" """处理HELLO握手"""
code = data.code hello_data = data.get("d", {})
code = hello_data.get("code", 0)
if code == 0: if code == 0:
self.session_id = data.session_id self.session_id = hello_data.get("session_id")
logger.info(f"[KOOK] 握手成功,session_id: {self.session_id}") logger.info(f"[KOOK] 握手成功,session_id: {self.session_id}")
# TODO 重置重连延迟 # TODO 重置重连延迟
# self.reconnect_delay = 1 # self.reconnect_delay = 1
@@ -269,12 +228,12 @@ class KookClient:
logger.error("[KOOK] Token已过期,需要重新获取") logger.error("[KOOK] Token已过期,需要重新获取")
self.running = False self.running = False
async def _handle_pong(self): async def _handle_pong(self, data):
"""处理PONG心跳响应""" """处理PONG心跳响应"""
self.last_heartbeat_time = time.time() self.last_heartbeat_time = time.time()
self.heartbeat_failed_count = 0 self.heartbeat_failed_count = 0
async def _handle_reconnect(self): async def _handle_reconnect(self, data):
"""处理重连指令""" """处理重连指令"""
logger.warning("[KOOK] 收到重连指令") logger.warning("[KOOK] 收到重连指令")
# 清空本地状态 # 清空本地状态
@@ -282,9 +241,10 @@ class KookClient:
self.session_id = None self.session_id = None
self.running = False self.running = False
async def _handle_resume_ack(self, data: KookResumeAckEventData): async def _handle_resume_ack(self, data):
"""处理RESUME确认""" """处理RESUME确认"""
self.session_id = data.session_id resume_data = data.get("d", {})
self.session_id = resume_data.get("session_id")
logger.info(f"[KOOK] Resume成功,session_id: {self.session_id}") logger.info(f"[KOOK] Resume成功,session_id: {self.session_id}")
async def _heartbeat_loop(self): async def _heartbeat_loop(self):
@@ -332,16 +292,9 @@ class KookClient:
async def _send_ping(self): async def _send_ping(self):
"""发送心跳PING""" """发送心跳PING"""
if self.ws is None:
logger.warning("[KOOK] 尚未连接kook WebSocket服务器, 跳过发送心跳包流程")
return
try: try:
ping_data = KookWebsocketEvent( ping_data = {"s": 2, "sn": self.last_sn}
signal=KookMessageSignal.PING, await self.ws.send(json.dumps(ping_data)) # type: ignore
data=None,
sn=self.last_sn,
)
await self.ws.send(ping_data.to_json())
except Exception as e: except Exception as e:
logger.error(f"[KOOK] 发送心跳失败: {e}") logger.error(f"[KOOK] 发送心跳失败: {e}")
@@ -9,6 +9,7 @@ class KookConfig:
# 基础配置 # 基础配置
token: str token: str
bot_nickname: str = ""
enable: bool = False enable: bool = False
id: str = "kook" id: str = "kook"
@@ -40,6 +41,7 @@ class KookConfig:
# id=config_dict.get("id", "kook"), # id=config_dict.get("id", "kook"),
enable=config_dict.get("enable", False), enable=config_dict.get("enable", False),
token=config_dict.get("kook_bot_token", ""), token=config_dict.get("kook_bot_token", ""),
bot_nickname=config_dict.get("kook_bot_nickname", ""),
reconnect_delay=config_dict.get( reconnect_delay=config_dict.get(
"kook_reconnect_delay", "kook_reconnect_delay",
KookConfig.reconnect_delay, KookConfig.reconnect_delay,
@@ -27,7 +27,6 @@ from .kook_types import (
KookCardMessage, KookCardMessage,
KookCardMessageContainer, KookCardMessageContainer,
KookMessageType, KookMessageType,
KookModuleType,
OrderMessage, OrderMessage,
) )
@@ -112,7 +111,7 @@ class KookEvent(AstrMessageEvent):
KookCardMessage( KookCardMessage(
modules=[ modules=[
FileModule( FileModule(
type=KookModuleType.AUDIO, type="audio",
title=title, title=title,
src=url, src=url,
) )
@@ -183,7 +182,7 @@ class KookEvent(AstrMessageEvent):
if item.reply_id: if item.reply_id:
reply_id = item.reply_id reply_id = item.reply_id
if not item.text: if not item.text:
logger.debug(f'[Kook] 跳过空消息,类型为"{item.type.name}"') logger.debug(f'[Kook] 跳过空消息,类型为"{item.type}"')
continue continue
try: try:
await self.client.send_text( await self.client.send_text(
+55 -319
View File
@@ -1,8 +1,10 @@
import json import json
from enum import IntEnum, StrEnum from dataclasses import field
from typing import Annotated, Any, Literal from enum import IntEnum
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict
from pydantic.dataclasses import dataclass
class KookApiPaths: class KookApiPaths:
@@ -23,9 +25,8 @@ class KookApiPaths:
DIRECT_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/direct-message/create" DIRECT_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/direct-message/create"
# 定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction
class KookMessageType(IntEnum): class KookMessageType(IntEnum):
"""定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction"""
TEXT = 1 TEXT = 1
IMAGE = 2 IMAGE = 2
VIDEO = 3 VIDEO = 3
@@ -36,26 +37,6 @@ class KookMessageType(IntEnum):
SYSTEM = 255 SYSTEM = 255
class KookModuleType(StrEnum):
PLAIN_TEXT = "plain-text"
KMARKDOWN = "kmarkdown"
IMAGE = "image"
BUTTON = "button"
HEADER = "header"
SECTION = "section"
IMAGE_GROUP = "image-group"
CONTAINER = "container"
ACTION_GROUP = "action-group"
CONTEXT = "context"
DIVIDER = "divider"
FILE = "file"
AUDIO = "audio"
VIDEO = "video"
COUNTDOWN = "countdown"
INVITE = "invite"
CARD = "card"
ThemeType = Literal[ ThemeType = Literal[
"primary", "success", "danger", "warning", "info", "secondary", "none", "invisible" "primary", "success", "danger", "warning", "info", "secondary", "none", "invisible"
] ]
@@ -67,81 +48,43 @@ SectionMode = Literal["left", "right"]
CountdownMode = Literal["day", "hour", "second"] CountdownMode = Literal["day", "hour", "second"]
class KookBaseDataClass(BaseModel): class KookCardColor(str):
model_config = ConfigDict( """16 进制色值"""
extra="allow",
arbitrary_types_allowed=True,
populate_by_name=True,
)
@classmethod
def from_dict(cls, raw_data: dict):
return cls.model_validate(raw_data)
@classmethod
def from_json(cls, raw_data: str | bytes | bytearray):
return cls.model_validate_json(raw_data)
def to_dict(
self,
mode: Literal["json", "python"] | str = "python",
by_alias=True,
exclude_none=True,
exclude_unset=False,
) -> dict:
return self.model_dump(
by_alias=by_alias,
exclude_none=exclude_none,
mode=mode,
exclude_unset=exclude_unset,
)
def to_json(
self,
indent: int | None = None,
ensure_ascii=False,
by_alias=True,
exclude_none=True,
exclude_unset=False,
) -> str:
return self.model_dump_json(
indent=indent,
ensure_ascii=ensure_ascii,
by_alias=by_alias,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
)
class KookCardModelBase(KookBaseDataClass): class KookCardModelBase:
"""卡片模块基类""" """卡片模块基类"""
type: str type: str
@dataclass
class PlainTextElement(KookCardModelBase): class PlainTextElement(KookCardModelBase):
content: str content: str
type: Literal[KookModuleType.PLAIN_TEXT] = KookModuleType.PLAIN_TEXT type: str = "plain-text"
emoji: bool = True emoji: bool = True
@dataclass
class KmarkdownElement(KookCardModelBase): class KmarkdownElement(KookCardModelBase):
content: str content: str
type: Literal[KookModuleType.KMARKDOWN] = KookModuleType.KMARKDOWN type: str = "kmarkdown"
@dataclass
class ImageElement(KookCardModelBase): class ImageElement(KookCardModelBase):
src: str src: str
type: Literal[KookModuleType.IMAGE] = KookModuleType.IMAGE type: str = "image"
alt: str = "" alt: str = ""
size: SizeType = "lg" size: SizeType = "lg"
circle: bool = False circle: bool = False
fallbackUrl: str | None = None fallbackUrl: str | None = None
@dataclass
class ButtonElement(KookCardModelBase): class ButtonElement(KookCardModelBase):
text: str text: str
type: Literal[KookModuleType.BUTTON] = KookModuleType.BUTTON type: str = "button"
theme: ThemeType = "primary" theme: ThemeType = "primary"
value: str = "" value: str = ""
"""当为 link 时,会跳转到 value 代表的链接; """当为 link 时,会跳转到 value 代表的链接;
@@ -153,88 +96,93 @@ class ButtonElement(KookCardModelBase):
AnyElement = PlainTextElement | KmarkdownElement | ImageElement | ButtonElement | str AnyElement = PlainTextElement | KmarkdownElement | ImageElement | ButtonElement | str
@dataclass
class ParagraphStructure(KookCardModelBase): class ParagraphStructure(KookCardModelBase):
fields: list[PlainTextElement | KmarkdownElement] fields: list[PlainTextElement | KmarkdownElement]
type: Literal["paragraph"] = "paragraph" type: str = "paragraph"
cols: int = 1 cols: int = 1
"""范围是 1-3 , 移动端忽略此参数""" """范围是 1-3 , 移动端忽略此参数"""
@dataclass
class HeaderModule(KookCardModelBase): class HeaderModule(KookCardModelBase):
text: PlainTextElement text: PlainTextElement
type: Literal[KookModuleType.HEADER] = KookModuleType.HEADER type: str = "header"
@dataclass
class SectionModule(KookCardModelBase): class SectionModule(KookCardModelBase):
text: PlainTextElement | KmarkdownElement | ParagraphStructure text: PlainTextElement | KmarkdownElement | ParagraphStructure
type: Literal[KookModuleType.SECTION] = KookModuleType.SECTION type: str = "section"
mode: SectionMode = "left" mode: SectionMode = "left"
accessory: ImageElement | ButtonElement | None = None accessory: ImageElement | ButtonElement | None = None
@dataclass
class ImageGroupModule(KookCardModelBase): class ImageGroupModule(KookCardModelBase):
"""1 到多张图片的组合""" """1 到多张图片的组合"""
elements: list[ImageElement] elements: list[ImageElement]
type: Literal[KookModuleType.IMAGE_GROUP] = KookModuleType.IMAGE_GROUP type: str = "image-group"
@dataclass
class ContainerModule(KookCardModelBase): class ContainerModule(KookCardModelBase):
"""1 到多张图片的组合,与图片组模块(ImageGroupModule)不同,图片并不会裁切为正方形。多张图片会纵向排列。""" """1 到多张图片的组合,与图片组模块(ImageGroupModule)不同,图片并不会裁切为正方形。多张图片会纵向排列。"""
elements: list[ImageElement] elements: list[ImageElement]
type: Literal[KookModuleType.CONTAINER] = KookModuleType.CONTAINER type: str = "container"
@dataclass
class ActionGroupModule(KookCardModelBase): class ActionGroupModule(KookCardModelBase):
"""用来放按钮的模块"""
elements: list[ButtonElement] elements: list[ButtonElement]
type: Literal[KookModuleType.ACTION_GROUP] = KookModuleType.ACTION_GROUP type: str = "action-group"
@dataclass
class ContextModule(KookCardModelBase): class ContextModule(KookCardModelBase):
elements: list[PlainTextElement | KmarkdownElement | ImageElement] elements: list[PlainTextElement | KmarkdownElement | ImageElement]
"""最多包含10个元素""" """最多包含10个元素"""
type: Literal[KookModuleType.CONTEXT] = KookModuleType.CONTEXT type: str = "context"
@dataclass
class DividerModule(KookCardModelBase): class DividerModule(KookCardModelBase):
"""展示分割线用的""" type: str = "divider"
type: Literal[KookModuleType.DIVIDER] = KookModuleType.DIVIDER
@dataclass
class FileModule(KookCardModelBase): class FileModule(KookCardModelBase):
src: str src: str
title: str = "" title: str = ""
type: Literal[KookModuleType.FILE, KookModuleType.AUDIO, KookModuleType.VIDEO] = ( type: Literal["file", "audio", "video"] = "file"
KookModuleType.FILE
)
cover: str | None = None cover: str | None = None
"""cover 仅音频有效, 是音频的封面图""" """cover 仅音频有效, 是音频的封面图"""
@dataclass
class CountdownModule(KookCardModelBase): class CountdownModule(KookCardModelBase):
"""startTime 和 endTime 为毫秒时间戳,startTime 和 endTime 不能小于服务器当前时间戳。""" """startTime 和 endTime 为毫秒时间戳,startTime 和 endTime 不能小于服务器当前时间戳。"""
endTime: int endTime: int
"""毫秒时间戳""" """毫秒时间戳"""
type: Literal[KookModuleType.COUNTDOWN] = KookModuleType.COUNTDOWN type: str = "countdown"
startTime: int | None = None startTime: int | None = None
"""毫秒时间戳, 仅当mode为second才有这个字段""" """毫秒时间戳, 仅当mode为second才有这个字段"""
mode: CountdownMode = "day" mode: CountdownMode = "day"
"""mode 主要是倒计时的样式""" """mode 主要是倒计时的样式"""
@dataclass
class InviteModule(KookCardModelBase): class InviteModule(KookCardModelBase):
code: str code: str
"""邀请链接或者邀请码""" """邀请链接或者邀请码"""
type: Literal[KookModuleType.INVITE] = KookModuleType.INVITE type: str = "invite"
# 所有模块的联合类型 # 所有模块的联合类型
AnyModule = Annotated[ AnyModule = (
HeaderModule HeaderModule
| SectionModule | SectionModule
| ImageGroupModule | ImageGroupModule
@@ -244,29 +192,34 @@ AnyModule = Annotated[
| DividerModule | DividerModule
| FileModule | FileModule
| CountdownModule | CountdownModule
| InviteModule, | InviteModule
Field(discriminator="type"), )
]
class KookCardMessage(KookBaseDataClass): class KookCardMessage(BaseModel):
"""卡片定义文档详见 : https://developer.kookapp.cn/doc/cardmessage """卡片定义文档详见 : https://developer.kookapp.cn/doc/cardmessage
此类型不能直接to_json后发送,因为kook要求卡片容器json顶层必须是**列表** 此类型不能直接to_json后发送,因为kook要求卡片容器json顶层必须是**列表**
若要发送卡片消息请使用KookCardMessageContainer 若要发送卡片消息请使用KookCardMessageContainer
""" """
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
type: Literal[KookModuleType.CARD] = KookModuleType.CARD type: str = "card"
theme: ThemeType | None = None theme: ThemeType | None = None
size: SizeType | None = None size: SizeType | None = None
color: str | None = None color: KookCardColor | None = None
"""16 进制色值""" modules: list[AnyModule] = field(default_factory=list)
modules: list[AnyModule] = Field(default_factory=list)
"""单个 card 模块数量不限制,但是一条消息中所有卡片的模块数量之和最多是 50""" """单个 card 模块数量不限制,但是一条消息中所有卡片的模块数量之和最多是 50"""
def add_module(self, module: AnyModule): def add_module(self, module: AnyModule):
self.modules.append(module) self.modules.append(module)
def to_dict(self, exclude_none: bool = True):
"""exclude_none:去掉值为 None 字段,保留结构"""
return self.model_dump(exclude_none=exclude_none)
def to_json(self, indent: int | None = None, ensure_ascii: bool = True):
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=ensure_ascii)
class KookCardMessageContainer(list[KookCardMessage]): class KookCardMessageContainer(list[KookCardMessage]):
"""卡片消息容器(列表),此类型可以直接to_json后发送出去""" """卡片消息容器(列表),此类型可以直接to_json后发送出去"""
@@ -279,227 +232,10 @@ class KookCardMessageContainer(list[KookCardMessage]):
[i.to_dict() for i in self], indent=indent, ensure_ascii=ensure_ascii [i.to_dict() for i in self], indent=indent, ensure_ascii=ensure_ascii
) )
@classmethod
def from_dict(cls, raw_data: list[dict[str, Any]]):
return cls(KookCardMessage.from_dict(item) for item in raw_data)
@dataclass
class OrderMessage(BaseModel): class OrderMessage:
index: int index: int
text: str text: str
type: KookMessageType type: KookMessageType
reply_id: str | int = "" reply_id: str | int = ""
class KookMessageSignal(IntEnum):
"""KOOK WebSocket 信令类型
ws文档: https://developer.kookapp.cn/doc/websocket""" # noqa: W291
MESSAGE = 0
"""server->client 消息(s包含聊天和通知消息)"""
HELLO = 1
"""server->client 客户端连接 ws 时, 服务端返回握手结果"""
PING = 2
"""client->server 心跳,ping"""
PONG = 3
"""server->client 心跳,pong"""
RESUME = 4
"""client->server resume, 恢复会话"""
RECONNECT = 5
"""server->client reconnect, 要求客户端断开当前连接重新连接"""
RESUME_ACK = 6
"""server->client resume ack"""
class KookChannelType(StrEnum):
GROUP = "GROUP"
PERSON = "PERSON"
BROADCAST = "BROADCAST"
class KookAuthor(KookBaseDataClass):
id: str
username: str
identify_num: str
nickname: str
bot: bool
online: bool
avatar: str | None = None
vip_avatar: str | None = None
status: int
roles: list[int] = Field(default_factory=list)
class KookKMarkdown(KookBaseDataClass):
raw_content: str
mention_part: list[Any] = Field(default_factory=list)
mention_role_part: list[Any] = Field(default_factory=list)
class KookExtra(KookBaseDataClass):
type: int | str
code: str | None = None
body: dict[str, Any] | None = None
author: KookAuthor | None = None
kmarkdown: KookKMarkdown | None = None
last_msg_content: str | None = None
mention: list[str] = Field(default_factory=list)
mention_all: bool = False
mention_here: bool = False
class KookMessageEventData(KookBaseDataClass):
signal: Literal[KookMessageSignal.MESSAGE] = Field(
KookMessageSignal.MESSAGE, exclude=True
)
"""only for type hint"""
channel_type: KookChannelType
type: KookMessageType
target_id: str
author_id: str
content: str | dict[str, Any]
msg_id: str
msg_timestamp: int
nonce: str
from_type: int
extra: KookExtra
class KookHelloEventData(KookBaseDataClass):
signal: Literal[KookMessageSignal.HELLO] = Field(
KookMessageSignal.HELLO, exclude=True
)
"""only for type hint"""
code: int
session_id: str
class KookPingEventData(KookBaseDataClass):
signal: Literal[KookMessageSignal.PING] = Field(
KookMessageSignal.PING, exclude=True
)
"""only for type hint"""
class KookPongEventData(KookBaseDataClass):
signal: Literal[KookMessageSignal.PONG] = Field(
KookMessageSignal.PONG, exclude=True
)
"""only for type hint"""
class KookResumeEventData(KookBaseDataClass):
signal: Literal[KookMessageSignal.RESUME] = Field(
KookMessageSignal.RESUME, exclude=True
)
"""only for type hint"""
class KookReconnectEventData(KookBaseDataClass):
signal: Literal[KookMessageSignal.RECONNECT] = Field(
KookMessageSignal.RECONNECT, exclude=True
)
"""only for type hint"""
code: int
err: str
class KookResumeAckEventData(KookBaseDataClass):
signal: Literal[KookMessageSignal.RESUME_ACK] = Field(
KookMessageSignal.RESUME_ACK, exclude=True
)
"""only for type hint"""
session_id: str
class KookWebsocketEvent(KookBaseDataClass):
"""KOOK WebSocket 原始推送结构"""
signal: KookMessageSignal = Field(
..., validation_alias="s", serialization_alias="s"
)
"""信令类型"""
data: Annotated[
KookMessageEventData
| KookHelloEventData
| KookPingEventData
| KookPongEventData
| KookResumeEventData
| KookReconnectEventData
| KookResumeAckEventData
| None,
Field(discriminator="signal"),
] = Field(None, validation_alias="d", serialization_alias="d")
"""数据事件主体,对应原字段是'd'"""
sn: int | None = None
"""消息序号 , 用来确定消息顺序和ws重连时使用
详见ws连接流程文档: https://developer.kookapp.cn/doc/websocket#%E8%BF%9E%E6%8E%A5%E6%B5%81%E7%A8%8B""" # noqa: W291
@model_validator(mode="before")
@classmethod
def _inject_signal_into_data(cls, data: Any) -> Any:
"""在解析前,把外层的 s 同步到内层的 d 中,供 discriminator 使用"""
if isinstance(data, dict):
s_value = data.get("s")
d_value = data.get("d")
if s_value is not None and isinstance(d_value, dict):
d_value["signal"] = s_value
return data
class KookUserTag(KookBaseDataClass):
color: str
bg_color: str
text: str
class KookApiResponseBase(KookBaseDataClass):
code: int
message: str
data: Any
def success(self) -> bool:
return self.code == 0
class KookUserMeData(KookBaseDataClass):
"""USER_ME 接口返回的 'data' 字段主体"""
id: str
username: str
identify_num: str
nickname: str
bot: bool
online: bool
status: int
bot_status: int
avatar: str
vip_avatar: str | None = None
banner: str | None = None
roles: list[Any] = Field(default_factory=list)
is_vip: bool
vip_amp: bool
wealth_level: int
mobile_verified: bool
client_id: str
tag_info: KookUserTag | None = None
class KookUserMeResponse(KookApiResponseBase):
"""USER_ME 完整响应结构"""
data: KookUserMeData
class KookGatewayIndexData(KookBaseDataClass):
url: str
class KookGatewayIndexResponse(KookApiResponseBase):
"""USER_ME 完整响应结构"""
data: KookGatewayIndexData
@@ -34,7 +34,7 @@ from .server import LarkWebhookServer
@register_platform_adapter( @register_platform_adapter(
"lark", "飞书机器人官方 API 适配器", support_streaming_message=True "lark", "飞书机器人官方 API 适配器", support_streaming_message=False
) )
class LarkPlatformAdapter(Platform): class LarkPlatformAdapter(Platform):
def __init__( def __init__(
@@ -491,7 +491,7 @@ class LarkPlatformAdapter(Platform):
name="lark", name="lark",
description="飞书机器人官方 API 适配器", description="飞书机器人官方 API 适配器",
id=cast(str, self.config.get("id")), id=cast(str, self.config.get("id")),
support_streaming_message=True, support_streaming_message=False,
) )
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None: async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None:
@@ -1,4 +1,3 @@
import asyncio
import base64 import base64
import json import json
import os import os
@@ -6,14 +5,6 @@ import uuid
from io import BytesIO from io import BytesIO
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.cardkit.v1 import (
ContentCardElementRequest,
ContentCardElementRequestBody,
CreateCardRequest,
CreateCardRequestBody,
SettingsCardRequest,
SettingsCardRequestBody,
)
from lark_oapi.api.im.v1 import ( from lark_oapi.api.im.v1 import (
CreateFileRequest, CreateFileRequest,
CreateFileRequestBody, CreateFileRequestBody,
@@ -37,7 +28,6 @@ from astrbot.core.utils.media_utils import (
convert_video_format, convert_video_format,
get_media_duration, get_media_duration,
) )
from astrbot.core.utils.metrics import Metric
class LarkMessageEvent(AstrMessageEvent): class LarkMessageEvent(AstrMessageEvent):
@@ -565,257 +555,15 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}") logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
return return
async def _create_streaming_card(self) -> str | None: async def send_streaming(self, generator, use_fallback: bool = False):
"""创建一个开启流式更新模式的卡片实体,返回 card_id。"""
if self.bot.cardkit is None:
logger.error("[Lark] API Client cardkit 模块未初始化")
return None
card_json = {
"schema": "2.0",
"header": {
"title": {"content": "", "tag": "plain_text"},
},
"config": {
"streaming_mode": True,
"summary": {"content": ""},
"streaming_config": {
"print_frequency_ms": {"default": 50},
"print_step": {"default": 2},
"print_strategy": "fast",
},
},
"body": {
"elements": [
{
"tag": "markdown",
"content": "",
"element_id": "markdown_1",
}
]
},
}
request = (
CreateCardRequest.builder()
.request_body(
CreateCardRequestBody.builder()
.type("card_json")
.data(json.dumps(card_json, ensure_ascii=False))
.build()
)
.build()
)
try:
response = await self.bot.cardkit.v1.card.acreate(request)
except Exception as e:
logger.error(f"[Lark] 创建流式卡片实体失败: {e}")
return None
if not response.success():
logger.error(
f"[Lark] 创建流式卡片实体失败({response.code}): {response.msg}"
)
return None
if response.data is None or not response.data.card_id:
logger.error("[Lark] 创建流式卡片实体成功但未返回 card_id")
return None
card_id = response.data.card_id
logger.debug(f"[Lark] 创建流式卡片实体成功: {card_id}")
return card_id
async def _send_card_message(
self,
card_id: str,
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
) -> bool:
"""将卡片实体作为 interactive 消息发送。"""
content = json.dumps(
{"type": "card", "data": {"card_id": card_id}},
ensure_ascii=False,
)
return await self._send_im_message(
self.bot,
content=content,
msg_type="interactive",
reply_message_id=reply_message_id,
receive_id=receive_id,
receive_id_type=receive_id_type,
)
async def _update_streaming_text(
self,
card_id: str,
content: str,
sequence: int,
) -> bool:
"""调用 CardKit 流式更新文本接口,向 markdown_1 组件推送全量文本。"""
if self.bot.cardkit is None:
logger.error("[Lark] API Client cardkit 模块未初始化")
return False
request = (
ContentCardElementRequest.builder()
.card_id(card_id)
.element_id("markdown_1")
.request_body(
ContentCardElementRequestBody.builder()
.content(content)
.sequence(sequence)
.uuid(str(uuid.uuid4()))
.build()
)
.build()
)
try:
response = await self.bot.cardkit.v1.card_element.acontent(request)
except Exception as e:
logger.debug(f"[Lark] 流式更新文本失败 (ignored): {e}")
return False
if not response.success():
logger.debug(f"[Lark] 流式更新文本失败({response.code}): {response.msg}")
return False
return True
async def _close_streaming_mode(
self,
card_id: str,
sequence: int,
) -> None:
"""关闭卡片的流式更新模式,使其可正常转发、摘要恢复。"""
if self.bot.cardkit is None:
logger.error("[Lark] API Client cardkit 模块未初始化")
return
settings_json = json.dumps(
{"config": {"streaming_mode": False}},
ensure_ascii=False,
)
request = (
SettingsCardRequest.builder()
.card_id(card_id)
.request_body(
SettingsCardRequestBody.builder()
.settings(settings_json)
.sequence(sequence)
.uuid(str(uuid.uuid4()))
.build()
)
.build()
)
try:
response = await self.bot.cardkit.v1.card.asettings(request)
except Exception as e:
logger.error(f"[Lark] 关闭流式模式失败: {e}")
return
if not response.success():
logger.error(f"[Lark] 关闭流式模式失败({response.code}): {response.msg}")
else:
logger.debug(f"[Lark] 流式模式已关闭: {card_id}")
async def _fallback_send_streaming(self, generator, use_fallback: bool = False):
"""回退到非流式发送:缓冲全部文本后一次性发送,并保留父类副作用。"""
buffer = None buffer = None
async for chain in generator: async for chain in generator:
if not buffer: if not buffer:
buffer = chain buffer = chain
else: else:
buffer.chain.extend(chain.chain) buffer.chain.extend(chain.chain)
if not buffer:
if buffer: return None
buffer.squash_plain() buffer.squash_plain()
await self.send(buffer) await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
self._has_send_oper = True
async def send_streaming(self, generator, use_fallback: bool = False):
"""使用 CardKit 流式卡片实现打字机效果。
流程创建卡片实体 发送消息 流式更新文本 关闭流式模式
使用解耦发送循环LLM token 到达时只更新 buffer 并唤醒发送协程
发送频率由网络 RTT 自然限流
"""
# Step 1: 创建流式卡片实体
card_id = await self._create_streaming_card()
if not card_id:
logger.warning("[Lark] 无法创建流式卡片,回退到非流式发送")
await self._fallback_send_streaming(generator, use_fallback)
return
# Step 2: 发送卡片消息
sent = await self._send_card_message(
card_id,
reply_message_id=self.message_obj.message_id,
)
if not sent:
logger.error("[Lark] 发送流式卡片消息失败,回退到非流式发送")
await self._fallback_send_streaming(generator, use_fallback)
return
logger.info("[Lark] 流式输出: 使用 CardKit 流式卡片")
# Step 3: 解耦发送循环 (Event-driven, 参考 Telegram Draft 路径)
sequence = 0
delta = ""
last_sent = ""
done = False
text_changed = asyncio.Event()
async def _sender_loop() -> None:
"""信号驱动的文本发送循环,有新内容就发,RTT 自然限流。"""
nonlocal sequence, last_sent
while not done:
await text_changed.wait()
text_changed.clear()
snapshot = delta
if snapshot and snapshot != last_sent:
sequence += 1
ok = await self._update_streaming_text(card_id, snapshot, sequence)
if ok:
last_sent = snapshot
if delta != snapshot:
text_changed.set()
sender_task = asyncio.create_task(_sender_loop())
try:
async for chain in generator:
if not isinstance(chain, MessageChain):
continue
if chain.type == "break":
# 飞书卡片不支持分段,忽略 break
continue
for comp in chain.chain:
if isinstance(comp, Plain):
delta += comp.text
text_changed.set()
finally:
done = True
text_changed.set()
await sender_task
# Step 4: 必要时补发最终文本 + 关闭流式模式
if delta and delta != last_sent:
sequence += 1
await self._update_streaming_text(card_id, delta, sequence)
sequence += 1
await self._close_streaming_mode(card_id, sequence)
# Step 5: 内联父类 send_streaming 的副作用
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
self._has_send_oper = True
@@ -18,7 +18,7 @@ from botpy.types.message import MarkdownPayload, Media
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import File, Image, Plain, Record, Video from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url, file_to_base64 from astrbot.core.utils.io import download_image_by_url, file_to_base64
@@ -47,11 +47,6 @@ _patch_qq_botpy_formdata()
class QQOfficialMessageEvent(AstrMessageEvent): class QQOfficialMessageEvent(AstrMessageEvent):
MARKDOWN_NOT_ALLOWED_ERROR = "不允许发送原生 markdown" MARKDOWN_NOT_ALLOWED_ERROR = "不允许发送原生 markdown"
IMAGE_FILE_TYPE = 1
VIDEO_FILE_TYPE = 2
VOICE_FILE_TYPE = 3
FILE_FILE_TYPE = 4
STREAM_MARKDOWN_NEWLINE_ERROR = "流式消息md分片需要\\n结束"
def __init__( def __init__(
self, self,
@@ -70,71 +65,35 @@ class QQOfficialMessageEvent(AstrMessageEvent):
await self._post_send() await self._post_send()
async def send_streaming(self, generator, use_fallback: bool = False): async def send_streaming(self, generator, use_fallback: bool = False):
"""流式输出仅支持消息列表私聊C2C),其他消息源退化为普通发送""" """流式输出仅支持消息列表私聊"""
# 先标记事件层“已执行发送操作”,避免异常路径遗漏
await super().send_streaming(generator, use_fallback)
# QQ C2C 流式协议:开始/中间分片使用 state=1,结束分片使用 state=10
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False} stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次发送分片的时间 last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 1 # 分片间最短间隔 (秒) throttle_interval = 1 # 编辑消息的间隔时间 (秒)
ret = None ret = None
source = (
self.message_obj.raw_message
) # 提前获取,避免 generator 为空时 NameError
try: try:
async for chain in generator: async for chain in generator:
source = self.message_obj.raw_message source = self.message_obj.raw_message
if not isinstance(source, botpy.message.C2CMessage):
# 非 C2C 场景:直接累积,最后统一发
if not self.send_buffer:
self.send_buffer = chain
else:
self.send_buffer.chain.extend(chain.chain)
continue
# ---- C2C 流式场景 ----
# tool_call break 信号:工具开始执行,先把已有 buffer 以 state=10 结束当前流式段
if chain.type == "break":
if self.send_buffer:
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
ret_id = self._extract_response_message_id(ret)
if ret_id is not None:
stream_payload["id"] = ret_id
# 重置 stream_payload,为下一段流式做准备
stream_payload = {
"state": 1,
"id": None,
"index": 0,
"reset": False,
}
last_edit_time = 0
continue
# 累积内容
if not self.send_buffer: if not self.send_buffer:
self.send_buffer = chain self.send_buffer = chain
else: else:
self.send_buffer.chain.extend(chain.chain) self.send_buffer.chain.extend(chain.chain)
# 节流:按时间间隔发送中间分片 if isinstance(source, botpy.message.C2CMessage):
current_time = asyncio.get_running_loop().time() # 真流式传输
if current_time - last_edit_time >= throttle_interval: current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = cast( ret = cast(
message.Message, message.Message,
await self._post_send(stream=stream_payload), await self._post_send(stream=stream_payload),
) )
stream_payload["index"] += 1 stream_payload["index"] += 1
ret_id = self._extract_response_message_id(ret) stream_payload["id"] = ret["id"]
if ret_id is not None: last_edit_time = asyncio.get_event_loop().time()
stream_payload["id"] = ret_id
last_edit_time = asyncio.get_running_loop().time()
self.send_buffer = None # 清空已发送的分片,避免下次重复发送旧内容
if isinstance(source, botpy.message.C2CMessage): if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,发送 buffer 中剩余内容 # 结束流式对话,并且传输 buffer 中剩余的消息
stream_payload["state"] = 10 stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload) ret = await self._post_send(stream=stream_payload)
else: else:
@@ -142,22 +101,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
except Exception as e: except Exception as e:
logger.error(f"发送流式消息时出错: {e}", exc_info=True) logger.error(f"发送流式消息时出错: {e}", exc_info=True)
# 避免累计内容在异常后被整包重复发送:仅清理缓存,不做非流式整包兜底
# 如需兜底,应该只发送未发送 delta(后续可继续优化)
self.send_buffer = None self.send_buffer = None
return None return await super().send_streaming(generator, use_fallback)
@staticmethod
def _extract_response_message_id(ret) -> str | None:
"""兼容 qq-botpy 返回 Message 对象或 dict 两种形态。"""
if ret is None:
return None
if isinstance(ret, dict):
ret_id = ret.get("id")
return str(ret_id) if ret_id is not None else None
ret_id = getattr(ret, "id", None)
return str(ret_id) if ret_id is not None else None
async def _post_send(self, stream: dict | None = None): async def _post_send(self, stream: dict | None = None):
if not self.send_buffer: if not self.send_buffer:
@@ -180,37 +126,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64, image_base64,
image_path, image_path,
record_file_path, record_file_path,
video_file_source,
file_source,
file_name,
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer) ) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
# C2C 流式仅用于文本分片,富媒体时降级为普通发送,避免平台侧流式校验报错。
if stream and (image_base64 or record_file_path):
logger.debug("[QQOfficial] 检测到富媒体,降级为非流式发送。")
stream = None
if ( if (
not plain_text not plain_text
and not image_base64 and not image_base64
and not image_path and not image_path
and not record_file_path and not record_file_path
and not video_file_source
and not file_source
): ):
return None return None
# QQ C2C 流式 API 说明:
# - 开始/中间分片(state=1):增量追加内容,不需要 \n(加了会导致强制换行)
# - 最终分片(state=10):结束流,content 必须以 \n 结尾(QQ API 要求)
if (
stream
and stream.get("state") == 10
and plain_text
and not plain_text.endswith("\n")
):
plain_text = plain_text + "\n"
payload: dict = { payload: dict = {
# "content": plain_text, # "content": plain_text,
"markdown": MarkdownPayload(content=plain_text) if plain_text else None, "markdown": MarkdownPayload(content=plain_text) if plain_text else None,
@@ -232,7 +157,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
self.IMAGE_FILE_TYPE, 1,
group_openid=source.group_openid, group_openid=source.group_openid,
) )
payload["media"] = media payload["media"] = media
@@ -240,35 +165,11 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload.pop("markdown", None) payload.pop("markdown", None)
payload["content"] = plain_text or None payload["content"] = plain_text or None
if record_file_path: # group record msg if record_file_path: # group record msg
media = await self.upload_group_and_c2c_media( media = await self.upload_group_and_c2c_record(
record_file_path, record_file_path,
self.VOICE_FILE_TYPE, 3,
group_openid=source.group_openid, group_openid=source.group_openid,
) )
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if video_file_source:
media = await self.upload_group_and_c2c_media(
video_file_source,
self.VIDEO_FILE_TYPE,
group_openid=source.group_openid,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if file_source:
media = await self.upload_group_and_c2c_media(
file_source,
self.FILE_FILE_TYPE,
file_name=file_name,
group_openid=source.group_openid,
)
if media:
payload["media"] = media payload["media"] = media
payload["msg_type"] = 7 payload["msg_type"] = 7
payload.pop("markdown", None) payload.pop("markdown", None)
@@ -280,14 +181,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
), ),
payload=payload, payload=payload,
plain_text=plain_text, plain_text=plain_text,
stream=stream,
) )
case botpy.message.C2CMessage(): case botpy.message.C2CMessage():
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
self.IMAGE_FILE_TYPE, 1,
openid=source.author.user_openid, openid=source.author.user_openid,
) )
payload["media"] = media payload["media"] = media
@@ -295,35 +195,11 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload.pop("markdown", None) payload.pop("markdown", None)
payload["content"] = plain_text or None payload["content"] = plain_text or None
if record_file_path: # c2c record if record_file_path: # c2c record
media = await self.upload_group_and_c2c_media( media = await self.upload_group_and_c2c_record(
record_file_path, record_file_path,
self.VOICE_FILE_TYPE, 3,
openid=source.author.user_openid, openid=source.author.user_openid,
) )
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if video_file_source:
media = await self.upload_group_and_c2c_media(
video_file_source,
self.VIDEO_FILE_TYPE,
openid=source.author.user_openid,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if file_source:
media = await self.upload_group_and_c2c_media(
file_source,
self.FILE_FILE_TYPE,
file_name=file_name,
openid=source.author.user_openid,
)
if media:
payload["media"] = media payload["media"] = media
payload["msg_type"] = 7 payload["msg_type"] = 7
payload.pop("markdown", None) payload.pop("markdown", None)
@@ -337,7 +213,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
), ),
payload=payload, payload=payload,
plain_text=plain_text, plain_text=plain_text,
stream=stream,
) )
else: else:
ret = await self._send_with_markdown_fallback( ret = await self._send_with_markdown_fallback(
@@ -347,7 +222,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
), ),
payload=payload, payload=payload,
plain_text=plain_text, plain_text=plain_text,
stream=stream,
) )
logger.debug(f"Message sent to C2C: {ret}") logger.debug(f"Message sent to C2C: {ret}")
@@ -363,7 +237,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
), ),
payload=payload, payload=payload,
plain_text=plain_text, plain_text=plain_text,
stream=stream,
) )
case botpy.message.DirectMessage(): case botpy.message.DirectMessage():
@@ -378,7 +251,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
), ),
payload=payload, payload=payload,
plain_text=plain_text, plain_text=plain_text,
stream=stream,
) )
case _: case _:
@@ -395,31 +267,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
send_func, send_func,
payload: dict, payload: dict,
plain_text: str, plain_text: str,
stream: dict | None = None,
): ):
try: try:
return await send_func(payload) return await send_func(payload)
except botpy.errors.ServerError as err: except botpy.errors.ServerError as err:
# QQ 流式 markdown 分片校验:内容必须以换行结尾。
# 某些边界场景服务端仍可能判定失败,这里做一次修正重试。
if stream and self.STREAM_MARKDOWN_NEWLINE_ERROR in str(err):
retry_payload = payload.copy()
markdown_payload = retry_payload.get("markdown")
if isinstance(markdown_payload, dict):
md_content = cast(str, markdown_payload.get("content", "") or "")
if md_content and not md_content.endswith("\n"):
retry_payload["markdown"] = {"content": md_content + "\n"}
content = cast(str | None, retry_payload.get("content"))
if content and not content.endswith("\n"):
retry_payload["content"] = content + "\n"
logger.warning(
"[QQOfficial] 流式 markdown 分片换行校验失败,已修正后重试一次。"
)
return await send_func(retry_payload)
if ( if (
self.MARKDOWN_NOT_ALLOWED_ERROR not in str(err) self.MARKDOWN_NOT_ALLOWED_ERROR not in str(err)
or not payload.get("markdown") or not payload.get("markdown")
@@ -431,14 +282,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"[QQOfficial] markdown 发送被拒绝,回退到 content 模式重试。" "[QQOfficial] markdown 发送被拒绝,回退到 content 模式重试。"
) )
fallback_payload = payload.copy() fallback_payload = payload.copy()
fallback_payload.pop("markdown", None) fallback_payload["markdown"] = None
fallback_payload["content"] = plain_text fallback_payload["content"] = plain_text
if fallback_payload.get("msg_type") == 2: if fallback_payload.get("msg_type") == 2:
fallback_payload["msg_type"] = 0 fallback_payload["msg_type"] = 0
if stream:
fallback_content = cast(str, fallback_payload.get("content") or "")
if fallback_content and not fallback_content.endswith("\n"):
fallback_payload["content"] = fallback_content + "\n"
return await send_func(fallback_payload) return await send_func(fallback_payload)
async def upload_group_and_c2c_image( async def upload_group_and_c2c_image(
@@ -480,19 +327,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ttl=result.get("ttl", 0), ttl=result.get("ttl", 0),
) )
async def upload_group_and_c2c_media( async def upload_group_and_c2c_record(
self, self,
file_source: str, file_source: str,
file_type: int, file_type: int,
srv_send_msg: bool = False, srv_send_msg: bool = False,
file_name: str | None = None,
**kwargs, **kwargs,
) -> Media | None: ) -> Media | None:
"""上传媒体文件""" """上传媒体文件"""
# 构建基础payload # 构建基础payload
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
if file_name:
payload["file_name"] = file_name
# 处理文件数据 # 处理文件数据
if os.path.exists(file_source): if os.path.exists(file_source):
@@ -556,21 +400,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
) -> message.Message: ) -> message.Message:
payload = locals() payload = locals()
payload.pop("self", None) payload.pop("self", None)
# QQ API does not accept stream.id=None; remove it when not yet assigned
if "stream" in payload and payload["stream"] is not None:
stream_data = dict(payload["stream"])
if stream_data.get("id") is None:
stream_data.pop("id", None)
payload["stream"] = stream_data
route = Route("POST", "/v2/users/{openid}/messages", openid=openid) route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
result = await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if result is None:
logger.warning("[QQOfficial] post_c2c_message: API 返回 None,跳过本次发送")
return None
if not isinstance(result, dict): if not isinstance(result, dict):
logger.error(f"[QQOfficial] post_c2c_message: 响应不是 dict: {result}") raise RuntimeError(
return None f"Failed to post c2c message, response is not dict: {result}"
)
return message.Message(**result) return message.Message(**result)
@@ -580,9 +416,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = None # only one img supported image_base64 = None # only one img supported
image_file_path = None image_file_path = None
record_file_path = None record_file_path = None
video_file_source = None
file_source = None
file_name = None
for i in message.chain: for i in message.chain:
if isinstance(i, Plain): if isinstance(i, Plain):
plain_text += i.text plain_text += i.text
@@ -621,30 +454,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
except Exception as e: except Exception as e:
logger.error(f"处理语音时出错: {e}") logger.error(f"处理语音时出错: {e}")
record_file_path = None record_file_path = None
elif isinstance(i, Video) and not video_file_source:
if i.file.startswith("file:///"):
video_file_source = i.file[8:]
else:
video_file_source = i.file
elif isinstance(i, File) and not file_source:
file_name = i.name
if i.file_:
file_path = i.file_
if file_path.startswith("file:///"):
file_path = file_path[8:]
elif file_path.startswith("file://"):
file_path = file_path[7:]
file_source = file_path
elif i.url:
file_source = i.url
else: else:
logger.debug(f"qq_official 忽略 {i.type}") logger.debug(f"qq_official 忽略 {i.type}")
return ( return plain_text, image_base64, image_file_path, record_file_path
plain_text,
image_base64,
image_file_path,
record_file_path,
video_file_source,
file_source,
file_name,
)
@@ -3,10 +3,8 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import random
import time import time
from types import SimpleNamespace from typing import cast
from typing import Any, cast
import botpy import botpy
import botpy.message import botpy.message
@@ -14,7 +12,7 @@ from botpy import Client
from astrbot import logger from astrbot import logger
from astrbot.api.event import MessageChain from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, File, Image, Plain, Record, Video from astrbot.api.message_components import At, File, Image, Plain
from astrbot.api.platform import ( from astrbot.api.platform import (
AstrBotMessage, AstrBotMessage,
MessageMember, MessageMember,
@@ -48,7 +46,6 @@ class botClient(Client):
) )
abm.group_id = cast(str, message.group_openid) abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id abm.session_id = abm.group_id
self.platform.remember_session_scene(abm.session_id, "group")
self._commit(abm) self._commit(abm)
# 收到频道消息 # 收到频道消息
@@ -59,7 +56,6 @@ class botClient(Client):
) )
abm.group_id = message.channel_id abm.group_id = message.channel_id
abm.session_id = abm.group_id abm.session_id = abm.group_id
self.platform.remember_session_scene(abm.session_id, "channel")
self._commit(abm) self._commit(abm)
# 收到私聊消息 # 收到私聊消息
@@ -71,7 +67,6 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE, MessageType.FRIEND_MESSAGE,
) )
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm) self._commit(abm)
# 收到 C2C 消息 # 收到 C2C 消息
@@ -81,11 +76,9 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE, MessageType.FRIEND_MESSAGE,
) )
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm) self._commit(abm)
def _commit(self, abm: AstrBotMessage) -> None: def _commit(self, abm: AstrBotMessage) -> None:
self.platform.remember_session_message_id(abm.session_id, abm.message_id)
self.platform.commit_event( self.platform.commit_event(
QQOfficialMessageEvent( QQOfficialMessageEvent(
abm.message_str, abm.message_str,
@@ -131,9 +124,6 @@ class QQOfficialPlatformAdapter(Platform):
self.client.set_platform(self) self.client.set_platform(self)
self._session_last_message_id: dict[str, str] = {}
self._session_scene: dict[str, str] = {}
self.test_mode = os.environ.get("TEST_MODE", "off") == "on" self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
async def send_by_session( async def send_by_session(
@@ -141,191 +131,14 @@ class QQOfficialPlatformAdapter(Platform):
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
) -> None: ) -> None:
await self._send_by_session_common(session, message_chain) raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
async def _send_by_session_common(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
(
plain_text,
image_base64,
image_path,
record_file_path,
video_file_source,
file_source,
file_name,
) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
if (
not plain_text
and not image_path
and not image_base64
and not record_file_path
and not video_file_source
and not file_source
):
return
msg_id = self._session_last_message_id.get(session.session_id)
if not msg_id:
logger.warning(
"[QQOfficial] No cached msg_id for session: %s, skip send_by_session",
session.session_id,
)
return
payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
ret: Any = None
send_helper = SimpleNamespace(bot=self.client)
if session.message_type == MessageType.GROUP_MESSAGE:
scene = self._session_scene.get(session.session_id)
if scene == "group":
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
QQOfficialMessageEvent.IMAGE_FILE_TYPE,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
record_file_path,
QQOfficialMessageEvent.VOICE_FILE_TYPE,
group_openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
if video_file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
video_file_source,
QQOfficialMessageEvent.VIDEO_FILE_TYPE,
group_openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("msg_id", None)
if file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
file_source,
QQOfficialMessageEvent.FILE_FILE_TYPE,
file_name=file_name,
group_openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("msg_id", None)
ret = await self.client.api.post_group_message(
group_openid=session.session_id,
**payload,
)
else:
if image_path:
payload["file_image"] = image_path
ret = await self.client.api.post_message(
channel_id=session.session_id,
**payload,
)
elif session.message_type == MessageType.FRIEND_MESSAGE:
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
QQOfficialMessageEvent.IMAGE_FILE_TYPE,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
record_file_path,
QQOfficialMessageEvent.VOICE_FILE_TYPE,
openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
if video_file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
video_file_source,
QQOfficialMessageEvent.VIDEO_FILE_TYPE,
openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
# QQ API rejects msg_id for media (video/file) messages sent
# via the proactive tool-call path; remove it to avoid 越权 error.
payload.pop("msg_id", None)
if file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
file_source,
QQOfficialMessageEvent.FILE_FILE_TYPE,
file_name=file_name,
openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("msg_id", None)
ret = await QQOfficialMessageEvent.post_c2c_message(
send_helper, # type: ignore
openid=session.session_id,
**payload,
)
else:
logger.warning(
"[QQOfficial] Unsupported message type for send_by_session: %s",
session.message_type,
)
return
sent_message_id = self._extract_message_id(ret)
if sent_message_id:
self.remember_session_message_id(session.session_id, sent_message_id)
await super().send_by_session(session, message_chain)
def remember_session_message_id(self, session_id: str, message_id: str) -> None:
if not session_id or not message_id:
return
self._session_last_message_id[session_id] = message_id
def remember_session_scene(self, session_id: str, scene: str) -> None:
if not session_id or not scene:
return
self._session_scene[session_id] = scene
def _extract_message_id(self, ret: Any) -> str | None:
if isinstance(ret, dict):
message_id = ret.get("id")
return str(message_id) if message_id else None
message_id = getattr(ret, "id", None)
if message_id:
return str(message_id)
return None
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
name="qq_official", name="qq_official",
description="QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=cast(str, self.config.get("id")), id=cast(str, self.config.get("id")),
support_proactive_message=True, support_proactive_message=False,
) )
@staticmethod @staticmethod
@@ -345,10 +158,7 @@ class QQOfficialPlatformAdapter(Platform):
return return
for attachment in attachments: for attachment in attachments:
content_type = cast( content_type = cast(str, getattr(attachment, "content_type", "") or "")
str,
getattr(attachment, "content_type", "") or "",
).lower()
url = QQOfficialPlatformAdapter._normalize_attachment_url( url = QQOfficialPlatformAdapter._normalize_attachment_url(
cast(str | None, getattr(attachment, "url", None)) cast(str | None, getattr(attachment, "url", None))
) )
@@ -364,74 +174,8 @@ class QQOfficialPlatformAdapter(Platform):
or getattr(attachment, "name", None) or getattr(attachment, "name", None)
or "attachment", or "attachment",
) )
ext = os.path.splitext(filename)[1].lower()
image_exts = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
audio_exts = {
".mp3",
".wav",
".ogg",
".m4a",
".amr",
".silk",
}
video_exts = {
".mp4",
".mov",
".avi",
".mkv",
".webm",
}
if content_type.startswith("audio") or ext in audio_exts:
msg.append(Record.fromURL(url))
elif content_type.startswith("video") or ext in video_exts:
msg.append(Video.fromURL(url))
elif content_type.startswith("image") or ext in image_exts:
msg.append(Image.fromURL(url))
else:
msg.append(File(name=filename, file=url, url=url)) msg.append(File(name=filename, file=url, url=url))
@staticmethod
def _parse_face_message(content: str) -> str:
"""Parse QQ official face message format and convert to readable text.
QQ official face message format:
<faceType=4,faceId="",ext="eyJ0ZXh0IjoiW+a7oeWktOmXruWPt10ifQ==">
The ext field contains base64-encoded JSON with a 'text' field
describing the emoji (e.g., '[满头问号]').
Args:
content: The message content that may contain face tags.
Returns:
Content with face tags replaced by readable emoji descriptions.
"""
import base64
import json
import re
def replace_face(match):
face_tag = match.group(0)
# Extract ext field from the face tag
ext_match = re.search(r'ext="([^"]*)"', face_tag)
if ext_match:
try:
ext_encoded = ext_match.group(1)
# Decode base64 and parse JSON
ext_decoded = base64.b64decode(ext_encoded).decode("utf-8")
ext_data = json.loads(ext_decoded)
emoji_text = ext_data.get("text", "")
if emoji_text:
return f"[表情:{emoji_text}]"
except Exception:
pass
# Fallback if parsing fails
return "[表情]"
# Match face tags: <faceType=...>
return re.sub(r"<faceType=\d+[^>]*>", replace_face, content)
@staticmethod @staticmethod
def _parse_from_qqofficial( def _parse_from_qqofficial(
message: botpy.message.Message message: botpy.message.Message
@@ -457,10 +201,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.group_id = message.group_openid abm.group_id = message.group_openid
else: else:
abm.sender = MessageMember(message.author.user_openid, "") abm.sender = MessageMember(message.author.user_openid, "")
# Parse face messages to readable text abm.message_str = message.content.strip()
abm.message_str = QQOfficialPlatformAdapter._parse_face_message(
message.content.strip()
)
abm.self_id = "unknown_selfid" abm.self_id = "unknown_selfid"
msg.append(At(qq="qq_official")) msg.append(At(qq="qq_official"))
msg.append(Plain(abm.message_str)) msg.append(Plain(abm.message_str))
@@ -476,12 +217,10 @@ class QQOfficialPlatformAdapter(Platform):
else: else:
abm.self_id = "" abm.self_id = ""
plain_content = QQOfficialPlatformAdapter._parse_face_message( plain_content = message.content.replace(
message.content.replace(
"<@!" + str(abm.self_id) + ">", "<@!" + str(abm.self_id) + ">",
"", "",
).strip() ).strip()
)
QQOfficialPlatformAdapter._append_attachments(msg, message.attachments) QQOfficialPlatformAdapter._append_attachments(msg, message.attachments)
abm.message = msg abm.message = msg
@@ -1,5 +1,7 @@
import asyncio import asyncio
import logging import logging
import random
from types import SimpleNamespace
from typing import Any, cast from typing import Any, cast
import botpy import botpy
@@ -13,6 +15,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter from ...register import register_platform_adapter
from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
from .qo_webhook_event import QQOfficialWebhookMessageEvent from .qo_webhook_event import QQOfficialWebhookMessageEvent
from .qo_webhook_server import QQOfficialWebhook from .qo_webhook_server import QQOfficialWebhook
@@ -120,11 +123,95 @@ class QQOfficialWebhookPlatformAdapter(Platform):
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
) -> None: ) -> None:
await QQOfficialPlatformAdapter._send_by_session_common( (
cast(Any, self), plain_text,
session, image_base64,
message_chain, image_path,
record_file_path,
) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
if not plain_text and not image_path:
return
msg_id = self._session_last_message_id.get(session.session_id)
if not msg_id:
logger.warning(
"[QQOfficialWebhook] No cached msg_id for session: %s, skip send_by_session",
session.session_id,
) )
return
payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
ret: Any = None
send_helper = SimpleNamespace(bot=self.client)
if session.message_type == MessageType.GROUP_MESSAGE:
scene = self._session_scene.get(session.session_id)
if scene == "group":
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
1,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
send_helper, # type: ignore
record_file_path,
3,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
ret = await self.client.api.post_group_message(
group_openid=session.session_id,
**payload,
)
else:
if image_path:
payload["file_image"] = image_path
ret = await self.client.api.post_message(
channel_id=session.session_id,
**payload,
)
elif session.message_type == MessageType.FRIEND_MESSAGE:
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
1,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
send_helper, # type: ignore
record_file_path,
3,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
ret = await QQOfficialMessageEvent.post_c2c_message(
send_helper, # type: ignore
openid=session.session_id,
**payload,
)
else:
logger.warning(
"[QQOfficialWebhook] Unsupported message type for send_by_session: %s",
session.message_type,
)
return
sent_message_id = self._extract_message_id(ret)
if sent_message_id:
self.remember_session_message_id(session.session_id, sent_message_id)
await super().send_by_session(session, message_chain)
def remember_session_message_id(self, session_id: str, message_id: str) -> None: def remember_session_message_id(self, session_id: str, message_id: str) -> None:
if not session_id or not message_id: if not session_id or not message_id:
@@ -1,6 +1,5 @@
import asyncio import asyncio
import logging import logging
import time
from typing import cast from typing import cast
import quart import quart
@@ -40,9 +39,6 @@ class QQOfficialWebhook:
self.client = botpy_client self.client = botpy_client
self.event_queue = event_queue self.event_queue = event_queue
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
# Deduplication cache for webhook retry callbacks.
self._seen_event_ids: dict[str, float] = {}
self._dedup_ttl: int = 60 # seconds
async def initialize(self) -> None: async def initialize(self) -> None:
logger.info("正在登录到 QQ 官方机器人...") logger.info("正在登录到 QQ 官方机器人...")
@@ -59,7 +55,7 @@ class QQOfficialWebhook:
max_async=1, max_async=1,
connect=bot_connect, connect=bot_connect,
dispatch=self.client.ws_dispatch, dispatch=self.client.ws_dispatch,
loop=asyncio.get_running_loop(), loop=asyncio.get_event_loop(),
api=self.api, api=self.api,
) )
@@ -110,22 +106,6 @@ class QQOfficialWebhook:
print(signed) print(signed)
return signed return signed
event_id = msg.get("id")
if event_id:
now = time.monotonic()
# Lazily evict expired entries to prevent unbounded growth.
expired = [
k
for k, ts in self._seen_event_ids.items()
if now - ts > self._dedup_ttl
]
for k in expired:
del self._seen_event_ids[k]
if event_id in self._seen_event_ids:
logger.debug(f"Duplicate webhook event {event_id!r}, skipping.")
return {"opcode": 12}
self._seen_event_ids[event_id] = now
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT: if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
event = msg["t"].lower() event = msg["t"].lower()
try: try:
@@ -289,8 +289,8 @@ class TelegramPlatformAdapter(Platform):
else: else:
message.type = MessageType.GROUP_MESSAGE message.type = MessageType.GROUP_MESSAGE
message.group_id = str(update.message.chat.id) message.group_id = str(update.message.chat.id)
if update.message.is_topic_message and update.message.message_thread_id: if update.message.message_thread_id:
# Telegram Topic Group: include thread id to isolate per-topic sessions. # Topic Group
message.group_id += "#" + str(update.message.message_thread_id) message.group_id += "#" + str(update.message.message_thread_id)
message.session_id = message.group_id message.session_id = message.group_id
message.message_id = str(update.message.message_id) message.message_id = str(update.message.message_id)
@@ -1,7 +1,6 @@
import asyncio import asyncio
import os import os
import re import re
from collections.abc import Callable
from typing import Any, cast from typing import Any, cast
import telegramify_markdown import telegramify_markdown
@@ -22,17 +21,6 @@ from astrbot.api.message_components import (
Video, Video,
) )
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.metrics import Metric
def _is_gif(path: str) -> bool:
if path.lower().endswith(".gif"):
return True
try:
with open(path, "rb") as f:
return f.read(6) in (b"GIF87a", b"GIF89a")
except OSError:
return False
class TelegramPlatformEvent(AstrMessageEvent): class TelegramPlatformEvent(AstrMessageEvent):
@@ -46,20 +34,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
"word": re.compile(r"\s"), "word": re.compile(r"\s"),
} }
# sendMessageDraft 的 draft_id 类级递增计数器
_TELEGRAM_DRAFT_ID_MAX = 2_147_483_647
_next_draft_id: int = 0
@classmethod
def _allocate_draft_id(cls) -> int:
"""分配一个递增的 draft_id,溢出时归 1。"""
cls._next_draft_id = (
1
if cls._next_draft_id >= cls._TELEGRAM_DRAFT_ID_MAX
else cls._next_draft_id + 1
)
return cls._next_draft_id
# 消息类型到 chat action 的映射,用于优先级判断 # 消息类型到 chat action 的映射,用于优先级判断
ACTION_BY_TYPE: dict[type, str] = { ACTION_BY_TYPE: dict[type, str] = {
Record: ChatAction.UPLOAD_VOICE, Record: ChatAction.UPLOAD_VOICE,
@@ -288,6 +262,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
try: try:
md_text = telegramify_markdown.markdownify( md_text = telegramify_markdown.markdownify(
chunk, chunk,
normalize_whitespace=False,
) )
await client.send_message( await client.send_message(
text=md_text, text=md_text,
@@ -301,13 +276,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
await client.send_message(text=chunk, **cast(Any, payload)) await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image): elif isinstance(i, Image):
image_path = await i.convert_to_file_path() image_path = await i.convert_to_file_path()
if _is_gif(image_path): await client.send_photo(photo=image_path, **cast(Any, payload))
send_coro = client.send_animation
media_kwarg = {"animation": image_path}
else:
send_coro = client.send_photo
media_kwarg = {"photo": image_path}
await send_coro(**media_kwarg, **cast(Any, payload))
elif isinstance(i, File): elif isinstance(i, File):
path = await i.get_file() path = await i.get_file()
name = i.name or os.path.basename(path) name = i.name or os.path.basename(path)
@@ -370,125 +339,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
except Exception as e: except Exception as e:
logger.error(f"[Telegram] 添加反应失败: {e}") logger.error(f"[Telegram] 添加反应失败: {e}")
async def _send_message_draft(
self,
chat_id: str,
draft_id: int,
text: str,
message_thread_id: str | None = None,
parse_mode: str | None = None,
) -> None:
"""通过 Bot.send_message_draft 发送草稿消息(流式推送部分消息)。
API 仅支持私聊
Args:
chat_id: 目标私聊的 chat_id
draft_id: 草稿唯一标识非零整数相同 draft_id 的变更会以动画展示
text: 消息文本1-4096 字符
message_thread_id: 可选目标消息线程 ID
parse_mode: 可选消息文本的解析模式
"""
kwargs: dict[str, Any] = {}
if message_thread_id:
kwargs["message_thread_id"] = int(message_thread_id)
if parse_mode:
kwargs["parse_mode"] = parse_mode
try:
logger.debug(
f"[Telegram] sendMessageDraft: chat_id={chat_id}, draft_id={draft_id}, text_len={len(text)}"
)
await self.client.send_message_draft(
chat_id=int(chat_id),
draft_id=draft_id,
text=text,
**kwargs,
)
except Exception as e:
logger.warning(f"[Telegram] sendMessageDraft 失败: {e!s}")
async def _process_chain_items(
self,
chain: MessageChain,
payload: dict[str, Any],
user_name: str,
message_thread_id: str | None,
on_text: Callable[[str], None],
) -> None:
"""处理 MessageChain 中的各类组件,文本通过 on_text 回调追加,媒体直接发送。"""
for i in chain.chain:
if isinstance(i, Plain):
on_text(i.text)
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
if _is_gif(image_path):
action = ChatAction.UPLOAD_VIDEO
send_coro = self.client.send_animation
media_kwarg = {"animation": image_path}
else:
action = ChatAction.UPLOAD_PHOTO
send_coro = self.client.send_photo
media_kwarg = {"photo": image_path}
await self._send_media_with_action(
self.client,
action,
send_coro,
user_name=user_name,
**media_kwarg,
**cast(Any, payload),
)
elif isinstance(i, File):
path = await i.get_file()
name = i.name or os.path.basename(path)
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_DOCUMENT,
self.client.send_document,
user_name=user_name,
document=path,
filename=name,
**cast(Any, payload),
)
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self._send_voice_with_fallback(
self.client,
path,
payload,
caption=i.text or None,
user_name=user_name,
message_thread_id=message_thread_id,
use_media_action=True,
)
elif isinstance(i, Video):
path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_VIDEO,
self.client.send_video,
user_name=user_name,
video=path,
**cast(Any, payload),
)
else:
logger.warning(f"不支持的消息类型: {type(i)}")
async def _send_final_segment(self, delta: str, payload: dict[str, Any]) -> None:
"""将累积文本作为 MarkdownV2 真实消息发送,失败时回退到纯文本。"""
try:
markdown_text = telegramify_markdown.markdownify(
delta,
)
await self.client.send_message(
text=markdown_text,
parse_mode="MarkdownV2",
**cast(Any, payload),
)
except Exception as e:
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
await self.client.send_message(text=delta, **cast(Any, payload))
async def send_streaming(self, generator, use_fallback: bool = False): async def send_streaming(self, generator, use_fallback: bool = False):
message_thread_id = None message_thread_id = None
@@ -506,137 +356,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
if message_thread_id: if message_thread_id:
payload["message_thread_id"] = message_thread_id payload["message_thread_id"] = message_thread_id
# sendMessageDraft 仅支持私聊(显式检查 FRIEND_MESSAGE
is_private = self.get_message_type() == MessageType.FRIEND_MESSAGE
if is_private:
logger.info("[Telegram] 流式输出: 使用 sendMessageDraft (私聊)")
await self._send_streaming_draft(
user_name, message_thread_id, payload, generator
)
else:
logger.info("[Telegram] 流式输出: 使用 edit_message_text fallback (群聊)")
await self._send_streaming_edit(
user_name, message_thread_id, payload, generator
)
# 内联父类 send_streaming 的副作用(避免传入已消费的 generator)
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name),
)
self._has_send_oper = True
async def _send_streaming_draft(
self,
user_name: str,
message_thread_id: str | None,
payload: dict[str, Any],
generator,
) -> None:
"""使用 sendMessageDraft API 进行流式推送(私聊专用)。
流式过程中使用 sendMessageDraft 推送草稿动画
流式结束后发送一条真实消息保留最终内容draft 是临时的会消失
使用信号驱动的发送循环每次有新 token 到达时唤醒发送
发送频率由网络 RTT 自然限制最多一个请求 in-flight
"""
draft_id = self._allocate_draft_id()
delta = ""
last_sent_text = ""
done = False # 信号:生成器已结束
text_changed = asyncio.Event() # 有新 token 到达时触发
async def _draft_sender_loop() -> None:
"""信号驱动的草稿发送循环,有新内容就发,RTT 自然限流。"""
nonlocal last_sent_text
while not done:
await text_changed.wait()
text_changed.clear()
# 发送最新的缓冲区内容(MarkdownV2 渲染,与真实消息一致)
if delta and delta != last_sent_text:
draft_text = delta[: self.MAX_MESSAGE_LENGTH]
if draft_text != last_sent_text:
try:
md = telegramify_markdown.markdownify(
draft_text,
)
await self._send_message_draft(
user_name,
draft_id,
md,
message_thread_id,
parse_mode="MarkdownV2",
)
last_sent_text = draft_text
except Exception:
# markdownify 对未闭合语法可能失败,回退纯文本
try:
await self._send_message_draft(
user_name,
draft_id,
draft_text,
message_thread_id,
)
last_sent_text = draft_text
except Exception as e2:
logger.debug(
f"[Telegram] sendMessageDraft failed (ignored): {e2!s}"
)
sender_task = asyncio.create_task(_draft_sender_loop())
def _append_text(t: str) -> None:
nonlocal delta
delta += t
text_changed.set() # 唤醒发送循环
try:
async for chain in generator:
if not isinstance(chain, MessageChain):
continue
if chain.type == "break":
# 分割符:发送真实消息保留内容,重置缓冲区
if delta:
# 用 emoji 清空 draft 显示,避免 draft 和真实消息同时可见
await self._send_message_draft(
user_name,
draft_id,
"\u23f3",
message_thread_id,
)
await self._send_final_segment(delta, payload)
delta = ""
last_sent_text = ""
draft_id = self._allocate_draft_id()
continue
await self._process_chain_items(
chain, payload, user_name, message_thread_id, _append_text
)
finally:
done = True
text_changed.set() # 唤醒循环使其退出
await sender_task
# 流式结束:用 emoji 清空 draft,然后发真实消息持久化
if delta:
await self._send_message_draft(
user_name,
draft_id,
"\u23f3",
message_thread_id,
)
await self._send_final_segment(delta, payload)
async def _send_streaming_edit(
self,
user_name: str,
message_thread_id: str | None,
payload: dict[str, Any],
generator,
) -> None:
"""使用 send_message + edit_message_text 进行流式推送(群聊 fallback)。"""
delta = "" delta = ""
current_content = "" current_content = ""
message_id = None message_id = None
@@ -647,16 +366,10 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 发送初始 typing 状态 # 发送初始 typing 状态
await self._ensure_typing(user_name, message_thread_id) await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = asyncio.get_running_loop().time() last_chat_action_time = asyncio.get_event_loop().time()
def _append_text(t: str) -> None:
nonlocal delta
delta += t
async for chain in generator: async for chain in generator:
if not isinstance(chain, MessageChain): if isinstance(chain, MessageChain):
continue
if chain.type == "break": if chain.type == "break":
# 分割符 # 分割符
if message_id: if message_id:
@@ -668,24 +381,78 @@ class TelegramPlatformEvent(AstrMessageEvent):
) )
except Exception as e: except Exception as e:
logger.warning(f"编辑消息失败(streaming-break): {e!s}") logger.warning(f"编辑消息失败(streaming-break): {e!s}")
message_id = None message_id = None # 重置消息 ID
delta = "" delta = "" # 重置 delta
continue continue
await self._process_chain_items( # 处理消息链中的每个组件
chain, payload, user_name, message_thread_id, _append_text for i in chain.chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_PHOTO,
self.client.send_photo,
user_name=user_name,
photo=image_path,
**cast(Any, payload),
) )
continue
elif isinstance(i, File):
path = await i.get_file()
name = i.name or os.path.basename(path)
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_DOCUMENT,
self.client.send_document,
user_name=user_name,
document=path,
filename=name,
**cast(Any, payload),
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self._send_voice_with_fallback(
self.client,
path,
payload,
caption=i.text or delta or None,
user_name=user_name,
message_thread_id=message_thread_id,
use_media_action=True,
)
continue
elif isinstance(i, Video):
path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_VIDEO,
self.client.send_video,
user_name=user_name,
video=path,
**cast(Any, payload),
)
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
continue
# 编辑或发送消息 # Plain
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH: if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_running_loop().time() current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time time_since_last_edit = current_time - last_edit_time
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval: if time_since_last_edit >= throttle_interval:
current_time = asyncio.get_running_loop().time() # 发送 typing 状态(带节流)
current_time = asyncio.get_event_loop().time()
if current_time - last_chat_action_time >= chat_action_interval: if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id) await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time last_chat_action_time = current_time
# 编辑消息
try: try:
await self.client.edit_message_text( await self.client.edit_message_text(
text=delta, text=delta,
@@ -695,9 +462,13 @@ class TelegramPlatformEvent(AstrMessageEvent):
current_content = delta current_content = delta
except Exception as e: except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}") logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = asyncio.get_running_loop().time() last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
else: else:
current_time = asyncio.get_running_loop().time() # delta 长度一般不会大于 4096,因此这里直接发送
# 发送 typing 状态(带节流)
current_time = asyncio.get_event_loop().time()
if current_time - last_chat_action_time >= chat_action_interval: if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id) await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time last_chat_action_time = current_time
@@ -709,13 +480,16 @@ class TelegramPlatformEvent(AstrMessageEvent):
except Exception as e: except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}") logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id message_id = msg.message_id
last_edit_time = asyncio.get_running_loop().time() last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
try: try:
if delta and current_content != delta: if delta and current_content != delta:
try: try:
markdown_text = telegramify_markdown.markdownify( markdown_text = telegramify_markdown.markdownify(
delta, delta,
normalize_whitespace=False,
) )
await self.client.edit_message_text( await self.client.edit_message_text(
text=markdown_text, text=markdown_text,
@@ -732,3 +506,5 @@ class TelegramPlatformEvent(AstrMessageEvent):
) )
except Exception as e: except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}") logger.warning(f"编辑消息失败(streaming): {e!s}")
return await super().send_streaming(generator, use_fallback)
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
return msg_list[-1] return msg_list[-1]
return None return None
msg_new = await asyncio.get_running_loop().run_in_executor( msg_new = await asyncio.get_event_loop().run_in_executor(
None, None,
get_latest_msg_item, get_latest_msg_item,
) )
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
@override @override
async def run(self) -> None: async def run(self) -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
if self.kf_name: if self.kf_name:
try: try:
acc_list = ( acc_list = (
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
elif isinstance(msg, VoiceMessage): elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_running_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client.media.download, self.client.media.download,
msg.media_id, msg.media_id,
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
abm.message_str = text abm.message_str = text
elif msgtype == "image": elif msgtype == "image":
media_id = msg.get("image", {}).get("media_id", "") media_id = msg.get("image", {}).get("media_id", "")
resp: Response = await asyncio.get_running_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client.media.download, self.client.media.download,
media_id, media_id,
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
abm.message = [Image(file=path, url=path)] abm.message = [Image(file=path, url=path)]
elif msgtype == "voice": elif msgtype == "voice":
media_id = msg.get("voice", {}).get("media_id", "") media_id = msg.get("voice", {}).get("media_id", "")
resp: Response = await asyncio.get_running_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client.media.download, self.client.media.download,
media_id, media_id,
@@ -1,5 +1,5 @@
"""企业微信智能机器人平台适配器 """企业微信智能机器人平台适配器
基于企业微信智能机器人 API 的消息平台适配器支持 HTTP 回调与长连接 基于企业微信智能机器人 API 的消息平台适配器支持 HTTP 回调
参考webchat_adapter.py的队列机制实现异步消息处理和流式响应 参考webchat_adapter.py的队列机制实现异步消息处理和流式响应
""" """
@@ -31,7 +31,6 @@ from .wecomai_api import (
WecomAIBotStreamMessageBuilder, WecomAIBotStreamMessageBuilder,
) )
from .wecomai_event import WecomAIBotMessageEvent from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_long_connection import WecomAIBotLongConnectionClient
from .wecomai_queue_mgr import WecomAIQueueMgr from .wecomai_queue_mgr import WecomAIQueueMgr
from .wecomai_server import WecomAIBotServer from .wecomai_server import WecomAIBotServer
from .wecomai_utils import ( from .wecomai_utils import (
@@ -79,13 +78,8 @@ class WecomAIBotAdapter(Platform):
self.settings = platform_settings self.settings = platform_settings
# 初始化配置参数 # 初始化配置参数
self.connection_mode = self.config.get( self.token = self.config["token"]
"wecom_ai_bot_connection_mode", "webhook" self.encoding_aes_key = self.config["encoding_aes_key"]
)
self.token = self.config.get("token", self.config.get("wecomaibot_token", ""))
self.encoding_aes_key = self.config.get(
"encoding_aes_key", self.config.get("wecomaibot_encoding_aes_key", "")
)
self.port = int(self.config["port"]) self.port = int(self.config["port"])
self.host = self.config.get("callback_server_host", "0.0.0.0") self.host = self.config.get("callback_server_host", "0.0.0.0")
self.bot_name = self.config.get("wecom_ai_bot_name", "") self.bot_name = self.config.get("wecom_ai_bot_name", "")
@@ -102,46 +96,19 @@ class WecomAIBotAdapter(Platform):
self.only_use_webhook_url_to_send = bool( self.only_use_webhook_url_to_send = bool(
self.config.get("only_use_webhook_url_to_send", False), self.config.get("only_use_webhook_url_to_send", False),
) )
self.long_connection_bot_id = self.config.get(
"wecomaibot_ws_bot_id", self.config.get("long_connection_bot_id", "")
)
self.long_connection_secret = self.config.get(
"wecomaibot_ws_secret", self.config.get("long_connection_secret", "")
)
self.long_connection_ws_url = self.config.get(
"wecomaibot_ws_url",
"wss://openws.work.weixin.qq.com",
)
self.long_connection_heartbeat_interval = int(
self.config.get("wecomaibot_heartbeat_interval", 30),
)
# 平台元数据 # 平台元数据
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="wecom_ai_bot", name="wecom_ai_bot",
description="企业微信智能机器人适配器,支持 HTTP 回调和长连接模式", description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
id=self.config.get("id", "wecom_ai_bot"), id=self.config.get("id", "wecom_ai_bot"),
support_proactive_message=bool(self.msg_push_webhook_url), support_proactive_message=bool(self.msg_push_webhook_url),
) )
self.api_client: WecomAIBotAPIClient | None = None # 初始化 API 客户端
self.server: WecomAIBotServer | None = None
self.long_connection_client: WecomAIBotLongConnectionClient | None = None
if self.connection_mode == "long_connection":
if not self.long_connection_bot_id or not self.long_connection_secret:
logger.warning(
"企业微信智能机器人长连接模式缺少 BotID 或 Secret,连接可能失败"
)
self.long_connection_client = WecomAIBotLongConnectionClient(
bot_id=self.long_connection_bot_id,
secret=self.long_connection_secret,
ws_url=self.long_connection_ws_url,
heartbeat_interval=self.long_connection_heartbeat_interval,
message_handler=self._process_long_connection_payload,
)
else:
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key) self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
# 初始化 HTTP 服务器
self.server = WecomAIBotServer( self.server = WecomAIBotServer(
host=self.host, host=self.host,
port=self.port, port=self.port,
@@ -194,9 +161,6 @@ class WecomAIBotAdapter(Platform):
加密后的响应消息无需响应时返回 None 加密后的响应消息无需响应时返回 None
""" """
if not self.api_client:
logger.error("Webhook 消息处理失败: API 客户端未初始化")
return None
msgtype = message_data.get("msgtype") msgtype = message_data.get("msgtype")
if not msgtype: if not msgtype:
logger.warning(f"消息类型未知,忽略: {message_data}") logger.warning(f"消息类型未知,忽略: {message_data}")
@@ -356,98 +320,8 @@ class WecomAIBotAdapter(Platform):
logger.error("处理欢迎消息时发生异常: %s", e) logger.error("处理欢迎消息时发生异常: %s", e)
return None return None
async def _process_long_connection_payload(
self,
payload: dict[str, Any],
) -> None:
"""处理长连接回调消息。"""
cmd = payload.get("cmd")
headers = payload.get("headers") or {}
body = payload.get("body") or {}
req_id = headers.get("req_id")
if not isinstance(body, dict):
return
if cmd == "aibot_msg_callback":
session_id = self._extract_session_id(body)
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._enqueue_message(
body, {"req_id": req_id or ""}, stream_id, session_id
)
self.queue_mgr.set_pending_response(
stream_id,
{
"req_id": req_id or "",
"connection_mode": "long_connection",
},
)
if self.initial_respond_text and req_id:
await self._send_long_connection_respond_msg(
req_id=req_id,
body={
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": self.initial_respond_text,
},
},
)
return
if cmd == "aibot_event_callback":
event = body.get("event") or {}
event_type = event.get("eventtype")
if (
event_type == "enter_chat"
and self.friend_message_welcome_text
and req_id
):
await self._send_long_connection_respond_welcome(req_id)
elif event_type == "disconnected_event":
logger.warning(
"[WecomAI][LongConn] 收到 disconnected_event,旧连接将被关闭"
)
async def _send_long_connection_respond_welcome(self, req_id: str) -> bool:
client = self.long_connection_client
if not client:
return False
return await client.send_command(
cmd="aibot_respond_welcome_msg",
req_id=req_id,
body={
"msgtype": "text",
"text": {
"content": self.friend_message_welcome_text,
},
},
)
async def _send_long_connection_respond_msg(
self,
req_id: str,
body: dict[str, Any],
) -> bool:
client = self.long_connection_client
if not client:
return False
return await client.send_command(
cmd="aibot_respond_msg",
req_id=req_id,
body=body,
)
def _extract_session_id(self, message_data: dict[str, Any]) -> str: def _extract_session_id(self, message_data: dict[str, Any]) -> str:
"""从消息数据中提取会话ID """从消息数据中提取会话ID"""
群聊使用 chatid单聊使用 userid
"""
chattype = message_data.get("chattype", "single")
if chattype == "group":
chat_id = message_data.get("chatid", "default_group")
return format_session_id("wecomai", chat_id)
else:
user_id = message_data.get("from", {}).get("userid", "default_user") user_id = message_data.get("from", {}).get("userid", "default_user")
return format_session_id("wecomai", user_id) return format_session_id("wecomai", user_id)
@@ -481,16 +355,15 @@ class WecomAIBotAdapter(Platform):
content = "" content = ""
image_base64 = [] image_base64 = []
_img_url_to_process: list[tuple[str, str | None]] = [] _img_url_to_process = []
msg_items = [] msg_items = []
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT: if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
content = WecomAIBotMessageParser.parse_text_message(message_data) content = WecomAIBotMessageParser.parse_text_message(message_data)
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE: elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
image_payload = message_data.get("image", {}) _img_url_to_process.append(
image_url = image_payload.get("url", "") WecomAIBotMessageParser.parse_image_message(message_data),
if image_url: )
_img_url_to_process.append((image_url, image_payload.get("aeskey")))
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED: elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
# 提取混合消息中的文本内容 # 提取混合消息中的文本内容
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data) msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
@@ -501,12 +374,9 @@ class WecomAIBotAdapter(Platform):
if text_content: if text_content:
text_parts.append(text_content) text_parts.append(text_content)
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE: elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
image_payload = item.get("image", {}) image_url = item.get("image", {}).get("url", "")
image_url = image_payload.get("url", "")
if image_url: if image_url:
_img_url_to_process.append( _img_url_to_process.append(image_url)
(image_url, image_payload.get("aeskey"))
)
content = " ".join(text_parts) if text_parts else "" content = " ".join(text_parts) if text_parts else ""
else: else:
content = f"[{msgtype}消息]" content = f"[{msgtype}消息]"
@@ -514,8 +384,8 @@ class WecomAIBotAdapter(Platform):
# 并行处理图片下载和解密 # 并行处理图片下载和解密
if _img_url_to_process: if _img_url_to_process:
tasks = [ tasks = [
process_encrypted_image(url, aes_key or self.encoding_aes_key) process_encrypted_image(url, self.encoding_aes_key)
for url, aes_key in _img_url_to_process for url in _img_url_to_process
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
for success, result in results: for success, result in results:
@@ -589,28 +459,13 @@ class WecomAIBotAdapter(Platform):
"""运行适配器,同时启动HTTP服务器和队列监听器""" """运行适配器,同时启动HTTP服务器和队列监听器"""
async def run_both() -> None: async def run_both() -> None:
if self.connection_mode == "long_connection":
if not self.long_connection_client:
raise RuntimeError("长连接客户端未初始化")
logger.info(
"启动企业微信智能机器人长连接模式: %s", self.long_connection_ws_url
)
await asyncio.gather(
self.long_connection_client.start(),
self.queue_listener.run(),
)
else:
# 如果启用统一 webhook 模式,则不启动独立服务器 # 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid") webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid: if self.unified_webhook_mode and webhook_uuid:
log_webhook_info( log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
f"{self.meta().id}(企业微信智能机器人)", webhook_uuid
)
# 只运行队列监听器 # 只运行队列监听器
await self.queue_listener.run() await self.queue_listener.run()
else: else:
if not self.server:
raise RuntimeError("Webhook 服务器未初始化")
logger.info( logger.info(
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port "启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
) )
@@ -624,8 +479,6 @@ class WecomAIBotAdapter(Platform):
async def webhook_callback(self, request: Any) -> Any: async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口""" """统一 Webhook 回调入口"""
if self.connection_mode == "long_connection" or not self.server:
return "long_connection mode does not accept webhook callbacks", 400
# 根据请求方法分发到不同的处理函数 # 根据请求方法分发到不同的处理函数
if request.method == "GET": if request.method == "GET":
return await self.server.handle_verify(request) return await self.server.handle_verify(request)
@@ -636,9 +489,6 @@ class WecomAIBotAdapter(Platform):
"""终止适配器""" """终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...") logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set() self.shutdown_event.set()
if self.long_connection_client:
await self.long_connection_client.shutdown()
if self.server:
await self.server.shutdown() await self.server.shutdown()
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
@@ -657,22 +507,17 @@ class WecomAIBotAdapter(Platform):
queue_mgr=self.queue_mgr, queue_mgr=self.queue_mgr,
webhook_client=self.webhook_client, webhook_client=self.webhook_client,
only_use_webhook_url_to_send=self.only_use_webhook_url_to_send, only_use_webhook_url_to_send=self.only_use_webhook_url_to_send,
long_connection_sender=self._send_long_connection_respond_msg,
) )
message_event.is_at_or_wake_command = (
True # 企业微信智能机器人默认消息都是 at 或唤醒命令
)
message_event.is_wake = True # 企业微信智能机器人消息默认当做唤醒命令处理
self.commit_event(message_event) self.commit_event(message_event)
except Exception as e: except Exception as e:
logger.error("处理消息时发生异常: %s", e) logger.error("处理消息时发生异常: %s", e)
def get_client(self) -> WecomAIBotAPIClient | None: def get_client(self) -> WecomAIBotAPIClient:
"""获取 API 客户端""" """获取 API 客户端"""
return self.api_client return self.api_client
def get_server(self) -> WecomAIBotServer | None: def get_server(self) -> WecomAIBotServer:
"""获取 HTTP 服务器实例""" """获取 HTTP 服务器实例"""
return self.server return self.server
@@ -1,7 +1,5 @@
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收""" """企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
from collections.abc import Awaitable, Callable
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Plain from astrbot.api.message_components import At, Image, Plain
@@ -20,11 +18,10 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
message_obj, message_obj,
platform_meta, platform_meta,
session_id: str, session_id: str,
api_client: WecomAIBotAPIClient | None, api_client: WecomAIBotAPIClient,
queue_mgr: WecomAIQueueMgr, queue_mgr: WecomAIQueueMgr,
webhook_client: WecomAIBotWebhookClient | None = None, webhook_client: WecomAIBotWebhookClient | None = None,
only_use_webhook_url_to_send: bool = False, only_use_webhook_url_to_send: bool = False,
long_connection_sender: (Callable[[str, dict], Awaitable[bool]] | None) = None,
) -> None: ) -> None:
"""初始化消息事件 """初始化消息事件
@@ -41,7 +38,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
self.queue_mgr = queue_mgr self.queue_mgr = queue_mgr
self.webhook_client = webhook_client self.webhook_client = webhook_client
self.only_use_webhook_url_to_send = only_use_webhook_url_to_send self.only_use_webhook_url_to_send = only_use_webhook_url_to_send
self.long_connection_sender = long_connection_sender
async def _mark_stream_complete(self, stream_id: str) -> None: async def _mark_stream_complete(self, stream_id: str) -> None:
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id) back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
@@ -121,18 +117,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data return data
@staticmethod
def _extract_plain_text_from_chain(message_chain: MessageChain | None) -> str:
if not message_chain:
return ""
plain_parts: list[str] = []
for comp in message_chain.chain:
if isinstance(comp, At):
plain_parts.append(f"@{comp.name} ")
elif isinstance(comp, Plain):
plain_parts.append(comp.text)
return "".join(plain_parts).strip()
async def send(self, message: MessageChain | None) -> None: async def send(self, message: MessageChain | None) -> None:
"""发送消息""" """发送消息"""
raw = self.message_obj.raw_message raw = self.message_obj.raw_message
@@ -140,44 +124,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict" "wecom_ai_bot platform event raw_message should be a dict"
) )
stream_id = raw.get("stream_id", self.session_id) stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get(
"connection_mode"
)
req_id = pending_response.get("callback_params", {}).get("req_id")
if (
connection_mode == "long_connection"
and self.long_connection_sender
and isinstance(req_id, str)
and req_id
):
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(
message,
unsupported_only=True,
)
content = self._extract_plain_text_from_chain(message)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content,
},
},
)
await super().send(MessageChain([]))
return
if self.only_use_webhook_url_to_send and self.webhook_client and message: if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message) await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id) await self._mark_stream_complete(stream_id)
@@ -206,77 +152,8 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict" "wecom_ai_bot platform event raw_message should be a dict"
) )
stream_id = raw.get("stream_id", self.session_id) stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get(
"connection_mode"
)
req_id = pending_response.get("callback_params", {}).get("req_id")
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id) back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if (
connection_mode == "long_connection"
and self.long_connection_sender
and isinstance(req_id, str)
and req_id
):
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": "",
},
},
)
await super().send_streaming(generator, use_fallback)
return
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
chain,
unsupported_only=True,
)
chain.squash_plain()
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": increment_plain,
},
},
)
await super().send_streaming(generator, use_fallback)
return
if self.only_use_webhook_url_to_send and self.webhook_client: if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([]) merged_chain = MessageChain([])
async for chain in generator: async for chain in generator:
@@ -1,236 +0,0 @@
"""企业微信智能机器人长连接客户端。"""
import asyncio
import json
import uuid
from collections.abc import Awaitable, Callable
from typing import Any
import aiohttp
from astrbot.api import logger
class WecomAIBotLongConnectionClient:
"""企业微信智能机器人 WebSocket 长连接客户端。"""
def __init__(
self,
bot_id: str,
secret: str,
ws_url: str,
heartbeat_interval: int,
message_handler: Callable[[dict[str, Any]], Awaitable[None]],
) -> None:
self.bot_id = bot_id
self.secret = secret
self.ws_url = ws_url
self.heartbeat_interval = max(5, int(heartbeat_interval))
self.message_handler = message_handler
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._shutdown_event = asyncio.Event()
self._send_lock = asyncio.Lock()
self._command_lock = asyncio.Lock()
self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
@staticmethod
def gen_req_id() -> str:
return uuid.uuid4().hex
async def start(self) -> None:
"""启动长连接并自动重连。"""
reconnect_delay = 1
while not self._shutdown_event.is_set():
try:
await self._run_once()
reconnect_delay = 1
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("[WecomAI][LongConn] 长连接异常: %s", e)
if self._shutdown_event.is_set():
break
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30)
async def _run_once(self) -> None:
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=None)
async with aiohttp.ClientSession(timeout=timeout) as session:
self._session = session
logger.info("[WecomAI][LongConn] 正在连接: %s", self.ws_url)
async with session.ws_connect(
self.ws_url, heartbeat=None, autoping=True
) as ws:
self._ws = ws
await self._subscribe()
logger.info("[WecomAI][LongConn] 订阅成功,已建立长连接")
heartbeat_task = asyncio.create_task(self._heartbeat_loop())
try:
while not self._shutdown_event.is_set():
message = await ws.receive()
if message.type == aiohttp.WSMsgType.TEXT:
await self._handle_text_message(message.data)
elif message.type in {
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.ERROR,
}:
break
finally:
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
self._ws = None
async def _subscribe(self) -> None:
"""发送 aibot_subscribe,并等待响应。"""
req_id = self.gen_req_id()
payload = {
"cmd": "aibot_subscribe",
"headers": {"req_id": req_id},
"body": {"bot_id": self.bot_id, "secret": self.secret},
}
await self._send_json(payload)
if not self._ws:
raise RuntimeError("WebSocket 未建立")
reply = await self._ws.receive(timeout=10)
if reply.type != aiohttp.WSMsgType.TEXT:
raise RuntimeError(f"订阅失败: 非文本响应 {reply.type}")
data = json.loads(reply.data)
if data.get("errcode") != 0:
raise RuntimeError(
f"订阅失败 errcode={data.get('errcode')} errmsg={data.get('errmsg')}"
)
async def _heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
await asyncio.sleep(self.heartbeat_interval)
if self._shutdown_event.is_set():
break
try:
await self.send_command("ping", self.gen_req_id(), None)
except Exception as e:
logger.warning("[WecomAI][LongConn] 发送心跳失败: %s", e)
return
async def _handle_text_message(self, text: str) -> None:
try:
payload = json.loads(text)
except json.JSONDecodeError:
logger.warning("[WecomAI][LongConn] 收到非 JSON 消息: %s", text)
return
headers = payload.get("headers") or {}
req_id = headers.get("req_id")
if isinstance(req_id, str):
waiter = self._response_waiters.get(req_id)
if waiter and not waiter.done():
waiter.set_result(payload)
return
cmd = payload.get("cmd")
if cmd in {"aibot_msg_callback", "aibot_event_callback"}:
await self.message_handler(payload)
return
if payload.get("errcode") not in (None, 0):
logger.warning(
"[WecomAI][LongConn] 服务端返回错误: errcode=%s errmsg=%s",
payload.get("errcode"),
payload.get("errmsg"),
)
async def send_command(
self,
cmd: str,
req_id: str,
body: dict[str, Any] | None,
) -> bool:
"""发送长连接命令。"""
headers = {"req_id": req_id}
payload: dict[str, Any] = {"cmd": cmd, "headers": headers}
if body is not None:
payload["body"] = body
async with self._command_lock:
max_retries = 3
for attempt in range(max_retries + 1):
response = await self._send_and_wait_response(req_id, payload)
if not response:
if attempt < max_retries:
await asyncio.sleep(min(0.2 * (2**attempt), 2.0))
continue
return False
errcode = response.get("errcode")
if errcode in (0, None):
return True
if errcode == 6000 and attempt < max_retries:
backoff = min(0.2 * (2**attempt), 2.0)
logger.warning(
"[WecomAI][LongConn] 命令冲突(errcode=6000),将重试。cmd=%s req_id=%s attempt=%d",
cmd,
req_id,
attempt + 1,
)
await asyncio.sleep(backoff)
continue
logger.warning(
"[WecomAI][LongConn] 命令失败: cmd=%s req_id=%s errcode=%s errmsg=%s",
cmd,
req_id,
errcode,
response.get("errmsg"),
)
return False
return False
async def _send_and_wait_response(
self,
req_id: str,
payload: dict[str, Any],
timeout: float = 10.0,
) -> dict[str, Any] | None:
loop = asyncio.get_running_loop()
waiter: asyncio.Future[dict[str, Any]] = loop.create_future()
self._response_waiters[req_id] = waiter
try:
await self._send_json(payload)
return await asyncio.wait_for(waiter, timeout=timeout)
except TimeoutError:
logger.warning(
"[WecomAI][LongConn] 等待命令响应超时: cmd=%s req_id=%s",
payload.get("cmd"),
req_id,
)
return None
finally:
self._response_waiters.pop(req_id, None)
async def _send_json(self, payload: dict[str, Any]) -> None:
ws = self._ws
if ws is None or ws.closed:
raise RuntimeError("长连接尚未建立")
async with self._send_lock:
await ws.send_json(payload)
async def shutdown(self) -> None:
self._shutdown_event.set()
ws = self._ws
if ws is not None and not ws.closed:
await ws.close()
session = self._session
if session is not None and not session.closed:
await session.close()
@@ -4,7 +4,6 @@
""" """
import asyncio import asyncio
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
@@ -83,7 +82,7 @@ class WecomAIQueueMgr:
del self.pending_responses[session_id] del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}") logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
if mark_finished: if mark_finished:
self.completed_streams[session_id] = time.monotonic() self.completed_streams[session_id] = asyncio.get_event_loop().time()
logger.debug(f"[WecomAI] 标记流已结束: {session_id}") logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
def remove_queue(self, session_id: str): def remove_queue(self, session_id: str):
@@ -136,7 +135,7 @@ class WecomAIQueueMgr:
""" """
self.pending_responses[session_id] = { self.pending_responses[session_id] = {
"callback_params": callback_params, "callback_params": callback_params,
"timestamp": time.monotonic(), "timestamp": asyncio.get_event_loop().time(),
} }
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}") logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
@@ -161,7 +160,7 @@ class WecomAIQueueMgr:
finished_at = self.completed_streams.get(session_id) finished_at = self.completed_streams.get(session_id)
if finished_at is None: if finished_at is None:
return False return False
if time.monotonic() - finished_at > max_age_seconds: if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
self.completed_streams.pop(session_id, None) self.completed_streams.pop(session_id, None)
return False return False
return True return True
@@ -173,7 +172,7 @@ class WecomAIQueueMgr:
max_age_seconds: 最大存活时间 max_age_seconds: 最大存活时间
""" """
current_time = time.monotonic() current_time = asyncio.get_event_loop().time()
expired_sessions = [] expired_sessions = []
for session_id, response_data in self.pending_responses.items(): for session_id, response_data in self.pending_responses.items():
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if future: if future:
logger.debug(f"duplicate message id checked: {msg.id}") logger.debug(f"duplicate message id checked: {msg.id}")
else: else:
future = asyncio.get_running_loop().create_future() future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg_id] = future self.wexin_event_workers[msg_id] = future
await self.convert_message(msg, future) await self.convert_message(msg, future)
# I love shield so much! # I love shield so much!
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
elif msg.type == "voice": elif msg.type == "voice":
assert isinstance(msg, VoiceMessage) assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_running_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client.media.download, self.client.media.download,
msg.media_id, msg.media_id,
+17 -23
View File
@@ -21,8 +21,8 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DEFAULT_MCP_CONFIG = {"mcpServers": {}} DEFAULT_MCP_CONFIG = {"mcpServers": {}}
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 180.0 DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 20.0
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 180.0 DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 30.0
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT" MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT" ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
MAX_MCP_TIMEOUT_SECONDS = 300.0 MAX_MCP_TIMEOUT_SECONDS = 300.0
@@ -417,11 +417,9 @@ class FunctionToolManager:
for (name, cfg, _), result in zip(active_configs, results, strict=False): for (name, cfg, _), result in zip(active_configs, results, strict=False):
if isinstance(result, Exception): if isinstance(result, Exception):
if isinstance(result, MCPInitTimeoutError): if isinstance(result, MCPInitTimeoutError):
logger.error( logger.error(f"MCP 服务 {name} 初始化超时({timeout_display}秒)")
f"Connected to MCP server {name} timeout ({timeout_display} seconds)"
)
else: else:
logger.error(f"Failed to initialize MCP server {name}: {result}") logger.error(f"MCP 服务 {name} 初始化失败: {result}")
self._log_safe_mcp_debug_config(cfg) self._log_safe_mcp_debug_config(cfg)
failed_services.append(name) failed_services.append(name)
async with self._runtime_lock: async with self._runtime_lock:
@@ -432,18 +430,16 @@ class FunctionToolManager:
if failed_services: if failed_services:
logger.warning( logger.warning(
f"The following MCP services failed to initialize: {', '.join(failed_services)}. " f"以下 MCP 服务初始化失败: {', '.join(failed_services)}"
f"Please check the mcp_server.json file and server availability." f"请检查配置文件 mcp_server.json 和服务器可用性。"
) )
summary = MCPInitSummary( summary = MCPInitSummary(
total=len(active_configs), success=success_count, failed=failed_services total=len(active_configs), success=success_count, failed=failed_services
) )
logger.info( logger.info(f"MCP 服务初始化完成: {summary.success}/{summary.total} 成功")
f"MCP services initialization completed: {summary.success}/{summary.total} successful, {len(summary.failed)} failed."
)
if summary.total > 0 and summary.success == 0: if summary.total > 0 and summary.success == 0:
msg = "All MCP services failed to initialize, please check the mcp_server.json and server availability." msg = "全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
if raise_on_all_failed: if raise_on_all_failed:
raise MCPAllServicesFailedError(msg) raise MCPAllServicesFailedError(msg)
logger.error(msg) logger.error(msg)
@@ -465,7 +461,7 @@ class FunctionToolManager:
async with self._runtime_lock: async with self._runtime_lock:
if name in self._mcp_server_runtime or name in self._mcp_starting: if name in self._mcp_server_runtime or name in self._mcp_starting:
logger.warning( logger.warning(
f"Connected to MCP server {name}, ignoring this startup request (timeout={timeout:g})." f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。"
) )
self._log_safe_mcp_debug_config(cfg) self._log_safe_mcp_debug_config(cfg)
return return
@@ -482,10 +478,10 @@ class FunctionToolManager:
) )
except asyncio.TimeoutError as exc: except asyncio.TimeoutError as exc:
raise MCPInitTimeoutError( raise MCPInitTimeoutError(
f"Connected to MCP server {name} timeout ({timeout:g} seconds)" f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
) from exc ) from exc
except Exception: except Exception:
logger.error(f"Failed to initialize MCP client {name}", exc_info=True) logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
raise raise
finally: finally:
if mcp_client is None: if mcp_client is None:
@@ -495,9 +491,9 @@ class FunctionToolManager:
async def lifecycle() -> None: async def lifecycle() -> None:
try: try:
await shutdown_event.wait() await shutdown_event.wait()
logger.info(f"Received shutdown signal for MCP client {name}") logger.info(f"收到 MCP 客户端 {name} 终止信号")
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug(f"MCP client {name} task was cancelled") logger.debug(f"MCP 客户端 {name} 任务被取消")
raise raise
finally: finally:
await self._terminate_mcp_client(name) await self._terminate_mcp_client(name)
@@ -549,7 +545,7 @@ class FunctionToolManager:
if strict: if strict:
raise MCPShutdownTimeoutError(pending_names, timeout) raise MCPShutdownTimeoutError(pending_names, timeout)
logger.warning( logger.warning(
"MCP server shutdown timeout (%s seconds), the following servers were not fully closed: %s", "MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s",
f"{timeout:g}", f"{timeout:g}",
", ".join(pending_names), ", ".join(pending_names),
) )
@@ -572,9 +568,7 @@ class FunctionToolManager:
try: try:
await mcp_client.cleanup() await mcp_client.cleanup()
except Exception as cleanup_exc: # noqa: BLE001 - only log here except Exception as cleanup_exc: # noqa: BLE001 - only log here
logger.error( logger.error(f"清理 MCP 客户端资源 {name} 失败: {cleanup_exc}")
f"Failed to cleanup MCP client resources {name}: {cleanup_exc}"
)
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient: async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
"""初始化单个MCP客户端""" """初始化单个MCP客户端"""
@@ -608,7 +602,7 @@ class FunctionToolManager:
) )
self.func_list.append(func_tool) self.func_list.append(func_tool)
logger.info(f"Connected to MCP server {name}, Tools: {tool_names}") logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return mcp_client return mcp_client
async def _terminate_mcp_client(self, name: str) -> None: async def _terminate_mcp_client(self, name: str) -> None:
@@ -628,7 +622,7 @@ class FunctionToolManager:
async with self._runtime_lock: async with self._runtime_lock:
self._mcp_server_runtime.pop(name, None) self._mcp_server_runtime.pop(name, None)
self._mcp_starting.discard(name) self._mcp_starting.discard(name)
logger.info(f"Disconnected from MCP server {name}") logger.info(f"已关闭 MCP 服务 {name}")
return return
# Runtime missing but stale tools may still exist after failed flows. # Runtime missing but stale tools may still exist after failed flows.
+18 -20
View File
@@ -79,7 +79,6 @@ class ProviderManager:
self._provider_change_hooks: list[ self._provider_change_hooks: list[
Callable[[str, ProviderType, str | None], None] Callable[[str, ProviderType, str | None], None]
] = [] ] = []
self._mcp_init_task: asyncio.Task | None = None
def set_provider_change_callback( def set_provider_change_callback(
self, self,
@@ -331,16 +330,24 @@ class ProviderManager:
if not self.curr_tts_provider_inst and self.tts_provider_insts: if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0] self.curr_tts_provider_inst = self.tts_provider_insts[0]
async def _init_mcp_clients_bg() -> None: # 初始化 MCP Client 连接(等待完成以确保工具可用)
try: strict_mcp_init = os.getenv("ASTRBOT_MCP_INIT_STRICT", "").strip().lower() in {
await self.llm_tools.init_mcp_clients() "1",
except Exception: "true",
logger.error("MCP init background task failed", exc_info=True) "yes",
"on",
if self._mcp_init_task is None or self._mcp_init_task.done(): }
self._mcp_init_task = asyncio.create_task( mcp_init_summary = await self.llm_tools.init_mcp_clients(
_init_mcp_clients_bg(), raise_on_all_failed=strict_mcp_init
name="provider-manager:mcp-init", )
if (
mcp_init_summary.total > 0
and mcp_init_summary.success == 0
and not strict_mcp_init
):
logger.warning(
"MCP 服务全部初始化失败,系统将继续启动(可设置 "
"ASTRBOT_MCP_INIT_STRICT=1 以在此场景下中止启动)。"
) )
def dynamic_import_provider(self, type: str) -> None: def dynamic_import_provider(self, type: str) -> None:
@@ -808,17 +815,8 @@ class ProviderManager:
config.save_config() config.save_config()
# load instance # load instance
await self.load_provider(new_config) await self.load_provider(new_config)
# sync in-memory config for API queries (e.g., embedding provider list)
self.providers_config = astrbot_config["provider"]
async def terminate(self) -> None: async def terminate(self) -> None:
if self._mcp_init_task and not self._mcp_init_task.done():
self._mcp_init_task.cancel()
try:
await self._mcp_init_task
except asyncio.CancelledError:
pass
for provider_inst in self.provider_insts: for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"): if hasattr(provider_inst, "terminate"):
await provider_inst.terminate() # type: ignore await provider_inst.terminate() # type: ignore
+1 -18
View File
@@ -281,24 +281,7 @@ class TTSProvider(AbstractProvider):
accumulated_text += text_part accumulated_text += text_part
async def test(self) -> None: async def test(self) -> None:
audio_path = await self.get_audio("hi") await self.get_audio("hi")
# 检查生成的音频文件是否有效
if not os.path.exists(audio_path):
raise Exception("TTS test failed: audio file was not created")
file_size = os.path.getsize(audio_path)
if file_size == 0:
raise Exception(
"TTS test failed: generated audio file is empty (0 bytes). "
"Please check your TTS provider configuration, especially required parameters like group_id for MiniMax."
)
# 清理测试文件
try:
os.remove(audio_path)
except Exception:
pass
class EmbeddingProvider(AbstractProvider): class EmbeddingProvider(AbstractProvider):
@@ -276,24 +276,9 @@ class ProviderAnthropic(Provider):
llm_response.id = completion.id llm_response.id = completion.id
llm_response.usage = self._extract_usage(completion.usage) llm_response.usage = self._extract_usage(completion.usage)
# Handle cases where completion only contains ThinkingBlock (e.g., MiniMax max_tokens) # TODO(Soulter): 处理 end_turn 情况
# When stop_reason='max_tokens', the model may return only thinking content
# This is valid and should not raise an exception
if not llm_response.completion_text and not llm_response.tools_call_args: if not llm_response.completion_text and not llm_response.tools_call_args:
# Guard clause: raise early if no valid content at all raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
if not llm_response.reasoning_content:
raise ValueError(
f"Anthropic API returned unparsable completion: "
f"no text, tool_use, or thinking content found. "
f"Completion: {completion}"
)
# We have reasoning content (ThinkingBlock) - this is valid
stop_reason = getattr(completion, "stop_reason", "unknown")
logger.debug(
f"Completion contains only ThinkingBlock (stop_reason={stop_reason})"
)
llm_response.completion_text = "" # Ensure empty string, not None
return llm_response return llm_response
@@ -20,7 +20,6 @@ from ..register import register_provider_adapter
TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts" TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts"
TEMP_DIR.mkdir(parents=True, exist_ok=True) TEMP_DIR.mkdir(parents=True, exist_ok=True)
AZURE_TTS_SUBSCRIPTION_KEY_PATTERN = r"^(?:[a-zA-Z0-9]{32}|[a-zA-Z0-9]{84})$"
class OTTSProvider: class OTTSProvider:
@@ -117,7 +116,7 @@ class AzureNativeProvider(TTSProvider):
"azure_tts_subscription_key", "azure_tts_subscription_key",
"", "",
).strip() ).strip()
if not re.fullmatch(AZURE_TTS_SUBSCRIPTION_KEY_PATTERN, self.subscription_key): if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
raise ValueError("无效的Azure订阅密钥") raise ValueError("无效的Azure订阅密钥")
self.region = provider_config.get("azure_tts_region", "eastus").strip() self.region = provider_config.get("azure_tts_region", "eastus").strip()
self.endpoint = ( self.endpoint = (
@@ -236,9 +235,9 @@ class AzureTTSProvider(TTSProvider):
raise ValueError(error_msg) from e raise ValueError(error_msg) from e
except KeyError as e: except KeyError as e:
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
if re.fullmatch(AZURE_TTS_SUBSCRIPTION_KEY_PATTERN, key_value): if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
return AzureNativeProvider(config, self.provider_settings) return AzureNativeProvider(config, self.provider_settings)
raise ValueError("订阅密钥格式无效,应为32位或84位字母数字或other[...]格式") raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
if isinstance(self.provider, OTTSProvider): if isinstance(self.provider, OTTSProvider):
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
model: str, model: str,
text: str, text: str,
) -> tuple[bytes | None, str]: ) -> tuple[bytes | None, str]:
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text) response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response) audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes: if not audio_bytes:
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
voice=self.voice, voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT, format=AudioFormat.WAV_24000HZ_MONO_16BIT,
) )
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
audio_bytes = await loop.run_in_executor( audio_bytes = await loop.run_in_executor(
None, None,
synthesizer.call, synthesizer.call,
+2 -2
View File
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
filename = f"genie_tts_{uuid.uuid4()}.wav" filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename) path = os.path.join(temp_dir, filename)
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
def _generate(save_path: str) -> None: def _generate(save_path: str) -> None:
assert genie is not None assert genie is not None
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
text_queue: asyncio.Queue[str | None], text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
) -> None: ) -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
while True: while True:
text = await text_queue.get() text = await text_queue.get()
@@ -13,11 +13,3 @@ class ProviderGroq(ProviderOpenAIOfficial):
) -> None: ) -> None:
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.reasoning_key = "reasoning" self.reasoning_key = "reasoning"
def _finally_convert_payload(self, payloads: dict) -> None:
"""Groq rejects assistant history items that include reasoning_content."""
super()._finally_convert_payload(payloads)
for message in payloads.get("messages", []):
if message.get("role") == "assistant":
message.pop("reasoning_content", None)
message.pop("reasoning", None)
@@ -154,14 +154,6 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
audio_stream = self._call_tts_stream(text) audio_stream = self._call_tts_stream(text)
audio = await self._audio_play(audio_stream) audio = await self._audio_play(audio_stream)
# 检查音频数据是否为空
if not audio or len(audio) == 0:
raise Exception(
"MiniMax TTS API returned empty audio data. "
"Please verify your configuration, especially the 'group_id' parameter. "
"You can find your group_id in Account Management -> Basic Information on the MiniMax platform."
)
# 结果保存至文件 # 结果保存至文件
with open(path, "wb") as file: with open(path, "wb") as file:
file.write(audio) file.write(audio)
@@ -169,4 +161,4 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return path return path
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
raise Exception(f"MiniMax TTS API request failed: {e!s}") raise e
@@ -40,46 +40,25 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
async def get_embedding(self, text: str) -> list[float]: async def get_embedding(self, text: str) -> list[float]:
"""获取文本的嵌入""" """获取文本的嵌入"""
kwargs = self._embedding_kwargs()
embedding = await self.client.embeddings.create( embedding = await self.client.embeddings.create(
input=text, input=text,
model=self.model, model=self.model,
**kwargs, dimensions=self.get_dim(),
) )
return embedding.data[0].embedding return embedding.data[0].embedding
async def get_embeddings(self, text: list[str]) -> list[list[float]]: async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入""" """批量获取文本的嵌入"""
kwargs = self._embedding_kwargs()
embeddings = await self.client.embeddings.create( embeddings = await self.client.embeddings.create(
input=text, input=text,
model=self.model, model=self.model,
**kwargs, dimensions=self.get_dim(),
) )
return [item.embedding for item in embeddings.data] return [item.embedding for item in embeddings.data]
def _embedding_kwargs(self) -> dict:
"""构建嵌入请求的可选参数"""
kwargs = {}
if "embedding_dimensions" in self.provider_config:
try:
kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
return kwargs
def get_dim(self) -> int: def get_dim(self) -> int:
"""获取向量的维度""" """获取向量的维度"""
if "embedding_dimensions" in self.provider_config: return int(self.provider_config.get("embedding_dimensions", 1024))
try:
return int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
return 0
async def terminate(self): async def terminate(self):
if self.client: if self.client:
@@ -311,7 +311,7 @@ class ProviderOpenAIOfficial(Provider):
state.handle_chunk(chunk) state.handle_chunk(chunk)
except Exception as e: except Exception as e:
logger.warning("Saving chunk state error: " + str(e)) logger.warning("Saving chunk state error: " + str(e))
if not chunk.choices: if len(chunk.choices) == 0:
continue continue
delta = chunk.choices[0].delta delta = chunk.choices[0].delta
# logger.debug(f"chunk delta: {delta}") # logger.debug(f"chunk delta: {delta}")
@@ -322,7 +322,7 @@ class ProviderOpenAIOfficial(Provider):
if reasoning: if reasoning:
llm_response.reasoning_content = reasoning llm_response.reasoning_content = reasoning
_y = True _y = True
if delta and delta.content: if delta.content:
# Don't strip streaming chunks to preserve spaces between words # Don't strip streaming chunks to preserve spaces between words
completion_text = self._normalize_content(delta.content, strip=False) completion_text = self._normalize_content(delta.content, strip=False)
llm_response.result_chain = MessageChain( llm_response.result_chain = MessageChain(
@@ -345,7 +345,7 @@ class ProviderOpenAIOfficial(Provider):
) -> str: ) -> str:
"""Extract reasoning content from OpenAI ChatCompletion if available.""" """Extract reasoning content from OpenAI ChatCompletion if available."""
reasoning_text = "" reasoning_text = ""
if not completion.choices: if len(completion.choices) == 0:
return reasoning_text return reasoning_text
if isinstance(completion, ChatCompletion): if isinstance(completion, ChatCompletion):
choice = completion.choices[0] choice = completion.choices[0]
@@ -468,7 +468,7 @@ class ProviderOpenAIOfficial(Provider):
"""Parse OpenAI ChatCompletion into LLMResponse""" """Parse OpenAI ChatCompletion into LLMResponse"""
llm_response = LLMResponse("assistant") llm_response = LLMResponse("assistant")
if not completion.choices: if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。") raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0] choice = completion.choices[0]
@@ -629,7 +629,6 @@ class ProviderOpenAIOfficial(Provider):
# 最后一次不等待 # 最后一次不等待
if retry_cnt < max_retries - 1: if retry_cnt < max_retries - 1:
await asyncio.sleep(1) await asyncio.sleep(1)
if chosen_key in available_api_keys:
available_api_keys.remove(chosen_key) available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0: if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys) chosen_key = random.choice(available_api_keys)
@@ -16,7 +16,4 @@ class ProviderOpenRouter(ProviderOpenAIOfficial):
self.client._custom_headers["HTTP-Referer"] = ( # type: ignore self.client._custom_headers["HTTP-Referer"] = ( # type: ignore
"https://github.com/AstrBotDevs/AstrBot" "https://github.com/AstrBotDevs/AstrBot"
) )
self.client._custom_headers["X-OpenRouter-Title"] = "AstrBot" # type: ignore self.client._custom_headers["X-TITLE"] = "AstrBot" # type: ignore
self.client._custom_headers["X-OpenRouter-Categories"] = (
"general-chat,personal-agent" # type: ignore
)
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...") logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
# 将模型加载放到线程池中执行 # 将模型加载放到线程池中执行
self.model = await asyncio.get_running_loop().run_in_executor( self.model = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16), lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
) )
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
audio_url = output_path audio_url = output_path
# 使用 run_in_executor 来调用模型进行识别 # 使用 run_in_executor 来调用模型进行识别
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
res = await loop.run_in_executor( res = await loop.run_in_executor(
None, # 使用默认的线程池 None, # 使用默认的线程池
lambda: cast(SenseVoiceSmall, self.model)( lambda: cast(SenseVoiceSmall, self.model)(
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
self.model = None self.model = None
async def initialize(self) -> None: async def initialize(self) -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor( self.model = await loop.run_in_executor(
None, None,
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
return False return False
async def get_text(self, audio_url: str) -> str: async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
is_tencent = False is_tencent = False
+14 -112
View File
@@ -3,7 +3,6 @@ from __future__ import annotations
import json import json
import os import os
import re import re
import shlex
import shutil import shutil
import tempfile import tempfile
import zipfile import zipfile
@@ -11,8 +10,6 @@ from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path, PurePosixPath from pathlib import Path, PurePosixPath
import yaml
from astrbot.core.utils.astrbot_path import ( from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path, get_astrbot_data_path,
get_astrbot_skills_path, get_astrbot_skills_path,
@@ -29,13 +26,6 @@ _SANDBOX_SKILLS_CACHE_VERSION = 1
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") _SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
def _is_ignored_zip_entry(name: str) -> bool:
parts = PurePosixPath(name).parts
if not parts:
return True
return parts[0] == "__MACOSX"
@dataclass @dataclass
class SkillInfo: class SkillInfo:
name: str name: str
@@ -71,76 +61,18 @@ def _parse_frontmatter_description(text: str) -> str:
break break
if end_idx is None: if end_idx is None:
return "" return ""
for line in lines[1:end_idx]:
frontmatter = "\n".join(lines[1:end_idx]) if ":" not in line:
try: continue
payload = yaml.safe_load(frontmatter) or {} key, value = line.split(":", 1)
except yaml.YAMLError: if key.strip().lower() == "description":
return value.strip().strip('"').strip("'")
return "" return ""
if not isinstance(payload, dict):
return ""
description = payload.get("description", "")
if not isinstance(description, str):
return ""
return description.strip()
# Regex for sanitizing paths used in prompt examples — only allow # Regex for sanitizing paths used in prompt examples — only allow
# safe path characters to prevent prompt injection via crafted skill paths. # safe path characters to prevent prompt injection via crafted skill paths.
_SAFE_PATH_RE = re.compile(r"[^\w./ ,()'\-]", re.UNICODE) _SAFE_PATH_RE = re.compile(r"[^A-Za-z0-9_./ -]")
_WINDOWS_DRIVE_PATH_RE = re.compile(r"^[A-Za-z]:(?:/|\\)")
_WINDOWS_UNC_PATH_RE = re.compile(r"^(//|\\\\)[^/\\]+[/\\][^/\\]+")
_CONTROL_CHARS_RE = re.compile(r"[\x00-\x1F\x7F]")
def _is_windows_prompt_path(path: str) -> bool:
if os.name != "nt":
return False
return bool(_WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path))
def _sanitize_prompt_path_for_prompt(path: str) -> str:
if not path:
return ""
if _WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path):
path = path.replace("\\", "/")
drive_prefix = ""
if _WINDOWS_DRIVE_PATH_RE.match(path):
drive_prefix = path[:2]
path = path[2:]
path = path.replace("`", "")
path = _CONTROL_CHARS_RE.sub("", path)
sanitized = _SAFE_PATH_RE.sub("", path)
return f"{drive_prefix}{sanitized}"
def _sanitize_prompt_description(description: str) -> str:
description = description.replace("`", "")
description = _CONTROL_CHARS_RE.sub(" ", description)
description = " ".join(description.split())
return description
def _sanitize_skill_display_name(name: str) -> str:
if _SKILL_NAME_RE.fullmatch(name):
return name
return "<invalid_skill_name>"
def _build_skill_read_command_example(path: str) -> str:
if path == "<skills_root>/<skill_name>/SKILL.md":
return f"cat {path}"
if _is_windows_prompt_path(path):
command = "type"
path_arg = f'"{os.path.normpath(path)}"'
else:
command = "cat"
path_arg = shlex.quote(path)
return f"{command} {path_arg}"
def build_skills_prompt(skills: list[SkillInfo]) -> str: def build_skills_prompt(skills: list[SkillInfo]) -> str:
@@ -153,37 +85,16 @@ def build_skills_prompt(skills: list[SkillInfo]) -> str:
skills_lines: list[str] = [] skills_lines: list[str] = []
example_path = "" example_path = ""
for skill in skills: for skill in skills:
display_name = _sanitize_skill_display_name(skill.name)
description = skill.description or "No description" description = skill.description or "No description"
if skill.source_type == "sandbox_only":
description = _sanitize_prompt_description(description)
if not description:
description = "Read SKILL.md for details."
if skill.source_type == "sandbox_only":
rendered_path = (
f"{str(SANDBOX_WORKSPACE_ROOT)}/{str(SANDBOX_SKILLS_ROOT)}/"
f"{display_name}/SKILL.md"
)
else:
rendered_path = _sanitize_prompt_path_for_prompt(skill.path)
if not rendered_path:
rendered_path = "<skills_root>/<skill_name>/SKILL.md"
skills_lines.append( skills_lines.append(
f"- **{display_name}**: {description}\n File: `{rendered_path}`" f"- **{skill.name}**: {description}\n File: `{skill.path}`"
) )
if not example_path: if not example_path:
example_path = rendered_path example_path = skill.path
skills_block = "\n".join(skills_lines) skills_block = "\n".join(skills_lines)
# Sanitize example_path — it may originate from sandbox cache (untrusted) # Sanitize example_path — it may originate from sandbox cache (untrusted)
if example_path == "<skills_root>/<skill_name>/SKILL.md": example_path = _SAFE_PATH_RE.sub("", example_path) if example_path else ""
example_path = "<skills_root>/<skill_name>/SKILL.md"
else:
example_path = _sanitize_prompt_path_for_prompt(example_path)
example_path = example_path or "<skills_root>/<skill_name>/SKILL.md" example_path = example_path or "<skills_root>/<skill_name>/SKILL.md"
example_command = _build_skill_read_command_example(example_path)
return ( return (
"## Skills\n\n" "## Skills\n\n"
@@ -201,9 +112,8 @@ def build_skills_prompt(skills: list[SkillInfo]) -> str:
"*Never silently skip a matching skill* — either use it or briefly " "*Never silently skip a matching skill* — either use it or briefly "
"explain why you chose not to.\n" "explain why you chose not to.\n"
"3. **Mandatory grounding** — Before executing any skill you MUST " "3. **Mandatory grounding** — Before executing any skill you MUST "
"first read its `SKILL.md` by running a shell command compatible " "first read its `SKILL.md` by running a shell command with the "
"with the current runtime shell and using the **absolute path** " f"**absolute path** shown above (e.g. `cat {example_path}`). "
f"shown above (e.g. `{example_command}`). "
"Never rely on memory or assumptions about a skill's content.\n" "Never rely on memory or assumptions about a skill's content.\n"
"4. **Progressive disclosure** — Load only what is directly " "4. **Progressive disclosure** — Load only what is directly "
"referenced from `SKILL.md`:\n" "referenced from `SKILL.md`:\n"
@@ -491,11 +401,7 @@ class SkillManager:
raise ValueError("Uploaded file is not a valid zip archive.") raise ValueError("Uploaded file is not a valid zip archive.")
with zipfile.ZipFile(zip_path) as zf: with zipfile.ZipFile(zip_path) as zf:
names = [ names = [name.replace("\\", "/") for name in zf.namelist()]
name
for name in (entry.replace("\\", "/") for entry in zf.namelist())
if name and not _is_ignored_zip_entry(name)
]
file_names = [name for name in names if name and not name.endswith("/")] file_names = [name for name in names if name and not name.endswith("/")]
if not file_names: if not file_names:
raise ValueError("Zip archive is empty.") raise ValueError("Zip archive is empty.")
@@ -530,11 +436,7 @@ class SkillManager:
raise ValueError("SKILL.md not found in the skill folder.") raise ValueError("SKILL.md not found in the skill folder.")
with tempfile.TemporaryDirectory(dir=get_astrbot_temp_path()) as tmp_dir: with tempfile.TemporaryDirectory(dir=get_astrbot_temp_path()) as tmp_dir:
for member in zf.infolist(): zf.extractall(tmp_dir)
member_name = member.filename.replace("\\", "/")
if not member_name or _is_ignored_zip_entry(member_name):
continue
zf.extract(member, tmp_dir)
src_dir = Path(tmp_dir) / skill_name src_dir = Path(tmp_dir) / skill_name
if not src_dir.exists(): if not src_dir.exists():
raise ValueError("Skill folder not found after extraction.") raise ValueError("Skill folder not found after extraction.")
+1 -1
View File
@@ -15,4 +15,4 @@ class RegexFilter(HandlerFilter):
self.regex = re.compile(regex) self.regex = re.compile(regex)
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
return bool(self.regex.search(event.get_message_str().strip())) return bool(self.regex.match(event.get_message_str().strip()))
+8 -5
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import re import re
from collections.abc import AsyncGenerator, Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import TYPE_CHECKING, Any
import docstring_parser import docstring_parser
@@ -15,6 +15,9 @@ from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools from astrbot.core.provider.register import llm_tools
if TYPE_CHECKING:
from astrbot.core.astr_agent_context import AstrAgentContext
from ..filter.command import CommandFilter from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter from ..filter.command_group import CommandGroupFilter
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
@@ -616,7 +619,7 @@ class RegisteringAgent:
kwargs["registering_agent"] = self kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs) return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[Any]) -> None: def __init__(self, agent: Agent[AstrAgentContext]) -> None:
self._agent = agent self._agent = agent
@@ -624,7 +627,7 @@ def register_agent(
name: str, name: str,
instruction: str, instruction: str,
tools: list[str | FunctionTool] | None = None, tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[Any] | None = None, run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
): ):
"""注册一个 Agent """注册一个 Agent
@@ -638,12 +641,12 @@ def register_agent(
tools_ = tools or [] tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]): def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[Any] AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent( agent = AstrAgent(
name=name, name=name,
instructions=instruction, instructions=instruction,
tools=tools_, tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[Any](), run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
) )
handoff_tool = HandoffTool(agent=agent) handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable handoff_tool.handler = awaitable
+9 -160
View File
@@ -1,14 +1,12 @@
"""插件的重载、启停、安装、卸载等操作。""" """插件的重载、启停、安装、卸载等操作。"""
import asyncio import asyncio
import contextlib
import functools import functools
import inspect import inspect
import json import json
import logging import logging
import os import os
import sys import sys
import tempfile
import traceback import traceback
from types import ModuleType from types import ModuleType
@@ -16,12 +14,7 @@ import yaml
from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import InvalidVersion, Version from packaging.version import InvalidVersion, Version
from astrbot.core import ( from astrbot.core import logger, pip_installer, sp
DependencyConflictError,
logger,
pip_installer,
sp,
)
from astrbot.core.agent.handoff import FunctionTool, HandoffTool from astrbot.core.agent.handoff import FunctionTool, HandoffTool
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import VERSION from astrbot.core.config.default import VERSION
@@ -31,13 +24,9 @@ from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path, get_astrbot_config_path,
get_astrbot_path, get_astrbot_path,
get_astrbot_plugin_path, get_astrbot_plugin_path,
get_astrbot_temp_path,
) )
from astrbot.core.utils.io import remove_dir from astrbot.core.utils.io import remove_dir
from astrbot.core.utils.metrics import Metric from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.requirements_utils import (
plan_missing_requirements_install,
)
from . import StarMetadata from . import StarMetadata
from .command_management import sync_command_configs from .command_management import sync_command_configs
@@ -59,97 +48,6 @@ class PluginVersionIncompatibleError(Exception):
"""Raised when plugin astrbot_version is incompatible with current AstrBot.""" """Raised when plugin astrbot_version is incompatible with current AstrBot."""
class PluginDependencyInstallError(Exception):
"""Raised when plugin dependency installation fails."""
def __init__(
self,
*,
plugin_label: str,
requirements_path: str,
error: Exception,
) -> None:
message = f"插件 {plugin_label} 依赖安装失败: {error!s}"
super().__init__(message)
self.plugin_label = plugin_label
self.requirements_path = requirements_path
self.error = error
@contextlib.contextmanager
def _temporary_filtered_requirements_file(
*,
install_lines: tuple[str, ...],
):
filtered_requirements_path: str | None = None
temp_dir = get_astrbot_temp_path()
try:
os.makedirs(temp_dir, exist_ok=True)
with tempfile.NamedTemporaryFile(
mode="w",
suffix="_plugin_requirements.txt",
delete=False,
dir=temp_dir,
encoding="utf-8",
) as filtered_requirements_file:
filtered_requirements_file.write("\n".join(install_lines) + "\n")
filtered_requirements_path = filtered_requirements_file.name
yield filtered_requirements_path
finally:
if filtered_requirements_path and os.path.exists(filtered_requirements_path):
try:
os.remove(filtered_requirements_path)
except OSError as exc:
logger.warning(
"删除临时插件依赖文件失败:%s(路径:%s",
exc,
filtered_requirements_path,
)
async def _install_requirements_with_precheck(
*,
plugin_label: str,
requirements_path: str,
) -> None:
install_plan = plan_missing_requirements_install(requirements_path)
if install_plan is None:
logger.info(
f"正在安装插件 {plugin_label} 的依赖库(缺失依赖预检查不可裁剪,回退到完整安装): "
f"{requirements_path}"
)
await pip_installer.install(requirements_path=requirements_path)
return
if not install_plan.missing_names:
logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。")
return
if not install_plan.install_lines:
fallback_reason = install_plan.fallback_reason or "unknown reason"
logger.info(
"检测到插件 %s 缺失依赖,但无法安全裁剪 requirements,回退到完整安装: %s (%s)",
plugin_label,
requirements_path,
fallback_reason,
)
await pip_installer.install(requirements_path=requirements_path)
return
logger.info(
f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: "
f"{requirements_path} -> {sorted(install_plan.missing_names)}"
)
with _temporary_filtered_requirements_file(
install_lines=install_plan.install_lines,
) as filtered_requirements_path:
await pip_installer.install(requirements_path=filtered_requirements_path)
class PluginManager: class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None: def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools from .star_tools import StarTools
@@ -300,36 +198,14 @@ class PluginManager:
to_update.append(p.root_dir_name) to_update.append(p.root_dir_name)
for p in to_update: for p in to_update:
plugin_path = os.path.join(plugin_dir, p) plugin_path = os.path.join(plugin_dir, p)
await self._ensure_plugin_requirements(plugin_path, p) if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
return True pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
async def _ensure_plugin_requirements(
self,
plugin_dir_path: str,
plugin_label: str,
) -> None:
requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
if not os.path.exists(requirements_path):
return
try: try:
await _install_requirements_with_precheck( await pip_installer.install(requirements_path=pth)
plugin_label=plugin_label,
requirements_path=requirements_path,
)
except asyncio.CancelledError:
raise
except DependencyConflictError as e:
logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}")
raise
except Exception as e: except Exception as e:
dependency_error = PluginDependencyInstallError( logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
plugin_label=plugin_label, return True
requirements_path=requirements_path,
error=e,
)
logger.exception(str(dependency_error))
raise dependency_error from e
async def _import_plugin_with_dependency_recovery( async def _import_plugin_with_dependency_recovery(
self, self,
@@ -546,7 +422,7 @@ class PluginManager:
root_dir_name: str, root_dir_name: str,
plugin_dir_path: str, plugin_dir_path: str,
reserved: bool, reserved: bool,
error: BaseException | str, error: Exception | str,
error_trace: str, error_trace: str,
) -> dict: ) -> dict:
record: dict = { record: dict = {
@@ -619,9 +495,6 @@ class PluginManager:
self._cleanup_plugin_state(dir_name) self._cleanup_plugin_state(dir_name)
plugin_path = os.path.join(self.plugin_store_path, dir_name)
await self._ensure_plugin_requirements(plugin_path, dir_name)
success, error = await self.load(specified_dir_name=dir_name) success, error = await self.load(specified_dir_name=dir_name)
if success: if success:
self.failed_plugin_dict.pop(dir_name, None) self.failed_plugin_dict.pop(dir_name, None)
@@ -1205,10 +1078,6 @@ class PluginManager:
# reload the plugin # reload the plugin
dir_name = os.path.basename(plugin_path) dir_name = os.path.basename(plugin_path)
await self._ensure_plugin_requirements(
plugin_path,
dir_name,
)
success, error_message = await self.load( success, error_message = await self.load(
specified_dir_name=dir_name, specified_dir_name=dir_name,
ignore_version_check=ignore_version_check, ignore_version_check=ignore_version_check,
@@ -1448,12 +1317,6 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。") raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin, proxy=proxy) await self.updator.update(plugin, proxy=proxy)
if plugin.root_dir_name:
plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
await self._ensure_plugin_requirements(
plugin_dir_path,
plugin_name,
)
await self.reload(plugin_name) await self.reload(plugin_name)
async def turn_off_plugin(self, plugin_name: str) -> None: async def turn_off_plugin(self, plugin_name: str) -> None:
@@ -1511,23 +1374,10 @@ class PluginManager:
return return
if "__del__" in star_metadata.star_cls_type.__dict__: if "__del__" in star_metadata.star_cls_type.__dict__:
loop = asyncio.get_running_loop() asyncio.get_event_loop().run_in_executor(
future = loop.run_in_executor(
None, None,
star_metadata.star_cls.__del__, star_metadata.star_cls.__del__,
) )
def _log_del_exception(fut: asyncio.Future) -> None:
if fut.cancelled():
return
if (exc := fut.exception()) is not None:
logger.error(
"插件 %s 在 __del__ 中抛出了异常:%r",
star_metadata.name,
exc,
)
future.add_done_callback(_log_del_exception)
elif "terminate" in star_metadata.star_cls_type.__dict__: elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate() await star_metadata.star_cls.terminate()
@@ -1625,7 +1475,6 @@ class PluginManager:
os.remove(zip_file_path) os.remove(zip_file_path)
except BaseException as e: except BaseException as e:
logger.warning(f"删除插件压缩包失败: {e!s}") logger.warning(f"删除插件压缩包失败: {e!s}")
await self._ensure_plugin_requirements(desti_dir, dir_name)
# await self.reload() # await self.reload()
success, error_message = await self.load( success, error_message = await self.load(
specified_dir_name=dir_name, specified_dir_name=dir_name,
+12 -18
View File
@@ -1,15 +1,12 @@
from __future__ import annotations from __future__ import annotations
import copy from typing import Any
from typing import TYPE_CHECKING, Any
from astrbot import logger from astrbot import logger
from astrbot.core.agent.agent import Agent from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.provider.func_tool_manager import FunctionToolManager
if TYPE_CHECKING:
from astrbot.core.persona_mgr import PersonaManager from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.func_tool_manager import FunctionToolManager
class SubAgentOrchestrator: class SubAgentOrchestrator:
@@ -46,10 +43,11 @@ class SubAgentOrchestrator:
continue continue
persona_id = item.get("persona_id") persona_id = item.get("persona_id")
if persona_id is not None: persona_data = None
persona_id = str(persona_id).strip() or None if persona_id:
persona_data = self._persona_mgr.get_persona_v3_by_id(persona_id) try:
if persona_id and persona_data is None: persona_data = await self._persona_mgr.get_persona(persona_id)
except StopIteration:
logger.warning( logger.warning(
"SubAgent persona %s not found, fallback to inline prompt.", "SubAgent persona %s not found, fallback to inline prompt.",
persona_id, persona_id,
@@ -64,15 +62,11 @@ class SubAgentOrchestrator:
begin_dialogs = None begin_dialogs = None
if persona_data: if persona_data:
prompt = str(persona_data.get("prompt", "")).strip() instructions = persona_data.system_prompt or instructions
if prompt: begin_dialogs = persona_data.begin_dialogs
instructions = prompt tools = persona_data.tools
begin_dialogs = copy.deepcopy( if public_description == "" and persona_data.system_prompt:
persona_data.get("_begin_dialogs_processed") public_description = persona_data.system_prompt[:120]
)
tools = persona_data.get("tools")
if public_description == "" and prompt:
public_description = prompt[:120]
if tools is None: if tools is None:
tools = None tools = None
elif not isinstance(tools, list): elif not isinstance(tools, list):
+1 -1
View File
@@ -30,7 +30,7 @@ class CreateActiveCronTool(FunctionTool[AstrAgentContext]):
"properties": { "properties": {
"cron_expression": { "cron_expression": {
"type": "string", "type": "string",
"description": "Cron expression defining recurring schedule (e.g., '0 8 * * *' or '0 23 * * mon-fri'). Prefer named weekdays like 'mon-fri' or 'sat,sun' instead of numeric day-of-week ranges such as '1-5' to avoid ambiguity across cron implementations.", "description": "Cron expression defining recurring schedule (e.g., '0 8 * * *').",
}, },
"run_at": { "run_at": {
"type": "string", "type": "string",
+6 -16
View File
@@ -25,22 +25,12 @@ class UmopConfigRouter:
) )
self.umop_to_conf_id = sp_data self.umop_to_conf_id = sp_data
@staticmethod
def _split_umo(umo: str) -> tuple[str, str, str] | None:
"""将 UMO 拆分为 3 个部分,同时保留 session_id 中的 ':'"""
if not isinstance(umo, str):
return None
parts = umo.split(":", 2)
if len(parts) != 3:
return None
return parts[0], parts[1], parts[2]
def _is_umo_match(self, p1: str, p2: str) -> bool: def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo""" """判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = self._split_umo(p1) p1_ls = p1.split(":")
p2_ls = self._split_umo(p2) p2_ls = p2.split(":")
if p1_ls is None or p2_ls is None: if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式 return False # 非法格式
return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls)) return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls))
@@ -72,7 +62,7 @@ class UmopConfigRouter:
""" """
for part in new_routing: for part in new_routing:
if self._split_umo(part) is None: if not isinstance(part, str) or len(part.split(":")) != 3:
raise ValueError( raise ValueError(
"umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", "umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
) )
@@ -91,7 +81,7 @@ class UmopConfigRouter:
ValueError: 如果 umo 格式不正确 ValueError: 如果 umo 格式不正确
""" """
if self._split_umo(umo) is None: if not isinstance(umo, str) or len(umo.split(":")) != 3:
raise ValueError( raise ValueError(
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
) )
@@ -109,7 +99,7 @@ class UmopConfigRouter:
ValueError: umo 格式不正确时抛出 ValueError: umo 格式不正确时抛出
""" """
if self._split_umo(umo) is None: if not isinstance(umo, str) or len(umo.split(":")) != 3:
raise ValueError( raise ValueError(
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
) )
-121
View File
@@ -1,121 +0,0 @@
import contextlib
import functools
import importlib.metadata as importlib_metadata
import logging
import os
from collections.abc import Iterator
from packaging.requirements import Requirement
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name,
collect_installed_distribution_versions,
get_requirement_check_paths,
)
logger = logging.getLogger("astrbot")
def _resolve_core_dist_name(core_dist_name: str | None) -> str | None:
if core_dist_name:
try:
importlib_metadata.distribution(core_dist_name)
return core_dist_name
except importlib_metadata.PackageNotFoundError:
return None
try:
importlib_metadata.distribution("AstrBot")
return "AstrBot"
except importlib_metadata.PackageNotFoundError:
pass
if not __package__:
return None
top_pkg = __package__.split(".")[0]
for dist in importlib_metadata.distributions():
try:
top_level = dist.read_text("top_level.txt") or ""
except Exception:
continue
if top_pkg in top_level.splitlines():
if "Name" in dist.metadata:
return dist.metadata["Name"]
return None
@functools.cache
def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]:
try:
resolved_core_dist_name = _resolve_core_dist_name(core_dist_name)
except Exception as exc:
logger.warning("解析核心分发名称失败: %s", exc)
return ()
if not resolved_core_dist_name:
return ()
try:
dist = importlib_metadata.distribution(resolved_core_dist_name)
except importlib_metadata.PackageNotFoundError:
return ()
except Exception as exc:
logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc)
return ()
if not dist or not dist.requires:
return ()
installed = collect_installed_distribution_versions(get_requirement_check_paths())
if not installed:
return ()
constraints: list[str] = []
for req_str in dist.requires:
try:
req = Requirement(req_str)
if req.marker and not req.marker.evaluate():
continue
name = canonicalize_distribution_name(req.name)
if name in installed:
constraints.append(f"{name}=={installed[name]}")
except Exception:
continue
return tuple(constraints)
class CoreConstraintsProvider:
def __init__(self, core_dist_name: str | None) -> None:
self._core_dist_name = core_dist_name
@contextlib.contextmanager
def constraints_file(self) -> Iterator[str | None]:
constraints = _get_core_constraints(self._core_dist_name)
if not constraints:
yield None
return
path: str | None = None
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8"
) as f:
f.write("\n".join(constraints))
path = f.name
logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints))
except Exception as exc:
logger.warning("创建临时约束文件失败: %s", exc)
yield None
return
try:
yield path
finally:
if path and os.path.exists(path):
with contextlib.suppress(Exception):
os.remove(path)
+95 -427
View File
@@ -7,71 +7,21 @@ import io
import logging import logging
import os import os
import re import re
import shlex
import sys import sys
import threading import threading
from collections import deque from collections import deque
from dataclasses import dataclass
from urllib.parse import urlparse
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name as _canonicalize_distribution_name,
)
from astrbot.core.utils.requirements_utils import (
extract_requirement_name,
extract_requirement_names,
parse_package_install_input,
)
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
logger = logging.getLogger("astrbot") logger = logging.getLogger("astrbot")
_DISTLIB_FINDER_PATCH_ATTEMPTED = False _DISTLIB_FINDER_PATCH_ATTEMPTED = False
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock() _SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
_PIP_FAILURE_PATTERNS = {
"error_prefix": re.compile(r"^\s*error:", re.IGNORECASE),
"user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE),
"resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE),
"cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE),
"conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE),
"constraint": re.compile(r"\(constraint\)", re.IGNORECASE),
"dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE),
}
_SENSITIVE_PIP_VALUE_KEYS = frozenset(
{"password", "passwd", "pass", "api_token", "token", "auth_token"}
)
_MAX_PIP_OUTPUT_LINES = 200
class DependencyConflictError(Exception): def _canonicalize_distribution_name(name: str) -> str:
"""Raised when pip encounters a dependency conflict.""" return re.sub(r"[-_.]+", "-", name).strip("-").lower()
def __init__(
self, message: str, errors: list[str], *, is_core_conflict: bool
) -> None:
super().__init__(message)
self.errors = errors
self.is_core_conflict = is_core_conflict
class PipInstallError(Exception):
"""Raised when pip install fails without a classified dependency conflict."""
def __init__(self, message: str, *, code: int) -> None:
super().__init__(message)
self.code = code
@dataclass
class PipConflictContext:
relevant_lines: list[str]
requested_lines: list[str]
dependency_detail_lines: list[str]
constraint_lines: list[str]
has_strong_conflict_signal: bool
has_contextual_conflict_signal: bool
def _get_pip_main(): def _get_pip_main():
@@ -91,12 +41,11 @@ def _get_pip_main():
return pip_main return pip_main
def _prepend_sys_path(path: str) -> None: def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]:
normalized_target = os.path.realpath(path) stream = io.StringIO()
sys.path[:] = [ with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
item for item in sys.path if os.path.realpath(item) != normalized_target result_code = pip_main(args)
] return result_code, stream.getvalue()
sys.path.insert(0, normalized_target)
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None: def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
@@ -110,258 +59,76 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No
handler.close() handler.close()
def _get_trusted_host_for_index_url(index_url: str) -> str | None: def _prepend_sys_path(path: str) -> None:
parsed = urlparse(index_url if "://" in index_url else f"//{index_url}") normalized_target = os.path.realpath(path)
host = parsed.hostname sys.path[:] = [
if host == "mirrors.aliyun.com": item for item in sys.path if os.path.realpath(item) != normalized_target
return host ]
return None sys.path.insert(0, normalized_target)
def _normalize_sensitive_pip_key(raw_key: str) -> str: def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
return raw_key.lstrip("-").replace("-", "_").lower() base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
return os.path.isfile(package_init) or os.path.isfile(module_file)
def _is_sensitive_pip_value_key(raw_key: str) -> bool: def _is_module_loaded_from_site_packages(
return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS module_name: str,
site_packages_path: str,
) -> bool:
module = sys.modules.get(module_name)
if module is None:
try:
module = importlib.import_module(module_name)
except Exception:
return False
module_file = getattr(module, "__file__", None)
if not module_file:
return False
def _redact_url_credentials(raw_value: str) -> str: module_path = os.path.realpath(module_file)
"""Redact URL credentials and known inline secret values for safe logging.""" site_packages_real = os.path.realpath(site_packages_path)
parsed = urlparse(raw_value) try:
if parsed.netloc and "@" in parsed.netloc: return (
hostname = parsed.hostname or "" os.path.commonpath([module_path, site_packages_real]) == site_packages_real
port = f":{parsed.port}" if parsed.port else "" )
return parsed._replace(netloc=f"<redacted>@{hostname}{port}").geturl() except ValueError:
if raw_value.startswith("--"):
option, separator, _ = raw_value.partition("=")
if separator and _is_sensitive_pip_value_key(option):
return f"{option}=****"
return raw_value
key, separator, _ = raw_value.partition("=")
if separator and _is_sensitive_pip_value_key(key):
return f"{key}=****"
return raw_value
def _redact_pip_args_for_logging(args: list[str]) -> list[str]:
redacted_args: list[str] = []
redact_next_value = False
for arg in args:
if redact_next_value:
redacted_args.append("****")
redact_next_value = False
continue
if arg.startswith("--") and "=" in arg:
option, value = arg.split("=", 1)
if _is_sensitive_pip_value_key(option):
redacted_args.append(f"{option}=****")
else:
redacted_args.append(f"{option}={_redact_url_credentials(value)}")
continue
if arg.startswith("-i") and arg != "-i":
redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}")
continue
if _is_sensitive_pip_value_key(arg):
redacted_args.append(arg)
redact_next_value = True
continue
redacted_args.append(_redact_url_credentials(arg))
return redacted_args
def _package_specs_override_index(package_specs: list[str]) -> bool:
for index, spec in enumerate(package_specs):
if spec == "--no-index":
return True
if spec in {"-i", "--index-url"}:
if index + 1 < len(package_specs):
return True
continue
if spec.startswith("--index-url="):
return True
if spec.startswith("-i") and spec != "-i":
return True
return False return False
class _StreamingLogWriter(io.TextIOBase): def _extract_requirement_name(raw_requirement: str) -> str | None:
def __init__(self, log_func, *, max_lines: int | None = None) -> None: line = raw_requirement.split("#", 1)[0].strip()
self._log_func = log_func if not line:
self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES) return None
self._buffer = "" if line.startswith(("-r", "--requirement", "-c", "--constraint")):
return None
def write(self, text: str) -> int: if line.startswith("-"):
if not text:
return 0
self._buffer += text.replace("\r\n", "\n").replace("\r", "\n")
while "\n" in self._buffer:
raw_line, self._buffer = self._buffer.split("\n", 1)
line = raw_line.rstrip("\r\n")
self._log_func(line)
self._lines.append(line)
return len(text)
def flush(self) -> None:
line = self._buffer.rstrip("\r\n")
if line:
self._log_func(line)
self._lines.append(line)
self._buffer = ""
@property
def lines(self) -> list[str]:
return list(self._lines)
def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]:
stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES)
with (
contextlib.redirect_stdout(stream),
contextlib.redirect_stderr(stream),
):
result_code = pip_main(args)
stream.flush()
return result_code, stream.lines
def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool:
names = pattern_names or tuple(_PIP_FAILURE_PATTERNS)
return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names)
def _normalize_conflict_detail_line(line: str) -> str:
stripped = line.strip()
if _matches_pip_failure_pattern(stripped, "user_requested"):
return re.sub(
r"^\s*The user requested\s+",
"",
stripped,
flags=re.IGNORECASE,
)
return stripped
def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None:
matched_indices = [
index
for index, line in enumerate(output_lines)
if _matches_pip_failure_pattern(line)
]
if matched_indices:
relevant_index_set: set[int] = set()
for index in matched_indices:
start = max(0, index - 1)
end = min(len(output_lines), index + 2)
relevant_index_set.update(range(start, end))
relevant_output_lines = [
line
for index, line in enumerate(output_lines)
if index in relevant_index_set
]
else:
relevant_output_lines = output_lines[-5:]
if not relevant_output_lines:
return None return None
dependency_detail_lines = [ egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
line.strip() if egg_match:
for line in relevant_output_lines return _canonicalize_distribution_name(egg_match.group(1))
if _matches_pip_failure_pattern(line, "dependency_detail")
]
requested_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "user_requested")
and not _matches_pip_failure_pattern(line, "constraint")
]
if not requested_lines:
requested_lines = [
line
for line in dependency_detail_lines
if not _matches_pip_failure_pattern(line, "constraint")
]
constraint_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "constraint")
]
has_strong_conflict_signal = any( candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
_matches_pip_failure_pattern( if not candidate:
line,
"resolution_impossible",
"cannot_install",
)
for line in relevant_output_lines
)
has_contextual_conflict_signal = any(
_matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines
) and bool(dependency_detail_lines or requested_lines or constraint_lines)
return PipConflictContext(
relevant_lines=relevant_output_lines,
requested_lines=requested_lines,
dependency_detail_lines=dependency_detail_lines,
constraint_lines=constraint_lines,
has_strong_conflict_signal=has_strong_conflict_signal,
has_contextual_conflict_signal=has_contextual_conflict_signal,
)
def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None:
context = _build_pip_conflict_context(output_lines)
if context is None:
return None return None
return _canonicalize_distribution_name(candidate)
if (
not context.has_strong_conflict_signal
and not context.has_contextual_conflict_signal
and not (context.requested_lines and context.constraint_lines)
):
return None
is_core_conflict = bool(context.constraint_lines) def _extract_requirement_names(requirements_path: str) -> set[str]:
names: set[str] = set()
detail = "" try:
if context.constraint_lines and context.requested_lines: with open(requirements_path, encoding="utf-8") as requirements_file:
detail = ( for line in requirements_file:
" 冲突详情: " requirement_name = _extract_requirement_name(line)
f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs " if requirement_name:
f"{_normalize_conflict_detail_line(context.constraint_lines[0])}" names.add(requirement_name)
) except Exception as exc:
elif len(context.dependency_detail_lines) >= 2: logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
detail = ( return names
" 冲突详情: "
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs "
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}"
)
if is_core_conflict:
message = (
f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容,"
"为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。"
)
else:
message = f"检测到依赖冲突。{detail}"
return DependencyConflictError(
message,
context.relevant_lines,
is_core_conflict=is_core_conflict,
)
def _extract_top_level_modules( def _extract_top_level_modules(
@@ -388,11 +155,7 @@ def _collect_candidate_modules(
by_name: dict[str, list[importlib_metadata.Distribution]] = {} by_name: dict[str, list[importlib_metadata.Distribution]] = {}
try: try:
for distribution in importlib_metadata.distributions(path=[site_packages_path]): for distribution in importlib_metadata.distributions(path=[site_packages_path]):
distribution_name = ( distribution_name = distribution.metadata.get("Name")
distribution.metadata["Name"]
if "Name" in distribution.metadata
else None
)
if not distribution_name: if not distribution_name:
continue continue
canonical_name = _canonicalize_distribution_name(distribution_name) canonical_name = _canonicalize_distribution_name(distribution_name)
@@ -410,7 +173,7 @@ def _collect_candidate_modules(
for distribution in by_name.get(requirement_name, []): for distribution in by_name.get(requirement_name, []):
for dependency_line in distribution.requires or []: for dependency_line in distribution.requires or []:
dependency_name = extract_requirement_name(dependency_line) dependency_name = _extract_requirement_name(dependency_line)
if not dependency_name: if not dependency_name:
continue continue
if dependency_name in expanded_requirement_names: if dependency_name in expanded_requirement_names:
@@ -467,38 +230,6 @@ def _ensure_preferred_modules(
raise RuntimeError(conflict_message) raise RuntimeError(conflict_message)
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
return os.path.isfile(package_init) or os.path.isfile(module_file)
def _is_module_loaded_from_site_packages(
module_name: str,
site_packages_path: str,
) -> bool:
module = sys.modules.get(module_name)
if module is None:
try:
module = importlib.import_module(module_name)
except Exception:
return False
module_file = getattr(module, "__file__", None)
if not module_file:
return False
module_path = os.path.realpath(module_file)
site_packages_real = os.path.realpath(site_packages_path)
try:
return (
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
)
except ValueError:
return False
def _prefer_module_from_site_packages( def _prefer_module_from_site_packages(
module_name: str, site_packages_path: str module_name: str, site_packages_path: str
) -> bool: ) -> bool:
@@ -800,63 +531,9 @@ def _patch_distlib_finder_for_frozen_runtime() -> None:
class PipInstaller: class PipInstaller:
def __init__( def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None:
self,
pip_install_arg: str,
pypi_index_url: str | None = None,
core_dist_name: str | None = "AstrBot",
) -> None:
self.pip_install_arg = pip_install_arg self.pip_install_arg = pip_install_arg
self.pypi_index_url = pypi_index_url self.pypi_index_url = pypi_index_url
self.core_dist_name = core_dist_name
self._core_constraints = CoreConstraintsProvider(core_dist_name)
def _build_pip_args(
self,
package_name: str | None,
requirements_path: str | None,
mirror: str | None,
) -> tuple[list[str], set[str]]:
args: list[str] = []
requested_requirements: set[str] = set()
normalized_requirements_path = (
requirements_path.strip() if requirements_path else ""
)
if package_name and normalized_requirements_path:
raise ValueError(
"package_name and requirements_path cannot be used together"
)
if package_name:
parsed_package = parse_package_install_input(package_name)
if parsed_package.specs:
args = ["install", *parsed_package.specs]
requested_requirements = set(parsed_package.requirement_names)
elif normalized_requirements_path:
args = ["install", "-r", normalized_requirements_path]
requested_requirements = extract_requirement_names(
normalized_requirements_path
)
if not args:
return [], requested_requirements
pip_install_args = (
shlex.split(self.pip_install_arg) if self.pip_install_arg else []
)
if not _package_specs_override_index([*args[1:], *pip_install_args]):
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
trusted_host = _get_trusted_host_for_index_url(index_url)
if trusted_host:
args.extend(["--trusted-host", trusted_host])
args.extend(["-i", index_url])
if pip_install_args:
args.extend(pip_install_args)
return args, requested_requirements
async def install( async def install(
self, self,
@@ -864,37 +541,36 @@ class PipInstaller:
requirements_path: str | None = None, requirements_path: str | None = None,
mirror: str | None = None, mirror: str | None = None,
) -> None: ) -> None:
args, requested_requirements = self._build_pip_args( args = ["install"]
package_name, requirements_path, mirror requested_requirements: set[str] = set()
) if package_name:
if not args: args.append(package_name)
logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。") requirement_name = _extract_requirement_name(package_name)
return if requirement_name:
requested_requirements.add(requirement_name)
elif requirements_path:
args.extend(["-r", requirements_path])
requested_requirements = _extract_requirement_names(requirements_path)
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
target_site_packages = None target_site_packages = None
if is_packaged_desktop_runtime(): if is_packaged_desktop_runtime():
target_site_packages = get_astrbot_site_packages_path() target_site_packages = get_astrbot_site_packages_path()
os.makedirs(target_site_packages, exist_ok=True) os.makedirs(target_site_packages, exist_ok=True)
_prepend_sys_path(target_site_packages) _prepend_sys_path(target_site_packages)
args.extend( args.extend(["--target", target_site_packages])
[ args.extend(["--upgrade", "--force-reinstall"])
"--target",
target_site_packages,
"--upgrade",
"--upgrade-strategy",
"only-if-needed",
]
)
with self._core_constraints.constraints_file() as constraints_file_path: if self.pip_install_arg:
if constraints_file_path: args.extend(self.pip_install_arg.split())
args.extend(["-c", constraints_file_path])
logger.info( logger.info(f"Pip 包管理器: pip {' '.join(args)}")
"Pip 包管理器 argv: %s", result_code = await self._run_pip_in_process(args)
["pip", *_redact_pip_args_for_logging(args)],
) if result_code != 0:
await self._run_pip_with_classification(args) raise Exception(f"安装失败,错误码:{result_code}")
if target_site_packages: if target_site_packages:
_prepend_sys_path(target_site_packages) _prepend_sys_path(target_site_packages)
@@ -913,7 +589,7 @@ class PipInstaller:
if not os.path.isdir(target_site_packages): if not os.path.isdir(target_site_packages):
return return
requested_requirements = extract_requirement_names(requirements_path) requested_requirements = _extract_requirement_names(requirements_path)
if not requested_requirements: if not requested_requirements:
return return
@@ -929,21 +605,13 @@ class PipInstaller:
_patch_distlib_finder_for_frozen_runtime() _patch_distlib_finder_for_frozen_runtime()
original_handlers = list(logging.getLogger().handlers) original_handlers = list(logging.getLogger().handlers)
try: result_code, output = await asyncio.to_thread(
result_code, output_lines = await asyncio.to_thread( _run_pip_main_with_output, pip_main, args
_run_pip_main_streaming, pip_main, args
) )
finally: for line in output.splitlines():
line = line.strip()
if line:
logger.info(line)
_cleanup_added_root_handlers(original_handlers) _cleanup_added_root_handlers(original_handlers)
if result_code != 0:
conflict = _classify_pip_failure(output_lines)
if conflict:
raise conflict
return result_code return result_code
async def _run_pip_with_classification(self, args: list[str]) -> None:
result_code = await self._run_pip_in_process(args)
if result_code != 0:
raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code)
-486
View File
@@ -1,486 +0,0 @@
import importlib.metadata as importlib_metadata
import logging
import os
import re
import shlex
import sys
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from packaging.requirements import InvalidRequirement, Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
logger = logging.getLogger("astrbot")
class RequirementsPrecheckFailed(Exception):
"""Raised when the pre-check of requirements fails."""
pass
@dataclass(frozen=True)
class ParsedPackageInput:
specs: tuple[str, ...]
requirement_names: frozenset[str]
@dataclass(frozen=True)
class MissingRequirementsPlan:
missing_names: frozenset[str]
install_lines: tuple[str, ...]
fallback_reason: str | None = None
def canonicalize_distribution_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
def strip_inline_requirement_comment(raw_input: str) -> str:
if raw_input.lstrip().startswith("#"):
return ""
return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip()
def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool:
try:
parsed_version = Version(version)
except InvalidVersion:
return False
return specifier.contains(parsed_version, prereleases=True)
def _looks_like_local_path_reference(token: str) -> bool:
candidate = token.strip()
if not candidate:
return False
return candidate in {".", ".."} or candidate.startswith(
("./", "../", "/", "~/", ".\\", "..\\", "\\")
)
def looks_like_direct_reference(token: str) -> bool:
candidate = token.strip()
if not candidate:
return False
return (
_looks_like_local_path_reference(candidate)
or candidate.startswith("git+")
or "://" in candidate
)
def extract_requirement_name(raw_requirement: str) -> str | None:
line = raw_requirement.split("#", 1)[0].strip()
if not line:
return None
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
return None
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
if egg_match:
return canonicalize_distribution_name(egg_match.group(1))
if line.startswith("-"):
return None
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
if not candidate:
return None
return canonicalize_distribution_name(candidate)
def _parse_editable_or_direct_name(target: str) -> str | None:
name = extract_requirement_name(target)
if not name:
return None
if "#egg=" in target or not looks_like_direct_reference(target):
return name
return None
def _parse_requirement_name_and_spec(
line: str,
) -> tuple[str | None, SpecifierSet | None]:
if line.startswith(("-c", "--constraint")):
return None, None
try:
req = Requirement(line)
except InvalidRequirement:
tokens = shlex.split(line)
if not tokens:
return None, None
editable_target: str | None = None
if tokens[0] in {"-e", "--editable"} and len(tokens) > 1:
editable_target = tokens[1]
elif tokens[0].startswith("--editable="):
editable_target = tokens[0].split("=", 1)[1]
if editable_target:
name = _parse_editable_or_direct_name(editable_target)
return (name, None) if name else (None, None)
name = _parse_editable_or_direct_name(line)
return (name, None) if name else (None, None)
if req.marker and not req.marker.evaluate():
return None, None
return canonicalize_distribution_name(req.name), (req.specifier or None)
def _parse_requirement_line(
line: str,
) -> tuple[str, SpecifierSet | None] | None:
name, specifier = _parse_requirement_name_and_spec(line)
return (name, specifier) if name else None
def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]:
requirement_names: set[str] = set()
skip_next_for: str | None = None
for token in tokens:
if skip_next_for:
if skip_next_for == "editable":
name = _parse_editable_or_direct_name(token)
if name:
requirement_names.add(name)
skip_next_for = None
continue
if token in {"-e", "--editable"}:
skip_next_for = "editable"
continue
if token in {
"-i",
"--index-url",
"--extra-index-url",
"-f",
"--find-links",
"--trusted-host",
"-r",
"--requirement",
"-c",
"--constraint",
}:
skip_next_for = "option-value"
continue
if token.startswith(("--editable=",)):
editable_target = token.split("=", 1)[1]
name = _parse_editable_or_direct_name(editable_target)
if name:
requirement_names.add(name)
continue
if token.startswith(
(
"--index-url=",
"--extra-index-url=",
"--find-links=",
"--trusted-host=",
"--requirement=",
"--constraint=",
)
):
continue
if (
(token.startswith("-i") and token != "-i")
or (token.startswith("-f") and token != "-f")
or token == "--no-index"
):
continue
if token.startswith("-"):
continue
name, _ = _parse_requirement_name_and_spec(token)
if name:
requirement_names.add(name)
return frozenset(requirement_names)
def parse_package_install_input(raw_input: str) -> ParsedPackageInput:
specs: list[str] = []
requirement_names: set[str] = set()
normalized = raw_input.strip()
if not normalized:
return ParsedPackageInput(specs=(), requirement_names=frozenset())
for raw_line in normalized.splitlines():
line = strip_inline_requirement_comment(raw_line)
if not line:
continue
try:
Requirement(line)
except InvalidRequirement:
tokens = shlex.split(line)
if not tokens:
continue
specs.extend(tokens)
requirement_names.update(
_extract_requirement_names_from_package_tokens(tokens)
)
continue
specs.append(line)
name, _ = _parse_requirement_name_and_spec(line)
if name:
requirement_names.add(name)
return ParsedPackageInput(
specs=tuple(specs),
requirement_names=frozenset(requirement_names),
)
def _iter_requirement_lines(
requirements_path: str,
_visited: set[str] | None = None,
) -> Iterator[str]:
visited = _visited or set()
resolved_path = os.path.realpath(requirements_path)
if resolved_path in visited:
logger.warning(
"检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path
)
return
visited.add(resolved_path)
with open(resolved_path, encoding="utf-8") as f:
for raw_line in f:
line = strip_inline_requirement_comment(raw_line)
if not line:
continue
tokens = shlex.split(line)
if not tokens:
continue
nested: str | None = None
if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1:
nested = tokens[1]
elif tokens[0].startswith("--requirement="):
nested = tokens[0].split("=", 1)[1]
if nested:
if not os.path.isabs(nested):
nested = os.path.join(os.path.dirname(resolved_path), nested)
yield from _iter_requirement_lines(nested, _visited=visited)
continue
yield line
def iter_requirements(
requirements_path: str | None = None,
lines: Iterable[str] | None = None,
) -> Iterator[tuple[str, SpecifierSet | None]]:
if lines is None:
if requirements_path is None:
raise ValueError("Either requirements_path or lines must be provided")
lines = _iter_requirement_lines(requirements_path)
for line in lines:
parsed = _parse_requirement_line(line)
if parsed is not None:
yield parsed
def extract_requirement_names(requirements_path: str) -> set[str]:
try:
return {
name for name, _ in iter_requirements(requirements_path=requirements_path)
}
except Exception as exc:
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
return set()
def get_requirement_check_paths() -> list[str]:
paths = list(sys.path)
if is_packaged_desktop_runtime():
target_site_packages = get_astrbot_site_packages_path()
if os.path.isdir(target_site_packages):
paths.insert(0, target_site_packages)
return paths
def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
distribution_name = (
distribution.metadata["Name"] if "Name" in distribution.metadata else None
)
if not distribution_name:
return None, None
return canonicalize_distribution_name(distribution_name), distribution.version
def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None:
installed: dict[str, str] = {}
try:
for distribution in importlib_metadata.distributions(path=paths):
distribution_name, version = _canonical_distribution_identity(distribution)
if not distribution_name or not version:
continue
installed.setdefault(distribution_name, version)
except Exception as exc:
logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc)
return None
return installed
def _load_requirement_lines_for_precheck(
requirements_path: str,
) -> tuple[bool, list[str] | None]:
try:
requirement_lines = list(_iter_requirement_lines(requirements_path))
except Exception as exc:
logger.warning(
"预检查缺失依赖失败,将回退到完整安装: %s (%s)",
requirements_path,
exc,
)
return False, None
fallback_line = next(
(
line
for line in requirement_lines
if (
(
line.startswith(("-e ", "--editable ", "--editable="))
and "#egg=" not in line
)
or (
_parse_requirement_line(line) is None
and looks_like_direct_reference(line)
)
)
),
None,
)
if fallback_line is not None:
logger.info(
"缺失依赖预检查发现无法安全裁剪的 option/direct-reference 行,将回退到完整安装: %s (%s)",
requirements_path,
fallback_line,
)
return False, None
return True, requirement_lines
def find_missing_requirements(requirements_path: str) -> set[str] | None:
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
requirements_path
)
if not can_precheck or requirement_lines is None:
return None
return find_missing_requirements_from_lines(requirement_lines)
def find_missing_requirements_from_lines(
requirement_lines: Sequence[str],
) -> set[str] | None:
required = list(iter_requirements(lines=requirement_lines))
if not required:
return set()
installed = collect_installed_distribution_versions(get_requirement_check_paths())
if installed is None:
return None
missing: set[str] = set()
for name, specifier in required:
installed_version = installed.get(name)
if not installed_version:
missing.add(name)
continue
if specifier and not _specifier_contains_version(specifier, installed_version):
missing.add(name)
return missing
def build_missing_requirements_install_lines(
requirements_path: str,
requirement_lines: Sequence[str],
missing_names: set[str] | frozenset[str],
) -> tuple[str, ...] | None:
wanted_names = set(missing_names)
install_lines: list[str] = []
for line in requirement_lines:
parsed = _parse_requirement_line(line)
if parsed is None:
if looks_like_direct_reference(line) or line.startswith(("-", "--")):
logger.debug(
"缺失依赖行筛选回退到完整安装:requirements 中包含无法安全裁剪的 option/direct-reference 行: %s (%s)",
requirements_path,
line,
)
return None
continue
name, _specifier = parsed
if name in wanted_names:
install_lines.append(line)
return tuple(install_lines)
def plan_missing_requirements_install(
requirements_path: str,
) -> MissingRequirementsPlan | None:
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
requirements_path
)
if not can_precheck or requirement_lines is None:
return None
missing = find_missing_requirements_from_lines(requirement_lines)
if missing is None:
return None
install_lines = build_missing_requirements_install_lines(
requirements_path,
requirement_lines,
missing,
)
if install_lines is None:
return None
if missing and not install_lines:
logger.warning(
"预检查缺失依赖成功,但无法映射到可安装 requirement 行,将回退到完整安装: %s -> %s",
requirements_path,
sorted(missing),
)
return MissingRequirementsPlan(
missing_names=frozenset(missing),
install_lines=(),
fallback_reason="unmapped missing requirement names",
)
return MissingRequirementsPlan(
missing_names=frozenset(missing),
install_lines=install_lines,
)
def find_missing_requirements_or_raise(requirements_path: str) -> set[str]:
missing = find_missing_requirements(requirements_path)
if missing is None:
raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}")
return missing
+1 -1
View File
@@ -7,4 +7,4 @@ def is_frozen_runtime() -> bool:
def is_packaged_desktop_runtime() -> bool: def is_packaged_desktop_runtime() -> bool:
return os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1" return is_frozen_runtime() and os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
+1 -27
View File
@@ -1,13 +1,9 @@
import asyncio import asyncio
import threading
import weakref
from collections import defaultdict from collections import defaultdict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
class _PerLoopSessionLockManager: class SessionLockManager:
"""Per-event-loop session lock manager; keeps original simple semantics."""
def __init__(self) -> None: def __init__(self) -> None:
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._lock_count: dict[str, int] = defaultdict(int) self._lock_count: dict[str, int] = defaultdict(int)
@@ -30,26 +26,4 @@ class _PerLoopSessionLockManager:
self._lock_count.pop(session_id, None) self._lock_count.pop(session_id, None)
class SessionLockManager:
"""Thread-safe session lock manager with per-event-loop isolation."""
def __init__(self) -> None:
self._state_guard = threading.Lock()
self._loop_managers: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
] = weakref.WeakKeyDictionary()
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
"""Get the lock manager for the current event loop."""
loop = asyncio.get_running_loop()
with self._state_guard:
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
@asynccontextmanager
async def acquire_lock(self, session_id: str):
manager = self._get_loop_manager()
async with manager.acquire_lock(session_id):
yield
session_lock_manager = SessionLockManager() session_lock_manager = SessionLockManager()
+1 -2
View File
@@ -82,8 +82,7 @@ class AuthRoute(Route):
def generate_jwt(self, username): def generate_jwt(self, username):
payload = { payload = {
"username": username, "username": username,
"exp": datetime.datetime.now(datetime.timezone.utc) "exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
+ datetime.timedelta(days=7),
} }
jwt_token = self.config["dashboard"].get("jwt_secret", None) jwt_token = self.config["dashboard"].get("jwt_secret", None)
if not jwt_token: if not jwt_token:
+1 -11
View File
@@ -977,17 +977,7 @@ class BackupRoute(Route):
if not jwt_secret: if not jwt_secret:
return Response().error("服务器配置错误").__dict__ return Response().error("服务器配置错误").__dict__
# Verify JWT token with strict security options jwt.decode(token, jwt_secret, algorithms=["HS256"])
jwt.decode(
token,
jwt_secret,
algorithms=["HS256"],
options={
"require": ["exp"], # Require expiration claim
"verify_signature": True, # Explicitly verify signature
"verify_exp": True, # Verify expiration
},
)
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
return Response().error("Token 已过期,请刷新页面后重试").__dict__ return Response().error("Token 已过期,请刷新页面后重试").__dict__
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
+22 -85
View File
@@ -36,20 +36,6 @@ async def track_conversation(convs: dict, conv_id: str):
convs.pop(conv_id, None) convs.pop(conv_id, None)
async def _poll_webchat_stream_result(back_queue, username: str):
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
return None, False
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
return None, True
except Exception as e:
logger.error(f"WebChat stream error: {e}")
return None, False
return result, False
class ChatRoute(Route): class ChatRoute(Route):
def __init__( def __init__(
self, self,
@@ -65,7 +51,6 @@ class ChatRoute(Route):
"/chat/get_session": ("GET", self.get_session), "/chat/get_session": ("GET", self.get_session),
"/chat/stop": ("POST", self.stop_session), "/chat/stop": ("POST", self.stop_session),
"/chat/delete_session": ("GET", self.delete_webchat_session), "/chat/delete_session": ("GET", self.delete_webchat_session),
"/chat/batch_delete_sessions": ("POST", self.batch_delete_sessions),
"/chat/update_session_display_name": ( "/chat/update_session_display_name": (
"POST", "POST",
self.update_session_display_name, self.update_session_display_name,
@@ -357,12 +342,16 @@ class ChatRoute(Route):
async with track_conversation(self.running_convs, webchat_conv_id): async with track_conversation(self.running_convs, webchat_conv_id):
while True: while True:
result, should_break = await _poll_webchat_stream_result( try:
back_queue, username result = await asyncio.wait_for(back_queue.get(), timeout=1)
) except asyncio.TimeoutError:
if should_break: continue
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True client_disconnected = True
break except Exception as e:
logger.error(f"WebChat stream error: {e}")
if not result: if not result:
continue continue
@@ -589,9 +578,19 @@ class ChatRoute(Route):
return Response().ok(data={"stopped_count": stopped_count}).__dict__ return Response().ok(data={"stopped_count": stopped_count}).__dict__
async def _delete_session_internal(self, session, username: str) -> None: async def delete_webchat_session(self):
"""Delete a single session and all its related data.""" """Delete a Platform session and all its related data."""
session_id = session.session_id session_id = request.args.get("session_id")
if not session_id:
return Response().error("Missing key: session_id").__dict__
username = g.get("username", "guest")
# 验证会话是否存在且属于当前用户
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
# 删除该会话下的所有对话 # 删除该会话下的所有对话
message_type = "GroupMessage" if session.is_group else "FriendMessage" message_type = "GroupMessage" if session.is_group else "FriendMessage"
@@ -633,70 +632,8 @@ class ChatRoute(Route):
# 删除会话 # 删除会话
await self.db.delete_platform_session(session_id) await self.db.delete_platform_session(session_id)
async def delete_webchat_session(self):
"""Delete a Platform session and all its related data."""
session_id = request.args.get("session_id")
if not session_id:
return Response().error("Missing key: session_id").__dict__
username = g.get("username", "guest")
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
await self._delete_session_internal(session, username)
return Response().ok().__dict__ return Response().ok().__dict__
async def batch_delete_sessions(self):
"""Batch delete multiple Platform sessions."""
post_data = await request.json
if post_data is None:
return Response().error("Missing JSON body").__dict__
if not isinstance(post_data, dict):
return Response().error("Invalid JSON body: expected object").__dict__
session_ids = post_data.get("session_ids")
if not session_ids or not isinstance(session_ids, list):
return Response().error("Missing or invalid key: session_ids").__dict__
username = g.get("username", "guest")
sessions = await self.db.get_platform_sessions_by_ids(session_ids)
sessions_by_id = {session.session_id: session for session in sessions}
deleted_count = 0
failed_items = []
for sid in session_ids:
session = sessions_by_id.get(sid)
if not session:
failed_items.append({"session_id": sid, "reason": "not found"})
continue
if session.creator != username:
failed_items.append({"session_id": sid, "reason": "permission denied"})
continue
try:
await self._delete_session_internal(session, username)
deleted_count += 1
sessions_by_id.pop(sid, None)
except Exception:
logger.warning("Failed to delete session %s", sid)
failed_items.append({"session_id": sid, "reason": "internal_error"})
return (
Response()
.ok(
data={
"deleted_count": deleted_count,
"failed_count": len(failed_items),
"failed_items": failed_items,
}
)
.__dict__
)
def _extract_attachment_ids(self, history_list) -> list[str]: def _extract_attachment_ids(self, history_list) -> list[str]:
"""从消息历史中提取所有 attachment_id""" """从消息历史中提取所有 attachment_id"""
attachment_ids = [] attachment_ids = []
-2
View File
@@ -610,7 +610,6 @@ class ConfigRoute(Route):
try: try:
conf_id = self.acm.create_conf(name=name, config=config) conf_id = self.acm.create_conf(name=name, config=config)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__ return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
except ValueError as e: except ValueError as e:
return Response().error(str(e)).__dict__ return Response().error(str(e)).__dict__
@@ -650,7 +649,6 @@ class ConfigRoute(Route):
try: try:
success = self.acm.delete_conf(conf_id) success = self.acm.delete_conf(conf_id)
if success: if success:
self.core_lifecycle.pipeline_scheduler_mapping.pop(conf_id, None)
return Response().ok(message="删除成功").__dict__ return Response().ok(message="删除成功").__dict__
return Response().error("删除失败").__dict__ return Response().error("删除失败").__dict__
except ValueError as e: except ValueError as e:
+1 -31
View File
@@ -5,8 +5,7 @@ import os
import ssl import ssl
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime
from pathlib import Path
import aiohttp import aiohttp
import certifi import certifi
@@ -353,34 +352,6 @@ class PluginRoute(Route):
logger.warning(f"获取插件 Logo 失败: {e}") logger.warning(f"获取插件 Logo 失败: {e}")
return None return None
def _resolve_plugin_dir(self, plugin) -> Path | None:
if not plugin.root_dir_name:
return None
base_dir = Path(
self.plugin_manager.reserved_plugin_path
if plugin.reserved
else self.plugin_manager.plugin_store_path
)
plugin_dir = base_dir / plugin.root_dir_name
if not plugin_dir.is_dir():
return None
return plugin_dir
def _get_plugin_installed_at(self, plugin) -> str | None:
plugin_dir = self._resolve_plugin_dir(plugin)
if plugin_dir is None:
return None
try:
return datetime.fromtimestamp(
plugin_dir.stat().st_mtime,
timezone.utc,
).isoformat()
except OSError as exc:
logger.warning(f"获取插件安装时间失败 {plugin.name}: {exc!s}")
return None
async def get_plugins(self): async def get_plugins(self):
_plugin_resp = [] _plugin_resp = []
plugin_name = request.args.get("name") plugin_name = request.args.get("name")
@@ -406,7 +377,6 @@ class PluginRoute(Route):
"logo": f"/api/file/{logo_url}" if logo_url else None, "logo": f"/api/file/{logo_url}" if logo_url else None,
"support_platforms": plugin.support_platforms, "support_platforms": plugin.support_platforms,
"astrbot_version": plugin.astrbot_version, "astrbot_version": plugin.astrbot_version,
"installed_at": self._get_plugin_installed_at(plugin),
} }
# 检查是否为全空的幽灵插件 # 检查是否为全空的幽灵插件
if not any( if not any(
-110
View File
@@ -2,7 +2,6 @@ import os
import re import re
import shutil import shutil
import traceback import traceback
import uuid
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -51,7 +50,6 @@ class SkillsRoute(Route):
self.routes = { self.routes = {
"/skills": ("GET", self.get_skills), "/skills": ("GET", self.get_skills),
"/skills/upload": ("POST", self.upload_skill), "/skills/upload": ("POST", self.upload_skill),
"/skills/batch-upload": ("POST", self.batch_upload_skills),
"/skills/download": ("GET", self.download_skill), "/skills/download": ("GET", self.download_skill),
"/skills/update": ("POST", self.update_skill), "/skills/update": ("POST", self.update_skill),
"/skills/delete": ("POST", self.delete_skill), "/skills/delete": ("POST", self.delete_skill),
@@ -190,114 +188,6 @@ class SkillsRoute(Route):
except Exception: except Exception:
logger.warning(f"Failed to remove temp skill file: {temp_path}") logger.warning(f"Failed to remove temp skill file: {temp_path}")
async def batch_upload_skills(self):
"""批量上传多个 skill ZIP 文件"""
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
files = await request.files
file_list = files.getlist("files")
if not file_list:
return Response().error("No files provided").__dict__
succeeded = []
failed = []
skill_mgr = SkillManager()
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
for file in file_list:
filename = os.path.basename(file.filename or "unknown.zip")
temp_path = None
try:
if not filename.lower().endswith(".zip"):
failed.append(
{
"filename": filename,
"error": "Only .zip files are supported",
}
)
continue
temp_path = os.path.join(
temp_dir, f"batch_{uuid.uuid4().hex}_{filename}"
)
await file.save(temp_path)
skill_name = skill_mgr.install_skill_from_zip(
temp_path, overwrite=True
)
succeeded.append({"filename": filename, "name": skill_name})
except Exception as e:
failed.append({"filename": filename, "error": str(e)})
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception:
pass
if succeeded:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning(
"Failed to sync uploaded skills to active sandboxes."
)
total = len(file_list)
success_count = len(succeeded)
if success_count == total:
message = f"All {total} skill(s) uploaded successfully."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
if success_count == 0:
message = f"Upload failed for all {total} file(s)."
resp = Response().error(message)
resp.data = {
"total": total,
"succeeded": succeeded,
"failed": failed,
}
return resp.__dict__
message = f"Partial success: {success_count}/{total} skill(s) uploaded."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def download_skill(self): async def download_skill(self):
try: try:
name = str(request.args.get("name") or "").strip() name = str(request.args.get("name") or "").strip()
+66 -183
View File
@@ -12,32 +12,6 @@ from .route import Response, Route, RouteContext
DEFAULT_MCP_CONFIG = {"mcpServers": {}} DEFAULT_MCP_CONFIG = {"mcpServers": {}}
class EmptyMcpServersError(ValueError):
"""Raised when mcpServers is empty."""
pass
def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
"""Extract server configuration from user-submitted mcpServers field.
Raises:
ValueError: Invalid configuration
"""
if not isinstance(mcp_servers_value, dict):
raise ValueError("mcpServers must be a JSON object")
if not mcp_servers_value:
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
key_0 = next(iter(mcp_servers_value))
extracted = mcp_servers_value[key_0]
if not isinstance(extracted, dict):
raise ValueError(
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
"and each value is an object containing fields like command/url."
)
return extracted
class ToolsRoute(Route): class ToolsRoute(Route):
def __init__( def __init__(
self, self,
@@ -59,37 +33,13 @@ class ToolsRoute(Route):
self.register_routes() self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
def _rollback_mcp_server(self, name: str) -> bool:
try:
rollback_config = self.tool_mgr.load_mcp_config()
if name in rollback_config["mcpServers"]:
rollback_config["mcpServers"].pop(name)
return self.tool_mgr.save_mcp_config(rollback_config)
return True
except Exception:
logger.error(traceback.format_exc())
return False
async def get_mcp_servers(self): async def get_mcp_servers(self):
try: try:
config = self.tool_mgr.load_mcp_config() config = self.tool_mgr.load_mcp_config()
servers = [] servers = []
mcp_servers = config.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
logger.warning(
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
)
mcp_servers = {}
# 获取所有服务器并添加它们的工具列表 # 获取所有服务器并添加它们的工具列表
for name, server_config in mcp_servers.items(): for name, server_config in config["mcpServers"].items():
if not isinstance(server_config, dict):
logger.warning(
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
)
continue
server_info = { server_info = {
"name": name, "name": name,
"active": server_config.get("active", True), "active": server_config.get("active", True),
@@ -115,7 +65,7 @@ class ToolsRoute(Route):
return Response().ok(servers).__dict__ return Response().ok(servers).__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Failed to get MCP server list: {e!s}").__dict__ return Response().error(f"获取 MCP 服务器列表失败: {e!s}").__dict__
async def add_mcp_server(self): async def add_mcp_server(self):
try: try:
@@ -125,7 +75,7 @@ class ToolsRoute(Route):
# 检查必填字段 # 检查必填字段
if not name: if not name:
return Response().error("Server name cannot be empty").__dict__ return Response().error("服务器名称不能为空").__dict__
# 移除特殊字段并检查配置是否有效 # 移除特殊字段并检查配置是否有效
has_valid_config = False has_valid_config = False
@@ -135,33 +85,21 @@ class ToolsRoute(Route):
for key, value in server_data.items(): for key, value in server_data.items():
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段 if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
if key == "mcpServers": if key == "mcpServers":
try: key_0 = list(server_data["mcpServers"].keys())[
server_config = _extract_mcp_server_config( 0
server_data["mcpServers"] ] # 不考虑为空的情况
) server_config = server_data["mcpServers"][key_0]
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else: else:
server_config[key] = value server_config[key] = value
has_valid_config = True has_valid_config = True
if not has_valid_config: if not has_valid_config:
return ( return Response().error("必须提供有效的服务器配置").__dict__
Response()
.error("A valid server configuration is required")
.__dict__
)
config = self.tool_mgr.load_mcp_config() config = self.tool_mgr.load_mcp_config()
if name in config["mcpServers"]: if name in config["mcpServers"]:
return Response().error(f"Server {name} already exists").__dict__ return Response().error(f"服务器 {name} 已存在").__dict__
try:
await self.tool_mgr.test_mcp_server_connection(server_config)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"MCP connection test failed: {e!s}").__dict__
config["mcpServers"][name] = server_config config["mcpServers"][name] = server_config
@@ -173,27 +111,17 @@ class ToolsRoute(Route):
timeout=30, timeout=30,
) )
except TimeoutError: except TimeoutError:
rollback_ok = self._rollback_mcp_server(name) return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
err_msg = f"Timed out while enabling MCP server {name}."
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"Failed to enable MCP server {name}: {e!s}"
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
return ( return (
Response() Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__
.ok(None, f"Successfully added MCP server {name}")
.__dict__
) )
return Response().error("Failed to save configuration").__dict__ return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Failed to add MCP server: {e!s}").__dict__ return Response().error(f"添加 MCP 服务器失败: {e!s}").__dict__
async def update_mcp_server(self): async def update_mcp_server(self):
try: try:
@@ -203,25 +131,23 @@ class ToolsRoute(Route):
old_name = server_data.get("oldName") or name old_name = server_data.get("oldName") or name
if not name: if not name:
return Response().error("Server name cannot be empty").__dict__ return Response().error("服务器名称不能为空").__dict__
config = self.tool_mgr.load_mcp_config() config = self.tool_mgr.load_mcp_config()
if old_name not in config["mcpServers"]: if old_name not in config["mcpServers"]:
return Response().error(f"Server {old_name} does not exist").__dict__ return Response().error(f"服务器 {old_name} 不存在").__dict__
is_rename = name != old_name is_rename = name != old_name
if name in config["mcpServers"] and is_rename: if name in config["mcpServers"] and is_rename:
return Response().error(f"Server {name} already exists").__dict__ return Response().error(f"服务器 {name} 已存在").__dict__
# 获取活动状态 # 获取活动状态
old_config = config["mcpServers"][old_name] active = server_data.get(
if isinstance(old_config, dict): "active",
old_active = old_config.get("active", True) config["mcpServers"][old_name].get("active", True),
else: )
old_active = True
active = server_data.get("active", old_active)
# 创建新的配置对象 # 创建新的配置对象
server_config = {"active": active} server_config = {"active": active}
@@ -239,19 +165,17 @@ class ToolsRoute(Route):
"oldName", "oldName",
]: # 排除特殊字段 ]: # 排除特殊字段
if key == "mcpServers": if key == "mcpServers":
try: key_0 = list(server_data["mcpServers"].keys())[
server_config = _extract_mcp_server_config( 0
server_data["mcpServers"] ] # 不考虑为空的情况
) server_config = server_data["mcpServers"][key_0]
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else: else:
server_config[key] = value server_config[key] = value
only_update_active = False only_update_active = False
# 如果只更新活动状态,保留原始配置 # 如果只更新活动状态,保留原始配置
if only_update_active and isinstance(old_config, dict): if only_update_active:
for key, value in old_config.items(): for key, value in config["mcpServers"][old_name].items():
if key != "active": # 除了active之外的所有字段都保留 if key != "active": # 除了active之外的所有字段都保留
server_config[key] = value server_config[key] = value
@@ -276,7 +200,7 @@ class ToolsRoute(Route):
return ( return (
Response() Response()
.error( .error(
f"Timed out while disabling MCP server {old_name} before enabling: {e!s}" f"启用前停用 MCP 服务器时 {old_name} 超时: {e!s}"
) )
.__dict__ .__dict__
) )
@@ -285,7 +209,7 @@ class ToolsRoute(Route):
return ( return (
Response() Response()
.error( .error(
f"Failed to disable MCP server {old_name} before enabling: {e!s}" f"启用前停用 MCP 服务器时 {old_name} 失败: {e!s}"
) )
.__dict__ .__dict__
) )
@@ -297,15 +221,13 @@ class ToolsRoute(Route):
) )
except TimeoutError: except TimeoutError:
return ( return (
Response() Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
.error(f"Timed out while enabling MCP server {name}.")
.__dict__
) )
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return ( return (
Response() Response()
.error(f"Failed to enable MCP server {name}: {e!s}") .error(f"启用 MCP 服务器 {name} 失败: {e!s}")
.__dict__ .__dict__
) )
# 如果要停用服务器 # 如果要停用服务器
@@ -315,26 +237,22 @@ class ToolsRoute(Route):
except TimeoutError: except TimeoutError:
return ( return (
Response() Response()
.error(f"Timed out while disabling MCP server {old_name}.") .error(f"停用 MCP 服务器 {old_name} 超时。")
.__dict__ .__dict__
) )
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return ( return (
Response() Response()
.error(f"Failed to disable MCP server {old_name}: {e!s}") .error(f"停用 MCP 服务器 {old_name} 失败: {e!s}")
.__dict__ .__dict__
) )
return ( return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
Response() return Response().error("保存配置失败").__dict__
.ok(None, f"Successfully updated MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Failed to update MCP server: {e!s}").__dict__ return Response().error(f"更新 MCP 服务器失败: {e!s}").__dict__
async def delete_mcp_server(self): async def delete_mcp_server(self):
try: try:
@@ -342,12 +260,12 @@ class ToolsRoute(Route):
name = server_data.get("name", "") name = server_data.get("name", "")
if not name: if not name:
return Response().error("Server name cannot be empty").__dict__ return Response().error("服务器名称不能为空").__dict__
config = self.tool_mgr.load_mcp_config() config = self.tool_mgr.load_mcp_config()
if name not in config["mcpServers"]: if name not in config["mcpServers"]:
return Response().error(f"Server {name} does not exist").__dict__ return Response().error(f"服务器 {name} 不存在").__dict__
del config["mcpServers"][name] del config["mcpServers"][name]
@@ -357,76 +275,51 @@ class ToolsRoute(Route):
await self.tool_mgr.disable_mcp_server(name, timeout=10) await self.tool_mgr.disable_mcp_server(name, timeout=10)
except TimeoutError: except TimeoutError:
return ( return (
Response() Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
.error(f"Timed out while disabling MCP server {name}.")
.__dict__
) )
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return ( return (
Response() Response()
.error(f"Failed to disable MCP server {name}: {e!s}") .error(f"停用 MCP 服务器 {name} 失败: {e!s}")
.__dict__ .__dict__
) )
return ( return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
Response() return Response().error("保存配置失败").__dict__
.ok(None, f"Successfully deleted MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Failed to delete MCP server: {e!s}").__dict__ return Response().error(f"删除 MCP 服务器失败: {e!s}").__dict__
async def test_mcp_connection(self): async def test_mcp_connection(self):
"""Test MCP server connection.""" """测试 MCP 服务器连接"""
try: try:
server_data = await request.json server_data = await request.json
config = server_data.get("mcp_server_config", None) config = server_data.get("mcp_server_config", None)
if not isinstance(config, dict) or not config: if not isinstance(config, dict) or not config:
return Response().error("Invalid MCP server configuration").__dict__ return Response().error("无效的 MCP 服务器配置").__dict__
if "mcpServers" in config: if "mcpServers" in config:
mcp_servers = config["mcpServers"] keys = list(config["mcpServers"].keys())
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1: if not keys:
return ( return Response().error("MCP 服务器配置不能为空").__dict__
Response() if len(keys) > 1:
.error( return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
"Only one MCP server configuration can be tested at a time" config = config["mcpServers"][keys[0]]
)
.__dict__
)
try:
config = _extract_mcp_server_config(mcp_servers)
except EmptyMcpServersError:
return (
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
elif not config: elif not config:
return ( return Response().error("MCP 服务器配置不能为空").__dict__
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
tools_name = await self.tool_mgr.test_mcp_server_connection(config) tools_name = await self.tool_mgr.test_mcp_server_connection(config)
return ( return (
Response() Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
.ok(data=tools_name, message="🎉 MCP server is available!")
.__dict__
) )
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Failed to test MCP connection: {e!s}").__dict__ return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__
async def get_tool_list(self): async def get_tool_list(self):
"""Get all registered tools.""" """获取所有注册的工具列表"""
try: try:
tools = self.tool_mgr.func_list tools = self.tool_mgr.func_list
tools_dict = [] tools_dict = []
@@ -456,44 +349,36 @@ class ToolsRoute(Route):
return Response().ok(data=tools_dict).__dict__ return Response().ok(data=tools_dict).__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Failed to get tool list: {e!s}").__dict__ return Response().error(f"获取工具列表失败: {e!s}").__dict__
async def toggle_tool(self): async def toggle_tool(self):
"""Activate or deactivate a specified tool.""" """启用或停用指定的工具"""
try: try:
data = await request.json data = await request.json
tool_name = data.get("name") tool_name = data.get("name")
action = data.get("activate") # True or False action = data.get("activate") # True or False
if not tool_name or action is None: if not tool_name or action is None:
return ( return Response().error("缺少必要参数: name 或 action").__dict__
Response()
.error("Missing required parameters: name or activate")
.__dict__
)
if action: if action:
try: try:
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
except ValueError as e: except ValueError as e:
return Response().error(f"Failed to activate tool: {e!s}").__dict__ return Response().error(f"启用工具失败: {e!s}").__dict__
else: else:
ok = self.tool_mgr.deactivate_llm_tool(tool_name) ok = self.tool_mgr.deactivate_llm_tool(tool_name)
if ok: if ok:
return Response().ok(None, "Operation successful.").__dict__ return Response().ok(None, "操作成功。").__dict__
return ( return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
Response()
.error(f"Tool {tool_name} does not exist or the operation failed.")
.__dict__
)
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Failed to operate tool: {e!s}").__dict__ return Response().error(f"操作工具失败: {e!s}").__dict__
async def sync_provider(self): async def sync_provider(self):
"""Sync MCP provider configuration.""" """同步 MCP 提供者配置"""
try: try:
data = await request.json data = await request.json
provider_name = data.get("name") # modelscope, or others provider_name = data.get("name") # modelscope, or others
@@ -502,11 +387,9 @@ class ToolsRoute(Route):
access_token = data.get("access_token", "") access_token = data.get("access_token", "")
await self.tool_mgr.sync_modelscope_mcp_servers(access_token) await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
case _: case _:
return ( return Response().error(f"未知: {provider_name}").__dict__
Response().error(f"Unknown provider: {provider_name}").__dict__
)
return Response().ok(message="Sync completed").__dict__ return Response().ok(message="同步成功").__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"Sync failed: {e!s}").__dict__ return Response().error(f"同步失败: {e!s}").__dict__
-70
View File
@@ -1,70 +0,0 @@
## What's Changed
### 新增
- 集成 KOOK 平台适配器 ([#5658](https://github.com/AstrBotDevs/AstrBot/pull/5658))。
- 新增 Discord pre-react Emoji 支持 ([#5609](https://github.com/AstrBotDevs/AstrBot/pull/5609))。
- 新增 Telegram 支持 `sendMessageDraft` 流式实时输出 API ([#5726](https://github.com/AstrBotDevs/AstrBot/issues/5726))
- 支持在 Agent 运行时进行消息跟进能力,跟进的消息实时注入给 Agent ([#5484](https://github.com/AstrBotDevs/AstrBot/pull/5484))。
- 集成 DeerFlow Agent Runner 并优化流式处理 ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581))。
- 新增 shell, ipython tool 中包含操作系统信息,提高 windows 下 tool call 成功率 ([#5677](https://github.com/AstrBotDevs/AstrBot/pull/5677))。
- Sandbox 支持 Shipyard-neo - 支持 Skills 自迭代 ([#5028](https://github.com/AstrBotDevs/AstrBot/pull/5028))。
- 新增 ChatUI WebSocket 传输模式选择,OpenAPI Chat API 支持 WebSocket 连接 ([#5410](https://github.com/AstrBotDevs/AstrBot/pull/5410))。
- 支持 Persona 自定义报错回复消息与兜底逻辑 ([#5547](https://github.com/AstrBotDevs/AstrBot/pull/5547))。
- 将 WebUI 静态文件打包至 wheel,并将 astrbot CLI 日志替换为英文 ([#5665](https://github.com/AstrBotDevs/AstrBot/pull/5665))。
- 增强聊天界面与移动端响应式体验 ([#5635](https://github.com/AstrBotDevs/AstrBot/pull/5635))。
- 优化插件失败处理逻辑与扩展列表交互体验 ([#5535](https://github.com/AstrBotDevs/AstrBot/pull/5535))。
### 修复
- 修复 MCP 初始化超时参数关键字不匹配的问题 ([#5743](https://github.com/AstrBotDevs/AstrBot/pull/5743))。
- 修复 MCP 工具竞态条件导致"completion 无法解析"错误 ([#5534](https://github.com/AstrBotDevs/AstrBot/pull/5534))。
- 修复 LINE 适配器中非 HTTPS URL 直接透传的问题 ([#5697](https://github.com/AstrBotDevs/AstrBot/pull/5697))。
- 修复 WebUI 侧边栏自定义状态不稳定的问题 ([#5670](https://github.com/AstrBotDevs/AstrBot/pull/5670))。
- 修复 KOOK 适配器收到消息和心跳响应时输出多余调试日志的问题。
- 修复 `DEMO_MODE` 环境变量未正确解析为布尔值的问题 ([#5676](https://github.com/AstrBotDevs/AstrBot/pull/5676))。
- 修复子 Agent 无法正确接收本地图片(参考图)路径的问题 ([#5579](https://github.com/AstrBotDevs/AstrBot/pull/5579))。
- 修复 `/model` 命令切换至不同 Provider 模型时产生误导性行为的问题 ([#5578](https://github.com/AstrBotDevs/AstrBot/pull/5578))。
- 修复对话记录中 UTC 时区偏移未处理导致时间戳异常的问题 ([#5580](https://github.com/AstrBotDevs/AstrBot/pull/5580))。
- 修复备份导入时重复平台统计数据导致异常的问题 ([#5594](https://github.com/AstrBotDevs/AstrBot/pull/5594))。
- 修复 `max_agent_step` 配置未应用到子 Agent 的问题 ([#5608](https://github.com/AstrBotDevs/AstrBot/pull/5608))。
- 修复插件列表排序和搜索过滤逻辑 ([#5559](https://github.com/AstrBotDevs/AstrBot/pull/5559))。
- 修复 `uv sync` 时未要求 Node.js 环境的问题。
---
## What's Changed (EN)
### New Features
- Integrated KOOK platform adapter ([#5658](https://github.com/AstrBotDevs/AstrBot/pull/5658)).
- Integrated DeerFlow Agent Runner with optimized streaming support ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581)).
- feat(telegram): supports sendMessageDraft API ([#5726](https://github.com/AstrBotDevs/AstrBot/issues/5726))
- Integrated Neo skill self-iteration capability with full lifecycle management (candidate, release, deletion) via Shipyard Neo sandbox ([#5028](https://github.com/AstrBotDevs/AstrBot/pull/5028)).
- Added Discord pre-ack emoji support ([#5609](https://github.com/AstrBotDevs/AstrBot/pull/5609)).
- Added WebSocket transport mode selection for the chat interface ([#5410](https://github.com/AstrBotDevs/AstrBot/pull/5410)).
- Added OS information to tool descriptions with unit test coverage ([#5677](https://github.com/AstrBotDevs/AstrBot/pull/5677)).
- Added follow-up message handling in `ToolLoopAgentRunner` ([#5484](https://github.com/AstrBotDevs/AstrBot/pull/5484)).
- Added support for persona custom error reply messages with fallback logic ([#5547](https://github.com/AstrBotDevs/AstrBot/pull/5547)).
- Bundled WebUI static files into the wheel package and replaced astrbot CLI logs with English ([#5665](https://github.com/AstrBotDevs/AstrBot/pull/5665)).
- Optimized async IO performance and added benchmark coverage ([#5737](https://github.com/AstrBotDevs/AstrBot/pull/5737)).
- Refactored API key creation and added unit tests for open API routes.
- Improved error messaging for AI execution failures in agent runners.
- Enhanced chat interface and mobile responsiveness ([#5635](https://github.com/AstrBotDevs/AstrBot/pull/5635)).
- Improved plugin failure handling and extension list UX ([#5535](https://github.com/AstrBotDevs/AstrBot/pull/5535)).
### Bug Fixes
- Fixed MCP initialization timeout keyword mismatch ([#5743](https://github.com/AstrBotDevs/AstrBot/pull/5743)).
- Fixed MCP tools race condition causing `completion 无法解析` error ([#5534](https://github.com/AstrBotDevs/AstrBot/pull/5534)).
- Fixed LINE adapter allowing non-HTTPS URLs to pass through directly ([#5697](https://github.com/AstrBotDevs/AstrBot/pull/5697)).
- Fixed unstable sidebar customization state in WebUI ([#5670](https://github.com/AstrBotDevs/AstrBot/pull/5670)).
- Fixed excessive debug logging in KOOK adapter for received messages and heartbeat responses.
- Fixed `DEMO_MODE` environment variable not being parsed correctly as a boolean ([#5676](https://github.com/AstrBotDevs/AstrBot/pull/5676)).
- Fixed sub-agent failing to correctly receive local image (reference image) paths ([#5579](https://github.com/AstrBotDevs/AstrBot/pull/5579)).
- Fixed misleading behavior of the `/model` command when switching to a model from a different provider ([#5578](https://github.com/AstrBotDevs/AstrBot/pull/5578)).
- Fixed unhandled UTC timezone offset causing incorrect timestamps in conversation records ([#5580](https://github.com/AstrBotDevs/AstrBot/pull/5580)).
- Fixed backup import failure due to duplicate platform stats entries ([#5594](https://github.com/AstrBotDevs/AstrBot/pull/5594)).
- Fixed `max_agent_step` config not being applied to sub-agents ([#5608](https://github.com/AstrBotDevs/AstrBot/pull/5608)).
- Fixed plugin list sorting and search filtering logic ([#5559](https://github.com/AstrBotDevs/AstrBot/pull/5559)).
- Fixed missing Node.js environment requirement during `uv sync`.
-40
View File
@@ -1,40 +0,0 @@
## What's Changed
### 新增
- 新增技能 ZIP 批量上传能力 ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804))。
### 修复
- 修复 MCP Server 配置异常时可能导致崩溃的问题 ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673))。
- 修复钉钉适配器文本消息被忽略、无法主动发送文件的问题 ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921))。
- 修复钉钉适配器无法接收图片与文件的问题 ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920))。
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))。
- 修复 OpenRouter `api_base` 配置错误的问题 ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911))。
- 修复插件市场中按展示名搜索已安装插件不生效的问题 ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811))。
- 修复仅图片响应未应用 `reply_with_quote``reply_with_mention` 的问题 ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219))。
- 修复 `RegexFilter` 使用 `re.match` 导致匹配范围不正确的问题 ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368))。
- 修复桌面运行环境检测依赖 frozen Python 的问题 ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859))。
- 修复通过“创建新配置”创建平台机器人后找不到 pipeline scheduler 的问题 ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776))。
---
## What's Changed (EN)
### New Features
- Added batch upload support for multiple skill ZIP files ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804)).
### Bug Fixes
- Fixed potential crash on malformed MCP server config ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673)).
- Fixed DingTalk adapter issue where text messages were ignored and files could not be sent proactively ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921)).
- Fixed DingTalk adapter issue where image and file messages could not be received ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920)).
- Fixed incorrect OpenRouter `api_base` configuration ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911)).
- Fixed searching installed plugins by display name in extensions ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811)).
- Fixed image-only responses not applying `reply_with_quote` and `reply_with_mention` ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219)).
- Fixed `RegexFilter` using `re.match` instead of `re.search` for expected matching behavior ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368)).
- Fixed desktop runtime detection requiring frozen Python ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859)).
- Fixed missing pipeline scheduler after creating a platform bot via "create new config" ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776)).
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))
-9
View File
@@ -1,9 +0,0 @@
## What's Changed
### 新增
- 企业微信智能机器人支持长连接模式。[#5930](https://github.com/AstrBotDevs/AstrBot/pull/5930)
### New
- Wecom AI Bot supports long-connection mode(Websockets). [#5930](https://github.com/AstrBotDevs/AstrBot/pull/5930)
-43
View File
@@ -1,43 +0,0 @@
## What's Changed
### 新增
- Lark 适配器支持 CardKit 流式输出(飞书)([#5777](https://github.com/AstrBotDevs/AstrBot/pull/5777))。
- WebUI 已安装插件列表新增筛选与排序功能 ([#5923](https://github.com/AstrBotDevs/AstrBot/pull/5923))。
### 优化
- 启动时后台加载 MCP Server,不阻塞加载流程 ([#5993](https://github.com/AstrBotDevs/AstrBot/pull/5993))。
### 修复
- 部分情况下 MCP 页报错 500 导致查看不了 MCP 服务器 ([#5993](https://github.com/AstrBotDevs/AstrBot/pull/5993))。
- 修复 TTS Provider 测试:增加文件大小校验,并补充 MiniMax 空音频检测 ([#5999](https://github.com/AstrBotDevs/AstrBot/pull/5999))。
- 修复前端切换到 Chat 后又回到 Welcome 时,页面切换配置未正确持久化的问题 ([#5792](https://github.com/AstrBotDevs/AstrBot/pull/5792))。
- 修复 Azure TTS 不支持 84 位订阅密钥的问题 ([#5813](https://github.com/AstrBotDevs/AstrBot/pull/5813))。
### 文档
- 文档仓库迁移:将 `AstrBotDevs/AstrBot-docs` 内容迁移至 `AstrBotDevs/AstrBot` ([#5960](https://github.com/AstrBotDevs/AstrBot/pull/5960))。
---
## What's Changed (EN)
### New Features
- Added CardKit streaming output support for the Lark/Feishu adapter ([#5777](https://github.com/AstrBotDevs/AstrBot/pull/5777)).
- Added filtering and sorting for installed plugins in the WebUI ([#5923](https://github.com/AstrBotDevs/AstrBot/pull/5923)).
### Impprovement
- MCP Server now loads in the background during startup without blocking the loading process ([#5993](https://github.com/AstrBotDevs/AstrBot/pull/5993)).
### Bug Fixes
- Added file size validation in TTS provider tests and MiniMax empty-audio detection ([#5999](https://github.com/AstrBotDevs/AstrBot/pull/5999)).
- Fixed frontend state persistence when switching from Chat back to Welcome ([#5792](https://github.com/AstrBotDevs/AstrBot/pull/5792)).
- Fixed Azure TTS support for 84-character subscription keys ([#5813](https://github.com/AstrBotDevs/AstrBot/pull/5813)).
- Reverted the MCP stdio missing-command error wording change after the previous fix ([#5992](https://github.com/AstrBotDevs/AstrBot/pull/5992)).
### Documentation
- Migrated documentation content from `AstrBotDevs/AstrBot-docs` into `AstrBotDevs/AstrBot` ([#5960](https://github.com/AstrBotDevs/AstrBot/pull/5960)).
-64
View File
@@ -1,64 +0,0 @@
## What's Changed
### 新增
- 新增俄语翻译([#6081](https://github.com/AstrBotDevs/AstrBot/pull/6081))。
- QQ 官方 Bot 新增文件、语音、视频消息支持(含 WebSocket 模式)([#6063](https://github.com/AstrBotDevs/AstrBot/pull/6063))。
### 优化
- 优化 QQ 官方 Bot 的流式消息投递可靠性与主动媒体发送能力([#6131](https://github.com/AstrBotDevs/AstrBot/pull/6131))。
- 优化边界场景下 booter 选择逻辑与消息发送工具([#6064](https://github.com/AstrBotDevs/AstrBot/pull/6064))。
### 修复
- 修复 Dashboard README 对话框锚点导航失效([#6083](https://github.com/AstrBotDevs/AstrBot/pull/6083))。
- 优先使用具名 weekday 的 cron 示例,避免歧义([#6091](https://github.com/AstrBotDevs/AstrBot/pull/6091))。
- 修复插件市场安装后状态未及时刷新的问题([#6124](https://github.com/AstrBotDevs/AstrBot/pull/6124))。
- 修复插件依赖安装逻辑:仅安装缺失依赖([#6088](https://github.com/AstrBotDevs/AstrBot/pull/6088))。
- 移除 Telegram 适配器中已废弃的 `normalize_whitespace` 参数([#6044](https://github.com/AstrBotDevs/AstrBot/pull/6044))。
- 修复 Windows 本地 skill 文件读取问题([#6028](https://github.com/AstrBotDevs/AstrBot/pull/6028))。
- 修复 Discord pre-ack emoji 配置重启后不持久化的问题([#6031](https://github.com/AstrBotDevs/AstrBot/pull/6031))。
- 统一 WebUI 搜索框清空行为([#6017](https://github.com/AstrBotDevs/AstrBot/pull/6017))。
- 优化插件依赖自动安装流程与 Dashboard 安装体验([#5954](https://github.com/AstrBotDevs/AstrBot/pull/5954))。
### 文档
- 新增 Astrbook 和玖帕喵社区链接([#6135](https://github.com/AstrBotDevs/AstrBot/pull/6135))。
- 修正文档 `docker.md``napcat.md` 中的拼写错误([#6048](https://github.com/AstrBotDevs/AstrBot/pull/6048))。
- 在多语言 README 中补充官方开发群号,并改进配置元数据中的正则说明。
- 更新编辑链接模式并移除过时仓库引用。
---
## What's Changed (EN)
### New Features
- Added Russian translation support ([#6081](https://github.com/AstrBotDevs/AstrBot/pull/6081)).
- Added file, voice, and video message support for QQ Official Bot (including WebSocket mode) ([#6063](https://github.com/AstrBotDevs/AstrBot/pull/6063)).
### Improvements
- Improved streaming message delivery reliability and proactive media sending for QQ Official API ([#6131](https://github.com/AstrBotDevs/AstrBot/pull/6131)).
- Optimized booter selection logic in edge cases and message sending tooling ([#6064](https://github.com/AstrBotDevs/AstrBot/pull/6064)).
### Bug Fixes
- Fixed broken README dialog anchor navigation in the Dashboard ([#6083](https://github.com/AstrBotDevs/AstrBot/pull/6083)).
- Preferred named weekday cron examples to reduce ambiguity ([#6091](https://github.com/AstrBotDevs/AstrBot/pull/6091)).
- Fixed plugin market install-state refresh after installation ([#6124](https://github.com/AstrBotDevs/AstrBot/pull/6124)).
- Fixed plugin dependency installation logic to install only missing packages ([#6088](https://github.com/AstrBotDevs/AstrBot/pull/6088)).
- Removed deprecated `normalize_whitespace` parameter in the Telegram adapter ([#6044](https://github.com/AstrBotDevs/AstrBot/pull/6044)).
- Fixed local skill file reading issues on Windows ([#6028](https://github.com/AstrBotDevs/AstrBot/pull/6028)).
- Fixed Discord pre-ack emoji config not being persisted across restarts ([#6031](https://github.com/AstrBotDevs/AstrBot/pull/6031)).
- Unified WebUI search input clear behavior ([#6017](https://github.com/AstrBotDevs/AstrBot/pull/6017)).
- Improved plugin dependency auto-install flow and Dashboard installation experience ([#5954](https://github.com/AstrBotDevs/AstrBot/pull/5954)).
### Documentation
- Added Astrbook and Jiupa Miao community links ([#6135](https://github.com/AstrBotDevs/AstrBot/pull/6135)).
- Fixed typos in `docker.md` and `napcat.md` ([#6048](https://github.com/AstrBotDevs/AstrBot/pull/6048)).
- Added official developer group IDs to multilingual READMEs and improved regex description in config metadata.
- Updated edit-link patterns and removed obsolete repository references.
-93
View File
@@ -1,93 +0,0 @@
## What's Changed
### 新增
- 补充 MiniMax Provider。([#6318](https://github.com/AstrBotDevs/AstrBot/pull/6318)
- 新增 WebUI ChatUI 页面的会话批量删除功能。([#6160](https://github.com/AstrBotDevs/AstrBot/pull/6160)
- 新增 WebUI ChatUI 配置发送快捷键。([#6272](https://github.com/AstrBotDevs/AstrBot/pull/6272)
### 优化
- 优化 UMO 处理兼容性。([#5996](https://github.com/AstrBotDevs/AstrBot/pull/5996)
- 重构 `_extract_session_id`,改进聊天类型分支处理。(#5775
- 优化聊天组件行为,使用 `shiki` 进行代码块渲染。([#6286](https://github.com/AstrBotDevs/AstrBot/pull/6286)
- 优化 WebUI 主题配色与视觉体验。([#6263](https://github.com/AstrBotDevs/AstrBot/pull/6263)
- 优化 OneBot @ 组件后处理,避免消息文本解析空格问题。([#6238](https://github.com/AstrBotDevs/AstrBot/pull/6238)
### 修复
- 修复创建新 Provider 后未同步 `providers_config` 的问题。([#6388](https://github.com/AstrBotDevs/AstrBot/pull/6388)
- 修复 API 返回 `null choices` 时的 `TypeError`。([#6313](https://github.com/AstrBotDevs/AstrBot/pull/6313)
- 修复 QQ Webhook 重试回调重复触发的问题。([#6320](https://github.com/AstrBotDevs/AstrBot/pull/6320)
- 修复流式模式下 `delta``None` 导致工具调用时报错的问题。([#6365](https://github.com/AstrBotDevs/AstrBot/pull/6365)
- 修复模型服务链接说明文字错误。([#6296](https://github.com/AstrBotDevs/AstrBot/pull/6296)
- 修复 AI 在 tool-calling 模式设为 `skills-like` 时发送媒体失败的问题。([#6317](https://github.com/AstrBotDevs/AstrBot/pull/6317)
- 修复 Telegram 适配器中 GIF 被错误转成静态图的问题。([#6329](https://github.com/AstrBotDevs/AstrBot/pull/6329)
- 将 Provider 图标来源替换为 jsDelivr CDN 地址,修复部分环境下图标加载问题。([#6340](https://github.com/AstrBotDevs/AstrBot/pull/6340)
- 修复 QQ 官方表情消息未解析为可读文本的问题。([#6355](https://github.com/AstrBotDevs/AstrBot/pull/6355)
- 修复 WebChat 队列异常时流式结果页面崩溃的问题。([#6123](https://github.com/AstrBotDevs/AstrBot/pull/6123)
- 修复子代理 handoff 工具在插件过滤时丢失的问题。([#6155](https://github.com/AstrBotDevs/AstrBot/pull/6155)
- 修复 Cron 提示文案缺少空格及 `utcnow()` 的弃用警告问题。([#6192](https://github.com/AstrBotDevs/AstrBot/pull/6192)
- 修复 WebUI 启动时 Sidebar hash 导航抖动/定位问题。([#6159](https://github.com/AstrBotDevs/AstrBot/pull/6159)
- 修复启动重试过程中移除已移除 API Key 的 `ValueError` 报错。([#6193](https://github.com/AstrBotDevs/AstrBot/pull/6193)
- 修复 README 启动命令引用更新为 `astrbot run`。([#6189](https://github.com/AstrBotDevs/AstrBot/pull/6189)
- 修复 `Plain.toDict()``@` 提及场景下空白字符丢失的问题。([#6244](https://github.com/AstrBotDevs/AstrBot/pull/6244)
- 修复 provider 依赖重复定义问题。([#6247](https://github.com/AstrBotDevs/AstrBot/pull/6247)
- 修复 Telegram 中普通回复被误判为线程的处理问题。([#6174](https://github.com/AstrBotDevs/AstrBot/pull/6174)
### 其他
- 调整 `astrbot.service` 及 CI 配置,升级 GitHub Actions 版本。
---
## What's Changed (EN)
### New Features
- Added OpenRouter chat completion provider adapter with support for custom headers ([#6436](https://github.com/AstrBotDevs/AstrBot/pull/6436)).
- Added MiniMax provider ([#6318](https://github.com/AstrBotDevs/AstrBot/pull/6318)).
- Added batch conversation deletion in WebChat ([#6160](https://github.com/AstrBotDevs/AstrBot/pull/6160)).
- Added send shortcut settings and localization support for WebChat input ([#6272](https://github.com/AstrBotDevs/AstrBot/pull/6272)).
- Added local temporary directory binding in YAML config ([#6191](https://github.com/AstrBotDevs/AstrBot/pull/6191)).
### Improvements
- Improved UMO processing compatibility ([#5996](https://github.com/AstrBotDevs/AstrBot/pull/5996)).
- Refactored `_extract_session_id` for chat type handling (#5775).
- Improved chat component behavior and uses `shiki` for code-block rendering ([#6286](https://github.com/AstrBotDevs/AstrBot/pull/6286)).
- Improved WebUI theme color and visual behavior ([#6263](https://github.com/AstrBotDevs/AstrBot/pull/6263)).
- Improved OneBot `@` component spacing handling ([#6238](https://github.com/AstrBotDevs/AstrBot/pull/6238)).
- Improved PR checklist validation and closure messaging.
### Bug Fixes
- Fixed missing `providers_config` sync after creating new providers ([#6388](https://github.com/AstrBotDevs/AstrBot/pull/6388)).
- Fixed `TypeError` when API returns null choices ([#6313](https://github.com/AstrBotDevs/AstrBot/pull/6313)).
- Fixed repeated QQ webhook retry callbacks ([#6320](https://github.com/AstrBotDevs/AstrBot/pull/6320)).
- Fixed tool-calling streaming null `delta` handling to prevent `AttributeError` ([#6365](https://github.com/AstrBotDevs/AstrBot/pull/6365)).
- Fixed model service link wording in docs/config ([#6296](https://github.com/AstrBotDevs/AstrBot/pull/6296)).
- Fixed AI media sending failure when tool-calling mode is set to `skills-like` ([#6317](https://github.com/AstrBotDevs/AstrBot/pull/6317)).
- Fixed GIF being sent as static image in Telegram adapter ([#6329](https://github.com/AstrBotDevs/AstrBot/pull/6329)).
- Replaced npm registry URLs with jsDelivr CDN for provider icons ([#6340](https://github.com/AstrBotDevs/AstrBot/pull/6340)).
- Fixed QQ official face message parsing to readable text ([#6355](https://github.com/AstrBotDevs/AstrBot/pull/6355)).
- Fixed WebChat stream-result crash on queue errors ([#6123](https://github.com/AstrBotDevs/AstrBot/pull/6123)).
- Preserved subagent handoff tools during plugin filtering ([#6155](https://github.com/AstrBotDevs/AstrBot/pull/6155)).
- Fixed cron prompt spacing and deprecated `utcnow()` usage ([#6192](https://github.com/AstrBotDevs/AstrBot/pull/6192)).
- Fixed unstable sidebar hash navigation on startup ([#6159](https://github.com/AstrBotDevs/AstrBot/pull/6159)).
- Fixed `ValueError` in retry loop when removing an already removed API key ([#6193](https://github.com/AstrBotDevs/AstrBot/pull/6193)).
- Updated startup command to `astrbot run` across READMEs ([#6189](https://github.com/AstrBotDevs/AstrBot/pull/6189)).
- Preserved whitespace in `Plain.toDict()` for @ mentions ([#6244](https://github.com/AstrBotDevs/AstrBot/pull/6244)).
- Removed duplicate dependencies entries ([#6247](https://github.com/AstrBotDevs/AstrBot/pull/6247)).
- Fixed Telegram normal reply being treated as topic thread ([#6174](https://github.com/AstrBotDevs/AstrBot/pull/6174)).
### Documentation
- Updated `rainyun` backup/access documentation ([#6427](https://github.com/AstrBotDevs/AstrBot/pull/6427)).
- Updated `package.md` and platform docs, including Matrix and Wecom AI bot documentation.
- Fixed Discord invite link in community docs.
### Chores
- Updated PR templates/checklist workflow, repository service config, and automated checks.
- Refreshed repository automation and formatting maintenance, and removed obsolete changelog scripts.
-1
View File
@@ -37,7 +37,6 @@ services:
- DEFAULT_SHIP_MEMORY=512m - DEFAULT_SHIP_MEMORY=512m
volumes: volumes:
- ${PWD}/data/shipyard/bay_data:/app/data - ${PWD}/data/shipyard/bay_data:/app/data
- ${PWD}/data/temp:/AstrBot/data/temp # Bind the local temp directory to the sandbox so that the uploaded file can be accessed in the sandbox
- /var/run/docker.sock:/var/run/docker.sock:ro - /var/run/docker.sock:/var/run/docker.sock:ro
networks: networks:
- astrbot_network - astrbot_network
+2
View File
@@ -1,3 +1,5 @@
version: '3.8'
# 当接入 QQ NapCat 时,请使用这个 compose 文件一键部署: https://github.com/NapNeko/NapCat-Docker/blob/main/compose/astrbot.yml # 当接入 QQ NapCat 时,请使用这个 compose 文件一键部署: https://github.com/NapNeko/NapCat-Docker/blob/main/compose/astrbot.yml
services: services:
+8 -13
View File
@@ -17,17 +17,17 @@
"@tiptap/starter-kit": "2.1.7", "@tiptap/starter-kit": "2.1.7",
"@tiptap/vue-3": "2.1.7", "@tiptap/vue-3": "2.1.7",
"apexcharts": "3.42.0", "apexcharts": "3.42.0",
"axios": "1.13.5", "axios": ">=1.6.2 <1.10.0 || >1.10.0 <2.0.0",
"axios-mock-adapter": "^1.22.0", "axios-mock-adapter": "^1.22.0",
"chance": "1.1.11", "chance": "1.1.11",
"date-fns": "2.30.0", "date-fns": "2.30.0",
"dompurify": "^3.3.2", "dompurify": "^3.3.1",
"event-source-polyfill": "^1.0.31", "event-source-polyfill": "^1.0.31",
"highlight.js": "^11.11.1", "highlight.js": "^11.11.1",
"js-md5": "^0.8.3", "js-md5": "^0.8.3",
"katex": "^0.16.27", "katex": "^0.16.27",
"lodash": "4.17.23", "lodash": "4.17.21",
"markdown-it": "^14.1.1", "markdown-it": "^14.1.0",
"markstream-vue": "^0.0.6", "markstream-vue": "^0.0.6",
"mermaid": "^11.12.2", "mermaid": "^11.12.2",
"monaco-editor": "^0.52.2", "monaco-editor": "^0.52.2",
@@ -36,8 +36,9 @@
"remixicon": "3.5.0", "remixicon": "3.5.0",
"shiki": "^3.20.0", "shiki": "^3.20.0",
"stream-markdown": "^0.0.13", "stream-markdown": "^0.0.13",
"stream-monaco": "^0.0.17",
"vee-validate": "4.11.3", "vee-validate": "4.11.3",
"vite-plugin-vuetify": "2.1.3", "vite-plugin-vuetify": "1.0.2",
"vue": "3.3.4", "vue": "3.3.4",
"vue-i18n": "^11.1.5", "vue-i18n": "^11.1.5",
"vue-router": "4.2.4", "vue-router": "4.2.4",
@@ -53,7 +54,7 @@
"@types/dompurify": "^3.0.5", "@types/dompurify": "^3.0.5",
"@types/markdown-it": "^14.1.2", "@types/markdown-it": "^14.1.2",
"@types/node": "^20.5.7", "@types/node": "^20.5.7",
"@vitejs/plugin-vue": "5.2.4", "@vitejs/plugin-vue": "4.3.3",
"@vue/eslint-config-prettier": "8.0.0", "@vue/eslint-config-prettier": "8.0.0",
"@vue/eslint-config-typescript": "11.0.3", "@vue/eslint-config-typescript": "11.0.3",
"@vue/tsconfig": "^0.4.0", "@vue/tsconfig": "^0.4.0",
@@ -63,15 +64,9 @@
"sass": "1.66.1", "sass": "1.66.1",
"sass-loader": "13.3.2", "sass-loader": "13.3.2",
"typescript": "5.1.6", "typescript": "5.1.6",
"vite": "6.4.1", "vite": "4.4.9",
"vue-cli-plugin-vuetify": "2.5.8", "vue-cli-plugin-vuetify": "2.5.8",
"vue-tsc": "1.8.8", "vue-tsc": "1.8.8",
"vuetify-loader": "^2.0.0-alpha.9" "vuetify-loader": "^2.0.0-alpha.9"
},
"pnpm": {
"overrides": {
"immutable": "4.3.8",
"lodash-es": "4.17.23"
}
} }
} }
+271 -601
View File
File diff suppressed because it is too large Load Diff
+3 -69
View File
@@ -11,7 +11,6 @@
:currSessionId="currSessionId" :currSessionId="currSessionId"
:selectedProjectId="selectedProjectId" :selectedProjectId="selectedProjectId"
:transportMode="transportMode" :transportMode="transportMode"
:sendShortcut="sendShortcut"
:isDark="isDark" :isDark="isDark"
:chatboxMode="chatboxMode" :chatboxMode="chatboxMode"
:isMobile="isMobile" :isMobile="isMobile"
@@ -21,7 +20,6 @@
@selectConversation="handleSelectConversation" @selectConversation="handleSelectConversation"
@editTitle="showEditTitleDialog" @editTitle="showEditTitleDialog"
@deleteConversation="handleDeleteConversation" @deleteConversation="handleDeleteConversation"
@batchDeleteConversations="handleBatchDeleteConversations"
@closeMobileSidebar="closeMobileSidebar" @closeMobileSidebar="closeMobileSidebar"
@toggleTheme="toggleTheme" @toggleTheme="toggleTheme"
@toggleFullscreen="toggleFullscreen" @toggleFullscreen="toggleFullscreen"
@@ -30,7 +28,6 @@
@editProject="showEditProjectDialog" @editProject="showEditProjectDialog"
@deleteProject="handleDeleteProject" @deleteProject="handleDeleteProject"
@updateTransportMode="setTransportMode" @updateTransportMode="setTransportMode"
@updateSendShortcut="setSendShortcut"
/> />
<!-- 右侧聊天内容区域 --> <!-- 右侧聊天内容区域 -->
@@ -74,14 +71,13 @@
:stagedImagesUrl="stagedImagesUrl" :stagedImagesUrl="stagedImagesUrl"
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles" :stagedFiles="stagedNonImageFiles"
:disabled="false" :disabled="isStreaming"
:is-running="isStreaming || isConvRunning" :is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
:session-id="currSessionId || null" :session-id="currSessionId || null"
:current-session="getCurrentSession" :current-session="getCurrentSession"
:replyTo="replyTo" :replyTo="replyTo"
:send-shortcut="sendShortcut"
@send="handleSendMessage" @send="handleSendMessage"
@stop="handleStopMessage" @stop="handleStopMessage"
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@@ -106,14 +102,13 @@
:stagedImagesUrl="stagedImagesUrl" :stagedImagesUrl="stagedImagesUrl"
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles" :stagedFiles="stagedNonImageFiles"
:disabled="false" :disabled="isStreaming"
:is-running="isStreaming || isConvRunning" :is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
:session-id="currSessionId || null" :session-id="currSessionId || null"
:current-session="getCurrentSession" :current-session="getCurrentSession"
:replyTo="replyTo" :replyTo="replyTo"
:send-shortcut="sendShortcut"
@send="handleSendMessage" @send="handleSendMessage"
@stop="handleStopMessage" @stop="handleStopMessage"
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@@ -137,14 +132,13 @@
:stagedImagesUrl="stagedImagesUrl" :stagedImagesUrl="stagedImagesUrl"
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles" :stagedFiles="stagedNonImageFiles"
:disabled="false" :disabled="isStreaming"
:is-running="isStreaming || isConvRunning" :is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
:session-id="currSessionId || null" :session-id="currSessionId || null"
:current-session="getCurrentSession" :current-session="getCurrentSession"
:replyTo="replyTo" :replyTo="replyTo"
:send-shortcut="sendShortcut"
@send="handleSendMessage" @send="handleSendMessage"
@stop="handleStopMessage" @stop="handleStopMessage"
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@@ -226,13 +220,10 @@ import { useMediaHandling } from '@/composables/useMediaHandling';
import { useProjects } from '@/composables/useProjects'; import { useProjects } from '@/composables/useProjects';
import type { Project } from '@/components/chat/ProjectList.vue'; import type { Project } from '@/components/chat/ProjectList.vue';
import { useRecording } from '@/composables/useRecording'; import { useRecording } from '@/composables/useRecording';
import { useToast } from '@/utils/toast';
interface Props { interface Props {
chatboxMode?: boolean; chatboxMode?: boolean;
} }
type SendShortcut = 'enter' | 'shift_enter';
const SEND_SHORTCUT_STORAGE_KEY = 'chat_send_shortcut';
const props = withDefaults(defineProps<Props>(), { const props = withDefaults(defineProps<Props>(), {
chatboxMode: false chatboxMode: false
@@ -242,7 +233,6 @@ const router = useRouter();
const route = useRoute(); const route = useRoute();
const { t } = useI18n(); const { t } = useI18n();
const { tm } = useModuleI18n('features/chat'); const { tm } = useModuleI18n('features/chat');
const { warning: toastWarning } = useToast();
const theme = useTheme(); const theme = useTheme();
const customizer = useCustomizerStore(); const customizer = useCustomizerStore();
@@ -267,7 +257,6 @@ const {
getSessions, getSessions,
newSession, newSession,
deleteSession: deleteSessionFn, deleteSession: deleteSessionFn,
batchDeleteSessions,
showEditTitleDialog, showEditTitleDialog,
saveTitle, saveTitle,
updateSessionTitle, updateSessionTitle,
@@ -341,18 +330,6 @@ interface ReplyInfo {
const replyTo = ref<ReplyInfo | null>(null); const replyTo = ref<ReplyInfo | null>(null);
const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark'); const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark');
const sendShortcut = ref<SendShortcut>('shift_enter');
function setSendShortcut(mode: SendShortcut) {
sendShortcut.value = mode;
localStorage.setItem(SEND_SHORTCUT_STORAGE_KEY, mode);
}
function focusChatInput() {
nextTick(() => {
chatInputRef.value?.focusInput?.();
});
}
// //
function checkMobile() { function checkMobile() {
@@ -511,7 +488,6 @@ async function handleSelectConversation(sessionIds: string[]) {
nextTick(() => { nextTick(() => {
messageList.value?.scrollToBottom(); messageList.value?.scrollToBottom();
}); });
focusChatInput();
} }
function handleNewChat() { function handleNewChat() {
@@ -521,7 +497,6 @@ function handleNewChat() {
// 退 // 退
selectedProjectId.value = null; selectedProjectId.value = null;
projectSessions.value = []; projectSessions.value = [];
focusChatInput();
} }
async function handleDeleteConversation(sessionId: string) { async function handleDeleteConversation(sessionId: string) {
@@ -535,33 +510,6 @@ async function handleDeleteConversation(sessionId: string) {
} }
} }
async function handleBatchDeleteConversations(sessionIds: string[]) {
try {
const result = await batchDeleteSessions(sessionIds);
//
if (result.currentSessionDeleted) {
messages.value = [];
}
//
if (result.failed_count > 0) {
toastWarning(
tm('batch.partialFailure', { failed: result.failed_count, total: sessionIds.length })
);
}
//
if (selectedProjectId.value) {
const sessions = await getProjectSessions(selectedProjectId.value);
projectSessions.value = sessions;
}
} catch (err) {
console.error('Batch delete sessions failed:', err);
toastWarning(tm('batch.requestFailed'));
}
}
async function handleSelectProject(projectId: string) { async function handleSelectProject(projectId: string) {
selectedProjectId.value = projectId; selectedProjectId.value = projectId;
const sessions = await getProjectSessions(projectId); const sessions = await getProjectSessions(projectId);
@@ -679,11 +627,6 @@ async function handleSendMessage() {
const selectedProviderId = selection?.providerId || ''; const selectedProviderId = selection?.providerId || '';
const selectedModelName = selection?.modelName || ''; const selectedModelName = selection?.modelName || '';
//
nextTick(() => {
messageList.value?.scrollToBottom();
});
await sendMsg( await sendMsg(
promptToSend, promptToSend,
filesToSend, filesToSend,
@@ -693,11 +636,6 @@ async function handleSendMessage() {
replyToSend replyToSend
); );
//
nextTick(() => {
messageList.value?.scrollToBottom();
});
// //
if (isCreatingNewSession && currentProjectId && currSessionId.value) { if (isCreatingNewSession && currentProjectId && currSessionId.value) {
await addSessionToProject(currSessionId.value, currentProjectId); await addSessionToProject(currSessionId.value, currentProjectId);
@@ -756,10 +694,6 @@ watch(sessions, (newSessions) => {
}); });
onMounted(() => { onMounted(() => {
const storedShortcut = localStorage.getItem(SEND_SHORTCUT_STORAGE_KEY);
if (storedShortcut === 'enter' || storedShortcut === 'shift_enter') {
sendShortcut.value = storedShortcut;
}
checkMobile(); checkMobile();
window.addEventListener('resize', checkMobile); window.addEventListener('resize', checkMobile);
getSessions(); getSessions();
+28 -43
View File
@@ -15,7 +15,7 @@
<transition name="fade"> <transition name="fade">
<div v-if="isDragging" class="drop-overlay"> <div v-if="isDragging" class="drop-overlay">
<div class="drop-overlay-content"> <div class="drop-overlay-content">
<v-icon size="48" color="primary">mdi-cloud-upload</v-icon> <v-icon size="48" color="deep-purple">mdi-cloud-upload</v-icon>
<span class="drop-text">{{ tm('input.dropToUpload') }}</span> <span class="drop-text">{{ tm('input.dropToUpload') }}</span>
</div> </div>
</div> </div>
@@ -41,7 +41,7 @@
<!-- Settings Menu --> <!-- Settings Menu -->
<StyledMenu offset="8" location="top start" :close-on-content-click="false"> <StyledMenu offset="8" location="top start" :close-on-content-click="false">
<template v-slot:activator="{ props: activatorProps }"> <template v-slot:activator="{ props: activatorProps }">
<v-btn v-bind="activatorProps" icon="mdi-plus" variant="text" color="primary" /> <v-btn v-bind="activatorProps" icon="mdi-plus" variant="text" color="deep-purple" />
</template> </template>
<!-- Upload Files --> <!-- Upload Files -->
@@ -87,7 +87,7 @@
{{ tm('voice.liveMode') }} {{ tm('voice.liveMode') }}
</v-tooltip> </v-tooltip>
</v-btn> --> </v-btn> -->
<v-btn @click="handleRecordClick" icon variant="text" :color="isRecording ? 'error' : 'primary'" <v-btn @click="handleRecordClick" icon variant="text" :color="isRecording ? 'error' : 'deep-purple'"
class="record-btn"> class="record-btn">
<v-icon :icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text" <v-icon :icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
plain></v-icon> plain></v-icon>
@@ -95,13 +95,13 @@
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }} {{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
</v-tooltip> </v-tooltip>
</v-btn> </v-btn>
<v-btn icon v-if="isRunning && !canSend" @click="$emit('stop')" variant="tonal" color="primary" class="send-btn"> <v-btn icon v-if="isRunning" @click="$emit('stop')" variant="tonal" color="deep-purple" class="send-btn">
<v-icon icon="mdi-stop" variant="text" plain></v-icon> <v-icon icon="mdi-stop" variant="text" plain></v-icon>
<v-tooltip activator="parent" location="top"> <v-tooltip activator="parent" location="top">
{{ tm('input.stopGenerating') }} {{ tm('input.stopGenerating') }}
</v-tooltip> </v-tooltip>
</v-btn> </v-btn>
<v-btn v-else @click="$emit('send')" icon="mdi-send" variant="tonal" color="primary" <v-btn v-else @click="$emit('send')" icon="mdi-send" variant="tonal" color="deep-purple"
:disabled="!canSend" class="send-btn" /> :disabled="!canSend" class="send-btn" />
</div> </div>
</div> </div>
@@ -117,7 +117,7 @@
</div> </div>
<div v-if="stagedAudioUrl" class="audio-preview"> <div v-if="stagedAudioUrl" class="audio-preview">
<v-chip color="primary" variant="tonal" class="audio-chip"> <v-chip color="deep-purple-lighten-4" class="audio-chip">
<v-icon start icon="mdi-microphone" size="small"></v-icon> <v-icon start icon="mdi-microphone" size="small"></v-icon>
{{ tm('voice.recording') }} {{ tm('voice.recording') }}
</v-chip> </v-chip>
@@ -126,7 +126,7 @@
</div> </div>
<div v-for="(file, index) in stagedFiles" :key="'file-' + index" class="file-preview"> <div v-for="(file, index) in stagedFiles" :key="'file-' + index" class="file-preview">
<v-chip color="primary" variant="tonal" class="file-chip"> <v-chip color="blue-grey-lighten-4" class="file-chip">
<v-icon start icon="mdi-file-document-outline" size="small"></v-icon> <v-icon start icon="mdi-file-document-outline" size="small"></v-icon>
<span class="file-name-preview">{{ file.original_name }}</span> <span class="file-name-preview">{{ file.original_name }}</span>
</v-chip> </v-chip>
@@ -173,7 +173,6 @@ interface Props {
currentSession?: Session | null; currentSession?: Session | null;
configId?: string | null; configId?: string | null;
replyTo?: ReplyInfo | null; replyTo?: ReplyInfo | null;
sendShortcut?: 'enter' | 'shift_enter';
} }
const props = withDefaults(defineProps<Props>(), { const props = withDefaults(defineProps<Props>(), {
@@ -181,8 +180,7 @@ const props = withDefaults(defineProps<Props>(), {
currentSession: null, currentSession: null,
configId: null, configId: null,
stagedFiles: () => [], stagedFiles: () => [],
replyTo: null, replyTo: null
sendShortcut: 'shift_enter'
}); });
const emit = defineEmits<{ const emit = defineEmits<{
@@ -255,8 +253,21 @@ watch(localPrompt, () => {
}); });
function handleKeyDown(e: KeyboardEvent) { function handleKeyDown(e: KeyboardEvent) {
const isEnter = e.key === 'Enter'; // Enter
if (!isEnter) { // Shift+Enter Ctrl+Enter / Cmd+Enter
if (e.keyCode === 13 && (e.shiftKey || e.ctrlKey || e.metaKey)) {
e.preventDefault();
if (localPrompt.value.trim() === '/astr_live_dev') {
emit('openLiveMode');
localPrompt.value = '';
return;
}
if (canSend.value) {
emit('send');
}
return;
}
// Ctrl+B // Ctrl+B
if (e.ctrlKey && e.keyCode === 66) { if (e.ctrlKey && e.keyCode === 66) {
e.preventDefault(); e.preventDefault();
@@ -269,26 +280,6 @@ function handleKeyDown(e: KeyboardEvent) {
} }
}, ctrlKeyLongPressThreshold); }, ctrlKeyLongPressThreshold);
} }
return;
}
const isSendHotkey =
e.ctrlKey ||
e.metaKey ||
(props.sendShortcut === 'enter' ? !e.shiftKey : e.shiftKey);
if (isSendHotkey) {
e.preventDefault();
if (localPrompt.value.trim() === '/astr_live_dev') {
emit('openLiveMode');
localPrompt.value = '';
return;
}
if (canSend.value) {
emit('send');
}
return;
}
} }
function handleKeyUp(e: KeyboardEvent) { function handleKeyUp(e: KeyboardEvent) {
@@ -373,11 +364,6 @@ function getCurrentSelection() {
return providerModelMenuRef.value?.getCurrentSelection(); return providerModelMenuRef.value?.getCurrentSelection();
} }
function focusInput() {
if (!inputField.value) return;
inputField.value.focus();
}
onMounted(() => { onMounted(() => {
if (inputField.value) { if (inputField.value) {
inputField.value.addEventListener('paste', handlePaste); inputField.value.addEventListener('paste', handlePaste);
@@ -393,8 +379,7 @@ onBeforeUnmount(() => {
}); });
defineExpose({ defineExpose({
getCurrentSelection, getCurrentSelection
focusInput
}); });
</script> </script>
@@ -414,8 +399,8 @@ defineExpose({
left: 0; left: 0;
right: 0; right: 0;
bottom: 0; bottom: 0;
background-color: rgba(var(--v-theme-primary), 0.12); background-color: rgba(103, 58, 183, 0.15);
border: 2px dashed rgba(var(--v-theme-primary), 0.45); border: 2px dashed rgba(103, 58, 183, 0.5);
border-radius: 24px; border-radius: 24px;
display: flex; display: flex;
align-items: center; align-items: center;
@@ -434,7 +419,7 @@ defineExpose({
.drop-text { .drop-text {
font-size: 16px; font-size: 16px;
font-weight: 500; font-weight: 500;
color: rgb(var(--v-theme-primary)); color: #673ab7;
} }
/* Fade transition for drop overlay */ /* Fade transition for drop overlay */
@@ -454,7 +439,7 @@ defineExpose({
justify-content: space-between; justify-content: space-between;
padding: 8px 16px; padding: 8px 16px;
margin: 8px 8px 0 8px; margin: 8px 8px 0 8px;
background-color: rgba(var(--v-theme-primary), 0.06); background-color: rgba(103, 58, 183, 0.06);
border-radius: 12px; border-radius: 12px;
gap: 8px; gap: 8px;
max-height: 500px; max-height: 500px;

Some files were not shown because too many files have changed in this diff Show More