Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 349ca05e26 |
@@ -21,23 +21,7 @@
|
|||||||
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
|
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
|
||||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
||||||
|
|
||||||
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。
|
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
|
||||||
/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
|
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
|
||||||
|
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
|
||||||
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。
|
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
|
||||||
/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
|
|
||||||
|
|
||||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
|
||||||
/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
|
|
||||||
|
|
||||||
- [ ] 😮 我的更改没有引入恶意代码。
|
|
||||||
/ My changes do not introduce malicious code.
|
|
||||||
|
|
||||||
- [ ] ⚠️ 我已认真阅读并理解以上所有内容,确保本次提交符合规范。
|
|
||||||
/ I have read and understood all the above and confirm this PR follows the rules.
|
|
||||||
|
|
||||||
- [ ] 🚀 我确保本次开发**基于 dev 分支**,并将代码合并至**开发分支**(除非极其紧急,才允许合并到主分支)。
|
|
||||||
/ I confirm that this development is **based on the dev branch** and will be merged into the **development branch**, unless it is extremely urgent to merge into the main branch.
|
|
||||||
|
|
||||||
- [ ] ⚠️ 我**没有**认真阅读以上内容,直接提交。
|
|
||||||
/ I **did not** read the above carefully before submitting.
|
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
name: release
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest # 运行环境
|
|
||||||
steps:
|
|
||||||
- name: checkout
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
- name: nodejs installation
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "18"
|
|
||||||
- name: npm install
|
|
||||||
run: npm add -D vitepress
|
|
||||||
working-directory: './docs' # working-directory 指定 shell 命令运行目录
|
|
||||||
- name: npm run build
|
|
||||||
run: npm run docs:build
|
|
||||||
working-directory: './docs'
|
|
||||||
- name: scp
|
|
||||||
uses: appleboy/scp-action@v1.0.0
|
|
||||||
with:
|
|
||||||
host: ${{ secrets.HOST_NEKO }}
|
|
||||||
username: ${{ secrets.USERNAME }}
|
|
||||||
password: ${{ secrets.PASSWORDNEKO }}
|
|
||||||
source: 'docs/.vitepress/dist/*'
|
|
||||||
target: '/tmp/'
|
|
||||||
- name: script
|
|
||||||
uses: appleboy/ssh-action@v1.2.5
|
|
||||||
with:
|
|
||||||
host: ${{ secrets.HOST_NEKO }}
|
|
||||||
username: ${{ secrets.USERNAME }}
|
|
||||||
password: ${{ secrets.PASSWORDNEKO }}
|
|
||||||
script: |
|
|
||||||
mkdir -p /root/docker_data/caddy/caddy_data/static_site/abv4/
|
|
||||||
rm -rf /root/docker_data/caddy/caddy_data/static_site/abv4/*
|
|
||||||
mv /tmp/docs/.vitepress/dist/* /root/docker_data/caddy/caddy_data/static_site/abv4/
|
|
||||||
rm -rf /tmp/docs/
|
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create GitHub Release
|
- name: Create GitHub Release
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: ncipollo/release-action@v1.20.0
|
uses: ncipollo/release-action@v1
|
||||||
with:
|
with:
|
||||||
tag: release-${{ github.sha }}
|
tag: release-${{ github.sha }}
|
||||||
owner: AstrBotDevs
|
owner: AstrBotDevs
|
||||||
|
|||||||
@@ -64,20 +64,20 @@ jobs:
|
|||||||
echo "build_date=$build_date" >> $GITHUB_OUTPUT
|
echo "build_date=$build_date" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set QEMU
|
- name: Set QEMU
|
||||||
uses: docker/setup-qemu-action@v4.0.0
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
- name: Set Docker Buildx
|
- name: Set Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v4.0.0
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Log in to DockerHub
|
- name: Log in to DockerHub
|
||||||
uses: docker/login-action@v4.0.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||||
|
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
if: env.HAS_GHCR_TOKEN == 'true'
|
if: env.HAS_GHCR_TOKEN == 'true'
|
||||||
uses: docker/login-action@v4.0.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ env.GHCR_OWNER }}
|
username: ${{ env.GHCR_OWNER }}
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
echo "EOF" >> $GITHUB_OUTPUT
|
echo "EOF" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Build and Push Nightly Image
|
- name: Build and Push Nightly Image
|
||||||
uses: docker/build-push-action@v7.0.0
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
@@ -163,27 +163,27 @@ jobs:
|
|||||||
cp -r dashboard/dist data/
|
cp -r dashboard/dist data/
|
||||||
|
|
||||||
- name: Set QEMU
|
- name: Set QEMU
|
||||||
uses: docker/setup-qemu-action@v4.0.0
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
- name: Set Docker Buildx
|
- name: Set Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v4.0.0
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Log in to DockerHub
|
- name: Log in to DockerHub
|
||||||
uses: docker/login-action@v4.0.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||||
|
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
if: env.HAS_GHCR_TOKEN == 'true'
|
if: env.HAS_GHCR_TOKEN == 'true'
|
||||||
uses: docker/login-action@v4.0.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ env.GHCR_OWNER }}
|
username: ${{ env.GHCR_OWNER }}
|
||||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Build and Push Release Image
|
- name: Build and Push Release Image
|
||||||
uses: docker/build-push-action@v7.0.0
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
name: PR Checklist Check
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request_target:
|
|
||||||
types: [opened, edited, reopened, synchronize]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
pull-requests: write
|
|
||||||
issues: write
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check checklist
|
|
||||||
id: check
|
|
||||||
uses: actions/github-script@v7
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const body = context.payload.pull_request.body || "";
|
|
||||||
const regex = /-\s*\[\s*x\s*\].*没有.*认真阅读/i;
|
|
||||||
const bad = regex.test(body);
|
|
||||||
core.setOutput("bad", bad);
|
|
||||||
|
|
||||||
- name: Close PR
|
|
||||||
if: steps.check.outputs.bad == 'true'
|
|
||||||
uses: actions/github-script@v7
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const pr = context.payload.pull_request;
|
|
||||||
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: pr.number,
|
|
||||||
body: `检测到你勾选了“我没有认真阅读”,PR 已关闭。`
|
|
||||||
});
|
|
||||||
|
|
||||||
await github.rest.pulls.update({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
pull_number: pr.number,
|
|
||||||
state: "closed"
|
|
||||||
});
|
|
||||||
@@ -50,7 +50,7 @@ jobs:
|
|||||||
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@v4.4.0
|
uses: pnpm/action-setup@v4
|
||||||
with:
|
with:
|
||||||
version: 10.28.2
|
version: 10.28.2
|
||||||
|
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
name: sync wiki
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
paths:
|
|
||||||
- '.github/workflows/sync-wiki.yml'
|
|
||||||
- 'docs/scripts/sync_docs_to_wiki.py'
|
|
||||||
- 'docs/tests/test_sync_docs_to_wiki.py'
|
|
||||||
- 'docs/zh/**'
|
|
||||||
- 'docs/en/**'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: sync-wiki-${{ github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
sync:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Validate manual ref
|
|
||||||
if: github.event_name == 'workflow_dispatch' && github.ref != 'refs/heads/master'
|
|
||||||
run: |
|
|
||||||
echo "This workflow only publishes from refs/heads/master. Re-run it from the master branch."
|
|
||||||
exit 1
|
|
||||||
|
|
||||||
- name: Check out docs repository
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Run sync unit tests
|
|
||||||
working-directory: docs
|
|
||||||
run: python -m unittest discover -s tests -p 'test_sync_docs_to_wiki.py' -v
|
|
||||||
|
|
||||||
- name: Validate internal doc links
|
|
||||||
run: python docs/scripts/sync_docs_to_wiki.py --source-root docs --check-links-only
|
|
||||||
|
|
||||||
- name: Clone AstrBot wiki
|
|
||||||
env:
|
|
||||||
WIKI_TOKEN: ${{ secrets.ASTRBOT_WIKI_TOKEN }}
|
|
||||||
run: |
|
|
||||||
test -n "$WIKI_TOKEN"
|
|
||||||
git clone "https://x-access-token:${WIKI_TOKEN}@github.com/AstrBotDevs/AstrBot.wiki.git" wiki
|
|
||||||
|
|
||||||
- name: Generate wiki pages
|
|
||||||
run: python docs/scripts/sync_docs_to_wiki.py --source-root docs --wiki-root wiki
|
|
||||||
|
|
||||||
- name: Commit and push wiki changes
|
|
||||||
working-directory: wiki
|
|
||||||
run: |
|
|
||||||
git config user.name "github-actions[bot]"
|
|
||||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
|
||||||
git add .
|
|
||||||
if git diff --cached --quiet; then
|
|
||||||
echo "No wiki changes to push"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
git commit -m "docs: sync wiki from AstrBot-1/docs"
|
|
||||||
git push
|
|
||||||
@@ -61,5 +61,3 @@ GenieData/
|
|||||||
.codex/
|
.codex/
|
||||||
.opencode/
|
.opencode/
|
||||||
.kilocode/
|
.kilocode/
|
||||||
.worktrees/
|
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ For users who want to quickly experience AstrBot, are familiar with command-line
|
|||||||
```bash
|
```bash
|
||||||
uv tool install astrbot
|
uv tool install astrbot
|
||||||
astrbot init # Only execute this command for the first time to initialize the environment
|
astrbot init # Only execute this command for the first time to initialize the environment
|
||||||
astrbot run
|
astrbot
|
||||||
```
|
```
|
||||||
|
|
||||||
> Requires [uv](https://docs.astral.sh/uv/) to be installed.
|
> Requires [uv](https://docs.astral.sh/uv/) to be installed.
|
||||||
@@ -234,8 +234,7 @@ pre-commit install
|
|||||||
- Group 7: 743746109
|
- Group 7: 743746109
|
||||||
- Group 8: 1030353265
|
- Group 8: 1030353265
|
||||||
|
|
||||||
- Developer Group(Chit-chat): 975206796
|
- Developer Group: 975206796
|
||||||
- Developer Group(Formal): 1039761811
|
|
||||||
|
|
||||||
### Discord Server
|
### Discord Server
|
||||||
|
|
||||||
|
|||||||
+1
-2
@@ -78,7 +78,7 @@ Pour les utilisateurs qui veulent découvrir AstrBot rapidement, qui sont famili
|
|||||||
```bash
|
```bash
|
||||||
uv tool install astrbot
|
uv tool install astrbot
|
||||||
astrbot init # Exécutez cette commande uniquement la première fois pour initialiser l'environnement
|
astrbot init # Exécutez cette commande uniquement la première fois pour initialiser l'environnement
|
||||||
astrbot run
|
astrbot
|
||||||
```
|
```
|
||||||
|
|
||||||
> [uv](https://docs.astral.sh/uv/) doit être installé.
|
> [uv](https://docs.astral.sh/uv/) doit être installé.
|
||||||
@@ -222,7 +222,6 @@ pre-commit install
|
|||||||
- Groupe 5 : 822130018
|
- Groupe 5 : 822130018
|
||||||
- Groupe 6 : 753075035
|
- Groupe 6 : 753075035
|
||||||
- Groupe développeurs : 975206796
|
- Groupe développeurs : 975206796
|
||||||
- Groupe développeurs (officiel) : 1039761811
|
|
||||||
|
|
||||||
### Serveur Discord
|
### Serveur Discord
|
||||||
|
|
||||||
|
|||||||
+1
-2
@@ -78,7 +78,7 @@ AstrBot を素早く試したいユーザーで、コマンドラインに慣れ
|
|||||||
```bash
|
```bash
|
||||||
uv tool install astrbot
|
uv tool install astrbot
|
||||||
astrbot init # 初回のみ実行して環境を初期化します
|
astrbot init # 初回のみ実行して環境を初期化します
|
||||||
astrbot run
|
astrbot
|
||||||
```
|
```
|
||||||
|
|
||||||
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。
|
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。
|
||||||
@@ -223,7 +223,6 @@ pre-commit install
|
|||||||
- 5群: 822130018
|
- 5群: 822130018
|
||||||
- 6群: 753075035
|
- 6群: 753075035
|
||||||
- 開発者群: 975206796
|
- 開発者群: 975206796
|
||||||
- 開発者群(正式): 1039761811
|
|
||||||
|
|
||||||
### Discord サーバー
|
### Discord サーバー
|
||||||
|
|
||||||
|
|||||||
+1
-2
@@ -78,7 +78,7 @@ AstrBot — это универсальная платформа Agent-чатб
|
|||||||
```bash
|
```bash
|
||||||
uv tool install astrbot
|
uv tool install astrbot
|
||||||
astrbot init # Выполните эту команду только при первом запуске для инициализации окружения
|
astrbot init # Выполните эту команду только при первом запуске для инициализации окружения
|
||||||
astrbot run
|
astrbot
|
||||||
```
|
```
|
||||||
|
|
||||||
> Требуется установленный [uv](https://docs.astral.sh/uv/).
|
> Требуется установленный [uv](https://docs.astral.sh/uv/).
|
||||||
@@ -222,7 +222,6 @@ pre-commit install
|
|||||||
- Группа 5: 822130018
|
- Группа 5: 822130018
|
||||||
- Группа 6: 753075035
|
- Группа 6: 753075035
|
||||||
- Группа разработчиков: 975206796
|
- Группа разработчиков: 975206796
|
||||||
- Группа разработчиков (официальная): 1039761811
|
|
||||||
|
|
||||||
### Сервер Discord
|
### Сервер Discord
|
||||||
|
|
||||||
|
|||||||
+2
-3
@@ -78,7 +78,7 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
|
|||||||
```bash
|
```bash
|
||||||
uv tool install astrbot
|
uv tool install astrbot
|
||||||
astrbot init # 僅首次執行此命令以初始化環境
|
astrbot init # 僅首次執行此命令以初始化環境
|
||||||
astrbot run
|
astrbot
|
||||||
```
|
```
|
||||||
|
|
||||||
> 需要安裝 [uv](https://docs.astral.sh/uv/)。
|
> 需要安裝 [uv](https://docs.astral.sh/uv/)。
|
||||||
@@ -225,8 +225,7 @@ pre-commit install
|
|||||||
- 6 群:753075035
|
- 6 群:753075035
|
||||||
- 7 群:743746109
|
- 7 群:743746109
|
||||||
- 8 群:1030353265
|
- 8 群:1030353265
|
||||||
- 開發者群(闲聊吹水):975206796
|
- 開發者群:975206796
|
||||||
- 開發者群(正式):1039761811
|
|
||||||
|
|
||||||
### Discord 群組
|
### Discord 群組
|
||||||
|
|
||||||
|
|||||||
+2
-3
@@ -78,7 +78,7 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、
|
|||||||
```bash
|
```bash
|
||||||
uv tool install astrbot
|
uv tool install astrbot
|
||||||
astrbot init # 仅首次执行此命令以初始化环境
|
astrbot init # 仅首次执行此命令以初始化环境
|
||||||
astrbot run
|
astrbot
|
||||||
```
|
```
|
||||||
|
|
||||||
> 需要安装 [uv](https://docs.astral.sh/uv/)。
|
> 需要安装 [uv](https://docs.astral.sh/uv/)。
|
||||||
@@ -226,8 +226,7 @@ pre-commit install
|
|||||||
- 6 群:753075035
|
- 6 群:753075035
|
||||||
- 7 群:743746109
|
- 7 群:743746109
|
||||||
- 8 群:1030353265
|
- 8 群:1030353265
|
||||||
- 开发者群(偏闲聊吹水):975206796
|
- 开发者群:975206796
|
||||||
- 开发者群(正式):1039761811
|
|
||||||
|
|
||||||
### Discord 频道
|
### Discord 频道
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "4.20.0"
|
__version__ = "4.19.2"
|
||||||
|
|||||||
@@ -4,21 +4,7 @@ from astrbot.core.config import AstrBotConfig
|
|||||||
from astrbot.core.config.default import DB_PATH
|
from astrbot.core.config.default import DB_PATH
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
from astrbot.core.file_token_service import FileTokenService
|
from astrbot.core.file_token_service import FileTokenService
|
||||||
from astrbot.core.utils.pip_installer import (
|
from astrbot.core.utils.pip_installer import PipInstaller
|
||||||
DependencyConflictError as DependencyConflictError,
|
|
||||||
)
|
|
||||||
from astrbot.core.utils.pip_installer import (
|
|
||||||
PipInstaller,
|
|
||||||
)
|
|
||||||
from astrbot.core.utils.requirements_utils import (
|
|
||||||
RequirementsPrecheckFailed as RequirementsPrecheckFailed,
|
|
||||||
)
|
|
||||||
from astrbot.core.utils.requirements_utils import (
|
|
||||||
find_missing_requirements as find_missing_requirements,
|
|
||||||
)
|
|
||||||
from astrbot.core.utils.requirements_utils import (
|
|
||||||
find_missing_requirements_or_raise as find_missing_requirements_or_raise,
|
|
||||||
)
|
|
||||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||||
|
|
||||||
|
|||||||
@@ -144,14 +144,10 @@ class MCPClient:
|
|||||||
|
|
||||||
cfg = _prepare_config(mcp_server_config.copy())
|
cfg = _prepare_config(mcp_server_config.copy())
|
||||||
|
|
||||||
def logging_callback(
|
def logging_callback(msg: str) -> None:
|
||||||
msg: str | mcp.types.LoggingMessageNotificationParams,
|
|
||||||
) -> None:
|
|
||||||
# Handle MCP service error logs
|
# Handle MCP service error logs
|
||||||
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
|
print(f"MCP Server {name} Error: {msg}")
|
||||||
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
|
self.server_errlogs.append(msg)
|
||||||
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
|
|
||||||
self.server_errlogs.append(log_msg)
|
|
||||||
|
|
||||||
if "url" in cfg:
|
if "url" in cfg:
|
||||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||||
@@ -218,24 +214,15 @@ class MCPClient:
|
|||||||
**cfg,
|
**cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None:
|
def callback(msg: str) -> None:
|
||||||
# Handle MCP service error logs
|
# Handle MCP service error logs
|
||||||
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
|
self.server_errlogs.append(msg)
|
||||||
if msg.level in (
|
|
||||||
"warning",
|
|
||||||
"error",
|
|
||||||
"critical",
|
|
||||||
"alert",
|
|
||||||
"emergency",
|
|
||||||
):
|
|
||||||
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
|
|
||||||
self.server_errlogs.append(log_msg)
|
|
||||||
|
|
||||||
stdio_transport = await self.exit_stack.enter_async_context(
|
stdio_transport = await self.exit_stack.enter_async_context(
|
||||||
mcp.stdio_client(
|
mcp.stdio_client(
|
||||||
server_params,
|
server_params,
|
||||||
errlog=LogPipe(
|
errlog=LogPipe(
|
||||||
level=logging.INFO,
|
level=logging.ERROR,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
identifier=f"MCPServer-{name}",
|
identifier=f"MCPServer-{name}",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
|||||||
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
|
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
||||||
None, response_queue.get, True, 1
|
None, response_queue.get, True, 1
|
||||||
)
|
)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
# 发起请求
|
# 发起请求
|
||||||
partial = functools.partial(Application.call, **payload)
|
partial = functools.partial(Application.call, **payload)
|
||||||
response = await asyncio.get_running_loop().run_in_executor(None, partial)
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||||
|
|
||||||
async for resp in self._handle_streaming_response(response, session_id):
|
async for resp in self._handle_streaming_response(response, session_id):
|
||||||
yield resp
|
yield resp
|
||||||
|
|||||||
@@ -326,7 +326,6 @@ async def run_live_agent(
|
|||||||
|
|
||||||
# 创建队列
|
# 创建队列
|
||||||
text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||||
delta_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
|
||||||
# audio_queue stored bytes or (text, bytes)
|
# audio_queue stored bytes or (text, bytes)
|
||||||
audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None] = asyncio.Queue()
|
audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None] = asyncio.Queue()
|
||||||
|
|
||||||
@@ -335,7 +334,6 @@ async def run_live_agent(
|
|||||||
_run_agent_feeder(
|
_run_agent_feeder(
|
||||||
agent_runner,
|
agent_runner,
|
||||||
text_queue,
|
text_queue,
|
||||||
delta_queue,
|
|
||||||
max_step,
|
max_step,
|
||||||
show_tool_use,
|
show_tool_use,
|
||||||
show_tool_call_result,
|
show_tool_call_result,
|
||||||
@@ -355,63 +353,32 @@ async def run_live_agent(
|
|||||||
|
|
||||||
# 3. 主循环:从 audio_queue 读取音频并 yield
|
# 3. 主循环:从 audio_queue 读取音频并 yield
|
||||||
try:
|
try:
|
||||||
delta_done = False
|
while True:
|
||||||
audio_done = False
|
queue_item = await audio_queue.get()
|
||||||
while not (delta_done and audio_done):
|
|
||||||
task_sources: dict[asyncio.Task, str] = {}
|
|
||||||
if not delta_done:
|
|
||||||
task = asyncio.create_task(delta_queue.get())
|
|
||||||
task_sources[task] = "delta"
|
|
||||||
if not audio_done:
|
|
||||||
task = asyncio.create_task(audio_queue.get())
|
|
||||||
task_sources[task] = "audio"
|
|
||||||
|
|
||||||
done, pending = await asyncio.wait(
|
if queue_item is None:
|
||||||
list(task_sources),
|
break
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
|
||||||
)
|
|
||||||
|
|
||||||
for task in pending:
|
text = None
|
||||||
task.cancel()
|
if isinstance(queue_item, tuple):
|
||||||
if pending:
|
text, audio_data = queue_item
|
||||||
await asyncio.gather(*pending, return_exceptions=True)
|
else:
|
||||||
|
audio_data = queue_item
|
||||||
|
|
||||||
for task in done:
|
if not first_chunk_received:
|
||||||
source = task_sources[task]
|
# 记录首帧延迟(从开始处理到收到第一个音频块)
|
||||||
queue_item = task.result()
|
tts_first_frame_time = time.time() - tts_start_time
|
||||||
if source == "delta":
|
first_chunk_received = True
|
||||||
if queue_item is None:
|
|
||||||
delta_done = True
|
|
||||||
continue
|
|
||||||
yield MessageChain(
|
|
||||||
chain=[Plain(queue_item)], type="live_text_delta"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if queue_item is None:
|
# 将音频数据封装为 MessageChain
|
||||||
audio_done = True
|
import base64
|
||||||
continue
|
|
||||||
|
|
||||||
text = None
|
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
|
||||||
if isinstance(queue_item, tuple):
|
comps: list[BaseMessageComponent] = [Plain(audio_b64)]
|
||||||
text, audio_data = queue_item
|
if text:
|
||||||
else:
|
comps.append(Json(data={"text": text}))
|
||||||
audio_data = queue_item
|
chain = MessageChain(chain=comps, type="audio_chunk")
|
||||||
|
yield chain
|
||||||
if not first_chunk_received:
|
|
||||||
# 记录首帧延迟(从开始处理到收到第一个音频块)
|
|
||||||
tts_first_frame_time = time.time() - tts_start_time
|
|
||||||
first_chunk_received = True
|
|
||||||
|
|
||||||
# 将音频数据封装为 MessageChain
|
|
||||||
import base64
|
|
||||||
|
|
||||||
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
|
|
||||||
comps: list[BaseMessageComponent] = [Plain(audio_b64)]
|
|
||||||
if text:
|
|
||||||
comps.append(Json(data={"text": text}))
|
|
||||||
chain = MessageChain(chain=comps, type="audio_chunk")
|
|
||||||
yield chain
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Live Agent] 运行时发生错误: {e}", exc_info=True)
|
logger.error(f"[Live Agent] 运行时发生错误: {e}", exc_info=True)
|
||||||
@@ -454,7 +421,6 @@ async def run_live_agent(
|
|||||||
async def _run_agent_feeder(
|
async def _run_agent_feeder(
|
||||||
agent_runner: AgentRunner,
|
agent_runner: AgentRunner,
|
||||||
text_queue: asyncio.Queue,
|
text_queue: asyncio.Queue,
|
||||||
delta_queue: asyncio.Queue,
|
|
||||||
max_step: int,
|
max_step: int,
|
||||||
show_tool_use: bool,
|
show_tool_use: bool,
|
||||||
show_tool_call_result: bool,
|
show_tool_call_result: bool,
|
||||||
@@ -474,13 +440,9 @@ async def _run_agent_feeder(
|
|||||||
if chain is None:
|
if chain is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if chain.type == "reasoning":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 提取文本
|
# 提取文本
|
||||||
text = chain.get_plain_text()
|
text = chain.get_plain_text()
|
||||||
if text:
|
if text:
|
||||||
await delta_queue.put(text)
|
|
||||||
buffer += text
|
buffer += text
|
||||||
|
|
||||||
# 分句逻辑:匹配标点符号
|
# 分句逻辑:匹配标点符号
|
||||||
@@ -515,7 +477,6 @@ async def _run_agent_feeder(
|
|||||||
finally:
|
finally:
|
||||||
# 发送结束信号
|
# 发送结束信号
|
||||||
await text_queue.put(None)
|
await text_queue.put(None)
|
||||||
await delta_queue.put(None)
|
|
||||||
|
|
||||||
|
|
||||||
async def _safe_tts_stream_wrapper(
|
async def _safe_tts_stream_wrapper(
|
||||||
|
|||||||
@@ -778,14 +778,9 @@ def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
|
|||||||
continue
|
continue
|
||||||
mp = tool.handler_module_path
|
mp = tool.handler_module_path
|
||||||
if not mp:
|
if not mp:
|
||||||
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*)
|
|
||||||
# 不应受到会话插件过滤影响。
|
|
||||||
new_tool_set.add_tool(tool)
|
|
||||||
continue
|
continue
|
||||||
plugin = star_map.get(mp)
|
plugin = star_map.get(mp)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
# 无法解析插件归属时,保守保留工具,避免误过滤。
|
|
||||||
new_tool_set.add_tool(tool)
|
|
||||||
continue
|
continue
|
||||||
if plugin.name in event.plugins_name or plugin.reserved:
|
if plugin.name in event.plugins_name or plugin.reserved:
|
||||||
new_tool_set.add_tool(tool)
|
new_tool_set.add_tool(tool)
|
||||||
|
|||||||
@@ -188,12 +188,7 @@ class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||||
name: str = "send_message_to_user"
|
name: str = "send_message_to_user"
|
||||||
description: str = (
|
description: str = "Directly send message to the user. Only use this tool when you need to proactively message the user. Otherwise you can directly output the reply in the conversation."
|
||||||
"Send message to the user. "
|
|
||||||
"Supports various message types including `plain`, `image`, `record`, `video`, `file`, and `mention_user`. "
|
|
||||||
"Use this tool to send media files (`image`, `record`, `video`, `file`), "
|
|
||||||
"or when you need to proactively message the user(such as cron job). For normal text replies, you can output directly."
|
|
||||||
)
|
|
||||||
|
|
||||||
parameters: dict = Field(
|
parameters: dict = Field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -209,7 +204,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": (
|
"description": (
|
||||||
"Component type. One of: "
|
"Component type. One of: "
|
||||||
"plain, image, record, video, file, mention_user. Record is voice message."
|
"plain, image, record, file, mention_user"
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"text": {
|
"text": {
|
||||||
@@ -325,19 +320,6 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
|||||||
components.append(Comp.Record.fromURL(url=url))
|
components.append(Comp.Record.fromURL(url=url))
|
||||||
else:
|
else:
|
||||||
return f"error: messages[{idx}] must include path or url for record component."
|
return f"error: messages[{idx}] must include path or url for record component."
|
||||||
elif msg_type == "video":
|
|
||||||
path = msg.get("path")
|
|
||||||
url = msg.get("url")
|
|
||||||
if path:
|
|
||||||
(
|
|
||||||
local_path,
|
|
||||||
file_from_sandbox,
|
|
||||||
) = await self._resolve_path_from_sandbox(context, path)
|
|
||||||
components.append(Comp.Video.fromFileSystem(path=local_path))
|
|
||||||
elif url:
|
|
||||||
components.append(Comp.Video.fromURL(url=url))
|
|
||||||
else:
|
|
||||||
return f"error: messages[{idx}] must include path or url for video component."
|
|
||||||
elif msg_type == "file":
|
elif msg_type == "file":
|
||||||
path = msg.get("path")
|
path = msg.get("path")
|
||||||
url = msg.get("url")
|
url = msg.get("url")
|
||||||
|
|||||||
@@ -121,12 +121,11 @@ class BayContainerManager:
|
|||||||
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
||||||
"""Block until Bay's ``/health`` endpoint returns 200."""
|
"""Block until Bay's ``/health`` endpoint returns 200."""
|
||||||
url = f"http://127.0.0.1:{self._host_port}/health"
|
url = f"http://127.0.0.1:{self._host_port}/health"
|
||||||
loop = asyncio.get_running_loop()
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
deadline = loop.time() + timeout
|
|
||||||
last_error: str = ""
|
last_error: str = ""
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
while loop.time() < deadline:
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, timeout=aiohttp.ClientTimeout(total=3)
|
url, timeout=aiohttp.ClientTimeout(total=3)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import locale
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -53,31 +52,6 @@ def _ensure_safe_path(path: str) -> str:
|
|||||||
return abs_path
|
return abs_path
|
||||||
|
|
||||||
|
|
||||||
def _decode_shell_output(output: bytes | None) -> str:
|
|
||||||
if output is None:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
preferred = locale.getpreferredencoding(False) or "utf-8"
|
|
||||||
try:
|
|
||||||
return output.decode("utf-8")
|
|
||||||
except (LookupError, UnicodeDecodeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if os.name == "nt":
|
|
||||||
for encoding in ("mbcs", "cp936", "gbk", "gb18030"):
|
|
||||||
try:
|
|
||||||
return output.decode(encoding)
|
|
||||||
except (LookupError, UnicodeDecodeError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
return output.decode(preferred)
|
|
||||||
except (LookupError, UnicodeDecodeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return output.decode("utf-8", errors="replace")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LocalShellComponent(ShellComponent):
|
class LocalShellComponent(ShellComponent):
|
||||||
async def exec(
|
async def exec(
|
||||||
@@ -98,32 +72,28 @@ class LocalShellComponent(ShellComponent):
|
|||||||
run_env.update({str(k): str(v) for k, v in env.items()})
|
run_env.update({str(k): str(v) for k, v in env.items()})
|
||||||
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
|
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
|
||||||
if background:
|
if background:
|
||||||
# `command` is intentionally executed through the current shell so
|
proc = subprocess.Popen(
|
||||||
# local computer-use behavior matches existing tool semantics.
|
|
||||||
# Safety relies on `_is_safe_command()` and the allowed-root checks.
|
|
||||||
proc = subprocess.Popen( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
|
|
||||||
command,
|
command,
|
||||||
shell=shell,
|
shell=shell,
|
||||||
cwd=working_dir,
|
cwd=working_dir,
|
||||||
env=run_env,
|
env=run_env,
|
||||||
stdout=subprocess.DEVNULL,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.DEVNULL,
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
)
|
)
|
||||||
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
|
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
|
||||||
# `command` is intentionally executed through the current shell so
|
result = subprocess.run(
|
||||||
# local computer-use behavior matches existing tool semantics.
|
|
||||||
# Safety relies on `_is_safe_command()` and the allowed-root checks.
|
|
||||||
result = subprocess.run( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
|
|
||||||
command,
|
command,
|
||||||
shell=shell,
|
shell=shell,
|
||||||
cwd=working_dir,
|
cwd=working_dir,
|
||||||
env=run_env,
|
env=run_env,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"stdout": _decode_shell_output(result.stdout),
|
"stdout": result.stdout,
|
||||||
"stderr": _decode_shell_output(result.stderr),
|
"stderr": result.stderr,
|
||||||
"exit_code": result.returncode,
|
"exit_code": result.returncode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -422,12 +422,6 @@ async def get_booter(
|
|||||||
) -> ComputerBooter:
|
) -> ComputerBooter:
|
||||||
config = context.get_config(umo=session_id)
|
config = context.get_config(umo=session_id)
|
||||||
|
|
||||||
runtime = config.get("provider_settings", {}).get("computer_use_runtime", "local")
|
|
||||||
if runtime == "local":
|
|
||||||
return get_local_booter()
|
|
||||||
elif runtime == "none":
|
|
||||||
raise RuntimeError("Sandbox runtime is disabled by configuration.")
|
|
||||||
|
|
||||||
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
|
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
|
||||||
booter_type = sandbox_cfg.get("booter", "shipyard_neo")
|
booter_type = sandbox_cfg.get("booter", "shipyard_neo")
|
||||||
|
|
||||||
|
|||||||
@@ -164,10 +164,7 @@ class CreateSkillPayloadTool(NeoSkillToolBase):
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"payload": {
|
"payload": {
|
||||||
"anyOf": [
|
"anyOf": [{"type": "object"}, {"type": "array"}],
|
||||||
{"type": "object"},
|
|
||||||
{"type": "array", "items": {"type": "object"}},
|
|
||||||
],
|
|
||||||
"description": (
|
"description": (
|
||||||
"Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. "
|
"Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. "
|
||||||
"This only stores content and returns payload_ref; it does not create a candidate or release."
|
"This only stores content and returns payload_ref; it does not create a candidate or release."
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.20.0"
|
VERSION = "4.19.2"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||||
@@ -219,9 +219,6 @@ DEFAULT_CONFIG = {
|
|||||||
"telegram": {
|
"telegram": {
|
||||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
||||||
},
|
},
|
||||||
"discord": {
|
|
||||||
"pre_ack_emoji": {"enable": False, "emojis": ["🤔"]},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"wake_prefix": ["/"],
|
"wake_prefix": ["/"],
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
@@ -345,20 +342,14 @@ CONFIG_METADATA_2 = {
|
|||||||
"企业微信智能机器人": {
|
"企业微信智能机器人": {
|
||||||
"id": "wecom_ai_bot",
|
"id": "wecom_ai_bot",
|
||||||
"type": "wecom_ai_bot",
|
"type": "wecom_ai_bot",
|
||||||
"hint": "如果发现字段有异常,请重新创建",
|
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"wecom_ai_bot_connection_mode": "long_connection", # long_connection, webhook
|
|
||||||
"wecom_ai_bot_name": "",
|
|
||||||
"wecomaibot_ws_bot_id": "",
|
|
||||||
"wecomaibot_ws_secret": "",
|
|
||||||
"wecomaibot_token": "",
|
|
||||||
"wecomaibot_encoding_aes_key": "",
|
|
||||||
"wecomaibot_init_respond_text": "",
|
"wecomaibot_init_respond_text": "",
|
||||||
"wecomaibot_friend_message_welcome_text": "",
|
"wecomaibot_friend_message_welcome_text": "",
|
||||||
|
"wecom_ai_bot_name": "",
|
||||||
"msg_push_webhook_url": "",
|
"msg_push_webhook_url": "",
|
||||||
"only_use_webhook_url_to_send": False,
|
"only_use_webhook_url_to_send": False,
|
||||||
"wecomaibot_ws_url": "wss://openws.work.weixin.qq.com",
|
"token": "",
|
||||||
"wecomaibot_heartbeat_interval": 30,
|
"encoding_aes_key": "",
|
||||||
"unified_webhook_mode": True,
|
"unified_webhook_mode": True,
|
||||||
"webhook_uuid": "",
|
"webhook_uuid": "",
|
||||||
"callback_server_host": "0.0.0.0",
|
"callback_server_host": "0.0.0.0",
|
||||||
@@ -741,13 +732,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "请务必填写正确,否则无法使用一些指令。",
|
"hint": "请务必填写正确,否则无法使用一些指令。",
|
||||||
},
|
},
|
||||||
"wecom_ai_bot_connection_mode": {
|
|
||||||
"description": "企业微信智能机器人连接模式",
|
|
||||||
"type": "string",
|
|
||||||
"options": ["webhook", "long_connection"],
|
|
||||||
"labels": ["Webhook 回调", "长连接"],
|
|
||||||
"hint": "Webhook 回调模式需要配置 Token/EncodingAESKey。长连接模式需要配置 BotID/Secret。",
|
|
||||||
},
|
|
||||||
"wecomaibot_init_respond_text": {
|
"wecomaibot_init_respond_text": {
|
||||||
"description": "企业微信智能机器人初始响应文本",
|
"description": "企业微信智能机器人初始响应文本",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -758,22 +742,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
|
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
|
||||||
},
|
},
|
||||||
"wecomaibot_token": {
|
|
||||||
"description": "企业微信智能机器人 Token",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "用于 Webhook 回调模式的身份验证。",
|
|
||||||
"condition": {
|
|
||||||
"wecom_ai_bot_connection_mode": "webhook",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"wecomaibot_encoding_aes_key": {
|
|
||||||
"description": "企业微信智能机器人 EncodingAESKey",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "用于 Webhook 回调模式的消息加密解密。",
|
|
||||||
"condition": {
|
|
||||||
"wecom_ai_bot_connection_mode": "webhook",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"msg_push_webhook_url": {
|
"msg_push_webhook_url": {
|
||||||
"description": "企业微信消息推送 Webhook URL",
|
"description": "企业微信消息推送 Webhook URL",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -784,40 +752,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
|
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
|
||||||
},
|
},
|
||||||
"wecomaibot_ws_bot_id": {
|
|
||||||
"description": "长连接 BotID",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "企业微信智能机器人长连接模式凭证 BotID。",
|
|
||||||
"condition": {
|
|
||||||
"wecom_ai_bot_connection_mode": "long_connection",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"wecomaibot_ws_secret": {
|
|
||||||
"description": "长连接 Secret",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "企业微信智能机器人长连接模式凭证 Secret。",
|
|
||||||
"condition": {
|
|
||||||
"wecom_ai_bot_connection_mode": "long_connection",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"wecomaibot_ws_url": {
|
|
||||||
"description": "长连接 WebSocket 地址",
|
|
||||||
"type": "string",
|
|
||||||
"invisible": True,
|
|
||||||
"hint": "默认值为 wss://openws.work.weixin.qq.com,一般无需修改。",
|
|
||||||
"condition": {
|
|
||||||
"wecom_ai_bot_connection_mode": "long_connection",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"wecomaibot_heartbeat_interval": {
|
|
||||||
"description": "长连接心跳间隔",
|
|
||||||
"type": "int",
|
|
||||||
"invisible": True,
|
|
||||||
"hint": "长连接模式心跳间隔(秒),建议 30 秒。",
|
|
||||||
"condition": {
|
|
||||||
"wecom_ai_bot_connection_mode": "long_connection",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"lark_bot_name": {
|
"lark_bot_name": {
|
||||||
"description": "飞书机器人的名字",
|
"description": "飞书机器人的名字",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -862,7 +796,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"unified_webhook_mode": {
|
"unified_webhook_mode": {
|
||||||
"description": "统一 Webhook 模式",
|
"description": "统一 Webhook 模式",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "Webhook 模式下使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。",
|
"hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。",
|
||||||
},
|
},
|
||||||
"webhook_uuid": {
|
"webhook_uuid": {
|
||||||
"invisible": True,
|
"invisible": True,
|
||||||
@@ -1132,18 +1066,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"proxy": "",
|
"proxy": "",
|
||||||
"custom_headers": {},
|
"custom_headers": {},
|
||||||
},
|
},
|
||||||
"MiniMax": {
|
|
||||||
"id": "minimax",
|
|
||||||
"provider": "minimax",
|
|
||||||
"type": "openai_chat_completion",
|
|
||||||
"provider_type": "chat_completion",
|
|
||||||
"enable": True,
|
|
||||||
"key": [],
|
|
||||||
"api_base": "https://api.minimaxi.com/v1",
|
|
||||||
"timeout": 120,
|
|
||||||
"proxy": "",
|
|
||||||
"custom_headers": {},
|
|
||||||
},
|
|
||||||
"xAI": {
|
"xAI": {
|
||||||
"id": "xai",
|
"id": "xai",
|
||||||
"provider": "xai",
|
"provider": "xai",
|
||||||
|
|||||||
@@ -332,9 +332,9 @@ class CronJobManager:
|
|||||||
cron_job=cron_job_str
|
cron_job=cron_job_str
|
||||||
)
|
)
|
||||||
req.prompt = (
|
req.prompt = (
|
||||||
"You are now responding to a scheduled task. "
|
"You are now responding to a scheduled task"
|
||||||
"Proceed according to your system instructions. "
|
"Proceed according to your system instructions. "
|
||||||
"Output using same language as previous conversation. "
|
"Output using same language as previous conversation."
|
||||||
"After completing your task, summarize and output your actions and results."
|
"After completing your task, summarize and output your actions and results."
|
||||||
)
|
)
|
||||||
if not req.func_tool:
|
if not req.func_tool:
|
||||||
|
|||||||
@@ -647,13 +647,6 @@ class BaseDatabase(abc.ABC):
|
|||||||
"""Get a Platform session by its ID."""
|
"""Get a Platform session by its ID."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def get_platform_sessions_by_ids(
|
|
||||||
self, session_ids: list[str]
|
|
||||||
) -> list[PlatformSession]:
|
|
||||||
"""Get platform sessions by IDs."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_platform_sessions_by_creator(
|
async def get_platform_sessions_by_creator(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1417,21 +1417,6 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def get_platform_sessions_by_ids(
|
|
||||||
self, session_ids: list[str]
|
|
||||||
) -> list[PlatformSession]:
|
|
||||||
"""Get platform sessions by IDs."""
|
|
||||||
if not session_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
query = select(PlatformSession).where(
|
|
||||||
col(PlatformSession.session_id).in_(session_ids)
|
|
||||||
)
|
|
||||||
result = await session.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
async def get_platform_sessions_by_creator(
|
async def get_platform_sessions_by_creator(
|
||||||
self,
|
self,
|
||||||
creator: str,
|
creator: str,
|
||||||
|
|||||||
@@ -96,10 +96,10 @@ class Plain(BaseMessageComponent):
|
|||||||
def __init__(self, text: str, convert: bool = True, **_) -> None:
|
def __init__(self, text: str, convert: bool = True, **_) -> None:
|
||||||
super().__init__(text=text, convert=convert, **_)
|
super().__init__(text=text, convert=convert, **_)
|
||||||
|
|
||||||
def toDict(self) -> dict:
|
def toDict(self):
|
||||||
return {"type": "text", "data": {"text": self.text}}
|
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||||
|
|
||||||
async def to_dict(self) -> dict:
|
async def to_dict(self):
|
||||||
return {"type": "text", "data": {"text": self.text}}
|
return {"type": "text", "data": {"text": self.text}}
|
||||||
|
|
||||||
|
|
||||||
@@ -699,24 +699,21 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
if self.url:
|
if self.url:
|
||||||
try:
|
try:
|
||||||
# 检查是否有正在运行的 event loop
|
loop = asyncio.get_event_loop()
|
||||||
asyncio.get_running_loop()
|
if loop.is_running():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"不可以在异步上下文中同步等待下载! "
|
"不可以在异步上下文中同步等待下载! "
|
||||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||||
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
except RuntimeError:
|
# 等待下载完成
|
||||||
# 没有运行中的 event loop,可以同步执行
|
loop.run_until_complete(self._download_file())
|
||||||
try:
|
|
||||||
# 使用 asyncio.run 安全地创建和关闭事件循环
|
|
||||||
asyncio.run(self._download_file())
|
|
||||||
except Exception:
|
|
||||||
logger.exception("文件下载失败")
|
|
||||||
|
|
||||||
if self.file_ and os.path.exists(self.file_):
|
if self.file_ and os.path.exists(self.file_):
|
||||||
return os.path.abspath(self.file_)
|
return os.path.abspath(self.file_)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文件下载失败: {e}")
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from aiocqhttp import CQHttp, Event
|
|||||||
|
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import (
|
from astrbot.api.message_components import (
|
||||||
At,
|
|
||||||
BaseMessageComponent,
|
BaseMessageComponent,
|
||||||
File,
|
File,
|
||||||
Image,
|
Image,
|
||||||
@@ -71,19 +70,11 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
"""解析成 OneBot json 格式"""
|
"""解析成 OneBot json 格式"""
|
||||||
ret = []
|
ret = []
|
||||||
for segment in message_chain.chain:
|
for segment in message_chain.chain:
|
||||||
if isinstance(segment, At):
|
if isinstance(segment, Plain):
|
||||||
# At 组件后插入一个空格,避免与后续文本粘连
|
|
||||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
|
||||||
ret.append(d)
|
|
||||||
ret.append({"type": "text", "data": {"text": " "}})
|
|
||||||
elif isinstance(segment, Plain):
|
|
||||||
if not segment.text.strip():
|
if not segment.text.strip():
|
||||||
continue
|
continue
|
||||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
||||||
ret.append(d)
|
ret.append(d)
|
||||||
else:
|
|
||||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
|
||||||
ret.append(d)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
|
|
||||||
async def get_access_token(self) -> str:
|
async def get_access_token(self) -> str:
|
||||||
try:
|
try:
|
||||||
access_token = await asyncio.get_running_loop().run_in_executor(
|
access_token = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client_.get_access_token,
|
self.client_.get_access_token,
|
||||||
)
|
)
|
||||||
@@ -760,7 +760,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
return
|
return
|
||||||
logger.error(f"钉钉机器人启动失败: {e}")
|
logger.error(f"钉钉机器人启动失败: {e}")
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(None, start_client, loop)
|
await loop.run_in_executor(None, start_client, loop)
|
||||||
|
|
||||||
async def terminate(self) -> None:
|
async def terminate(self) -> None:
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from .server import LarkWebhookServer
|
|||||||
|
|
||||||
|
|
||||||
@register_platform_adapter(
|
@register_platform_adapter(
|
||||||
"lark", "飞书机器人官方 API 适配器", support_streaming_message=True
|
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
|
||||||
)
|
)
|
||||||
class LarkPlatformAdapter(Platform):
|
class LarkPlatformAdapter(Platform):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -491,7 +491,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
name="lark",
|
name="lark",
|
||||||
description="飞书机器人官方 API 适配器",
|
description="飞书机器人官方 API 适配器",
|
||||||
id=cast(str, self.config.get("id")),
|
id=cast(str, self.config.get("id")),
|
||||||
support_streaming_message=True,
|
support_streaming_message=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -6,14 +5,6 @@ import uuid
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import lark_oapi as lark
|
import lark_oapi as lark
|
||||||
from lark_oapi.api.cardkit.v1 import (
|
|
||||||
ContentCardElementRequest,
|
|
||||||
ContentCardElementRequestBody,
|
|
||||||
CreateCardRequest,
|
|
||||||
CreateCardRequestBody,
|
|
||||||
SettingsCardRequest,
|
|
||||||
SettingsCardRequestBody,
|
|
||||||
)
|
|
||||||
from lark_oapi.api.im.v1 import (
|
from lark_oapi.api.im.v1 import (
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
CreateFileRequestBody,
|
CreateFileRequestBody,
|
||||||
@@ -37,7 +28,6 @@ from astrbot.core.utils.media_utils import (
|
|||||||
convert_video_format,
|
convert_video_format,
|
||||||
get_media_duration,
|
get_media_duration,
|
||||||
)
|
)
|
||||||
from astrbot.core.utils.metrics import Metric
|
|
||||||
|
|
||||||
|
|
||||||
class LarkMessageEvent(AstrMessageEvent):
|
class LarkMessageEvent(AstrMessageEvent):
|
||||||
@@ -565,257 +555,15 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||||
return
|
return
|
||||||
|
|
||||||
async def _create_streaming_card(self) -> str | None:
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
"""创建一个开启流式更新模式的卡片实体,返回 card_id。"""
|
|
||||||
if self.bot.cardkit is None:
|
|
||||||
logger.error("[Lark] API Client cardkit 模块未初始化")
|
|
||||||
return None
|
|
||||||
|
|
||||||
card_json = {
|
|
||||||
"schema": "2.0",
|
|
||||||
"header": {
|
|
||||||
"title": {"content": "", "tag": "plain_text"},
|
|
||||||
},
|
|
||||||
"config": {
|
|
||||||
"streaming_mode": True,
|
|
||||||
"summary": {"content": ""},
|
|
||||||
"streaming_config": {
|
|
||||||
"print_frequency_ms": {"default": 50},
|
|
||||||
"print_step": {"default": 2},
|
|
||||||
"print_strategy": "fast",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"body": {
|
|
||||||
"elements": [
|
|
||||||
{
|
|
||||||
"tag": "markdown",
|
|
||||||
"content": "",
|
|
||||||
"element_id": "markdown_1",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
request = (
|
|
||||||
CreateCardRequest.builder()
|
|
||||||
.request_body(
|
|
||||||
CreateCardRequestBody.builder()
|
|
||||||
.type("card_json")
|
|
||||||
.data(json.dumps(card_json, ensure_ascii=False))
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.bot.cardkit.v1.card.acreate(request)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Lark] 创建流式卡片实体失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not response.success():
|
|
||||||
logger.error(
|
|
||||||
f"[Lark] 创建流式卡片实体失败({response.code}): {response.msg}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if response.data is None or not response.data.card_id:
|
|
||||||
logger.error("[Lark] 创建流式卡片实体成功但未返回 card_id")
|
|
||||||
return None
|
|
||||||
|
|
||||||
card_id = response.data.card_id
|
|
||||||
logger.debug(f"[Lark] 创建流式卡片实体成功: {card_id}")
|
|
||||||
return card_id
|
|
||||||
|
|
||||||
async def _send_card_message(
|
|
||||||
self,
|
|
||||||
card_id: str,
|
|
||||||
reply_message_id: str | None = None,
|
|
||||||
receive_id: str | None = None,
|
|
||||||
receive_id_type: str | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""将卡片实体作为 interactive 消息发送。"""
|
|
||||||
content = json.dumps(
|
|
||||||
{"type": "card", "data": {"card_id": card_id}},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
return await self._send_im_message(
|
|
||||||
self.bot,
|
|
||||||
content=content,
|
|
||||||
msg_type="interactive",
|
|
||||||
reply_message_id=reply_message_id,
|
|
||||||
receive_id=receive_id,
|
|
||||||
receive_id_type=receive_id_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _update_streaming_text(
|
|
||||||
self,
|
|
||||||
card_id: str,
|
|
||||||
content: str,
|
|
||||||
sequence: int,
|
|
||||||
) -> bool:
|
|
||||||
"""调用 CardKit 流式更新文本接口,向 markdown_1 组件推送全量文本。"""
|
|
||||||
if self.bot.cardkit is None:
|
|
||||||
logger.error("[Lark] API Client cardkit 模块未初始化")
|
|
||||||
return False
|
|
||||||
|
|
||||||
request = (
|
|
||||||
ContentCardElementRequest.builder()
|
|
||||||
.card_id(card_id)
|
|
||||||
.element_id("markdown_1")
|
|
||||||
.request_body(
|
|
||||||
ContentCardElementRequestBody.builder()
|
|
||||||
.content(content)
|
|
||||||
.sequence(sequence)
|
|
||||||
.uuid(str(uuid.uuid4()))
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.bot.cardkit.v1.card_element.acontent(request)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"[Lark] 流式更新文本失败 (ignored): {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not response.success():
|
|
||||||
logger.debug(f"[Lark] 流式更新文本失败({response.code}): {response.msg}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _close_streaming_mode(
|
|
||||||
self,
|
|
||||||
card_id: str,
|
|
||||||
sequence: int,
|
|
||||||
) -> None:
|
|
||||||
"""关闭卡片的流式更新模式,使其可正常转发、摘要恢复。"""
|
|
||||||
if self.bot.cardkit is None:
|
|
||||||
logger.error("[Lark] API Client cardkit 模块未初始化")
|
|
||||||
return
|
|
||||||
|
|
||||||
settings_json = json.dumps(
|
|
||||||
{"config": {"streaming_mode": False}},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
request = (
|
|
||||||
SettingsCardRequest.builder()
|
|
||||||
.card_id(card_id)
|
|
||||||
.request_body(
|
|
||||||
SettingsCardRequestBody.builder()
|
|
||||||
.settings(settings_json)
|
|
||||||
.sequence(sequence)
|
|
||||||
.uuid(str(uuid.uuid4()))
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.bot.cardkit.v1.card.asettings(request)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Lark] 关闭流式模式失败: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not response.success():
|
|
||||||
logger.error(f"[Lark] 关闭流式模式失败({response.code}): {response.msg}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"[Lark] 流式模式已关闭: {card_id}")
|
|
||||||
|
|
||||||
async def _fallback_send_streaming(self, generator, use_fallback: bool = False):
|
|
||||||
"""回退到非流式发送:缓冲全部文本后一次性发送,并保留父类副作用。"""
|
|
||||||
buffer = None
|
buffer = None
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
if not buffer:
|
if not buffer:
|
||||||
buffer = chain
|
buffer = chain
|
||||||
else:
|
else:
|
||||||
buffer.chain.extend(chain.chain)
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
if buffer:
|
return None
|
||||||
buffer.squash_plain()
|
buffer.squash_plain()
|
||||||
await self.send(buffer)
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
|
||||||
self._has_send_oper = True
|
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
|
||||||
"""使用 CardKit 流式卡片实现打字机效果。
|
|
||||||
|
|
||||||
流程:创建卡片实体 → 发送消息 → 流式更新文本 → 关闭流式模式。
|
|
||||||
使用解耦发送循环,LLM token 到达时只更新 buffer 并唤醒发送协程,
|
|
||||||
发送频率由网络 RTT 自然限流。
|
|
||||||
"""
|
|
||||||
# Step 1: 创建流式卡片实体
|
|
||||||
card_id = await self._create_streaming_card()
|
|
||||||
if not card_id:
|
|
||||||
logger.warning("[Lark] 无法创建流式卡片,回退到非流式发送")
|
|
||||||
await self._fallback_send_streaming(generator, use_fallback)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 2: 发送卡片消息
|
|
||||||
sent = await self._send_card_message(
|
|
||||||
card_id,
|
|
||||||
reply_message_id=self.message_obj.message_id,
|
|
||||||
)
|
|
||||||
if not sent:
|
|
||||||
logger.error("[Lark] 发送流式卡片消息失败,回退到非流式发送")
|
|
||||||
await self._fallback_send_streaming(generator, use_fallback)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("[Lark] 流式输出: 使用 CardKit 流式卡片")
|
|
||||||
|
|
||||||
# Step 3: 解耦发送循环 (Event-driven, 参考 Telegram Draft 路径)
|
|
||||||
sequence = 0
|
|
||||||
delta = ""
|
|
||||||
last_sent = ""
|
|
||||||
done = False
|
|
||||||
text_changed = asyncio.Event()
|
|
||||||
|
|
||||||
async def _sender_loop() -> None:
|
|
||||||
"""信号驱动的文本发送循环,有新内容就发,RTT 自然限流。"""
|
|
||||||
nonlocal sequence, last_sent
|
|
||||||
while not done:
|
|
||||||
await text_changed.wait()
|
|
||||||
text_changed.clear()
|
|
||||||
snapshot = delta
|
|
||||||
if snapshot and snapshot != last_sent:
|
|
||||||
sequence += 1
|
|
||||||
ok = await self._update_streaming_text(card_id, snapshot, sequence)
|
|
||||||
if ok:
|
|
||||||
last_sent = snapshot
|
|
||||||
if delta != snapshot:
|
|
||||||
text_changed.set()
|
|
||||||
|
|
||||||
sender_task = asyncio.create_task(_sender_loop())
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for chain in generator:
|
|
||||||
if not isinstance(chain, MessageChain):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if chain.type == "break":
|
|
||||||
# 飞书卡片不支持分段,忽略 break
|
|
||||||
continue
|
|
||||||
|
|
||||||
for comp in chain.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
delta += comp.text
|
|
||||||
text_changed.set()
|
|
||||||
finally:
|
|
||||||
done = True
|
|
||||||
text_changed.set()
|
|
||||||
await sender_task
|
|
||||||
|
|
||||||
# Step 4: 必要时补发最终文本 + 关闭流式模式
|
|
||||||
if delta and delta != last_sent:
|
|
||||||
sequence += 1
|
|
||||||
await self._update_streaming_text(card_id, delta, sequence)
|
|
||||||
|
|
||||||
sequence += 1
|
|
||||||
await self._close_streaming_mode(card_id, sequence)
|
|
||||||
|
|
||||||
# Step 5: 内联父类 send_streaming 的副作用
|
|
||||||
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
|
||||||
self._has_send_oper = True
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from botpy.types.message import MarkdownPayload, Media
|
|||||||
|
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import File, Image, Plain, Record, Video
|
from astrbot.api.message_components import Image, Plain, Record
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
||||||
@@ -47,11 +47,6 @@ _patch_qq_botpy_formdata()
|
|||||||
|
|
||||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||||
MARKDOWN_NOT_ALLOWED_ERROR = "不允许发送原生 markdown"
|
MARKDOWN_NOT_ALLOWED_ERROR = "不允许发送原生 markdown"
|
||||||
IMAGE_FILE_TYPE = 1
|
|
||||||
VIDEO_FILE_TYPE = 2
|
|
||||||
VOICE_FILE_TYPE = 3
|
|
||||||
FILE_FILE_TYPE = 4
|
|
||||||
STREAM_MARKDOWN_NEWLINE_ERROR = "流式消息md分片需要\\n结束"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -70,71 +65,35 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
await self._post_send()
|
await self._post_send()
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
"""流式输出仅支持消息列表私聊(C2C),其他消息源退化为普通发送"""
|
"""流式输出仅支持消息列表私聊"""
|
||||||
# 先标记事件层“已执行发送操作”,避免异常路径遗漏
|
|
||||||
await super().send_streaming(generator, use_fallback)
|
|
||||||
# QQ C2C 流式协议:开始/中间分片使用 state=1,结束分片使用 state=10
|
|
||||||
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||||
last_edit_time = 0 # 上次发送分片的时间
|
last_edit_time = 0 # 上次编辑消息的时间
|
||||||
throttle_interval = 1 # 分片间最短间隔 (秒)
|
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
|
||||||
ret = None
|
ret = None
|
||||||
source = (
|
|
||||||
self.message_obj.raw_message
|
|
||||||
) # 提前获取,避免 generator 为空时 NameError
|
|
||||||
try:
|
try:
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
source = self.message_obj.raw_message
|
source = self.message_obj.raw_message
|
||||||
|
|
||||||
if not isinstance(source, botpy.message.C2CMessage):
|
|
||||||
# 非 C2C 场景:直接累积,最后统一发
|
|
||||||
if not self.send_buffer:
|
|
||||||
self.send_buffer = chain
|
|
||||||
else:
|
|
||||||
self.send_buffer.chain.extend(chain.chain)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# ---- C2C 流式场景 ----
|
|
||||||
|
|
||||||
# tool_call break 信号:工具开始执行,先把已有 buffer 以 state=10 结束当前流式段
|
|
||||||
if chain.type == "break":
|
|
||||||
if self.send_buffer:
|
|
||||||
stream_payload["state"] = 10
|
|
||||||
ret = await self._post_send(stream=stream_payload)
|
|
||||||
ret_id = self._extract_response_message_id(ret)
|
|
||||||
if ret_id is not None:
|
|
||||||
stream_payload["id"] = ret_id
|
|
||||||
# 重置 stream_payload,为下一段流式做准备
|
|
||||||
stream_payload = {
|
|
||||||
"state": 1,
|
|
||||||
"id": None,
|
|
||||||
"index": 0,
|
|
||||||
"reset": False,
|
|
||||||
}
|
|
||||||
last_edit_time = 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 累积内容
|
|
||||||
if not self.send_buffer:
|
if not self.send_buffer:
|
||||||
self.send_buffer = chain
|
self.send_buffer = chain
|
||||||
else:
|
else:
|
||||||
self.send_buffer.chain.extend(chain.chain)
|
self.send_buffer.chain.extend(chain.chain)
|
||||||
|
|
||||||
# 节流:按时间间隔发送中间分片
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
current_time = asyncio.get_running_loop().time()
|
# 真流式传输
|
||||||
if current_time - last_edit_time >= throttle_interval:
|
current_time = asyncio.get_event_loop().time()
|
||||||
ret = cast(
|
time_since_last_edit = current_time - last_edit_time
|
||||||
message.Message,
|
|
||||||
await self._post_send(stream=stream_payload),
|
if time_since_last_edit >= throttle_interval:
|
||||||
)
|
ret = cast(
|
||||||
stream_payload["index"] += 1
|
message.Message,
|
||||||
ret_id = self._extract_response_message_id(ret)
|
await self._post_send(stream=stream_payload),
|
||||||
if ret_id is not None:
|
)
|
||||||
stream_payload["id"] = ret_id
|
stream_payload["index"] += 1
|
||||||
last_edit_time = asyncio.get_running_loop().time()
|
stream_payload["id"] = ret["id"]
|
||||||
self.send_buffer = None # 清空已发送的分片,避免下次重复发送旧内容
|
last_edit_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
if isinstance(source, botpy.message.C2CMessage):
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
# 结束流式对话,发送 buffer 中剩余内容
|
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||||
stream_payload["state"] = 10
|
stream_payload["state"] = 10
|
||||||
ret = await self._post_send(stream=stream_payload)
|
ret = await self._post_send(stream=stream_payload)
|
||||||
else:
|
else:
|
||||||
@@ -142,22 +101,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||||
# 避免累计内容在异常后被整包重复发送:仅清理缓存,不做非流式整包兜底
|
|
||||||
# 如需兜底,应该只发送未发送 delta(后续可继续优化)
|
|
||||||
self.send_buffer = None
|
self.send_buffer = None
|
||||||
|
|
||||||
return None
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_response_message_id(ret) -> str | None:
|
|
||||||
"""兼容 qq-botpy 返回 Message 对象或 dict 两种形态。"""
|
|
||||||
if ret is None:
|
|
||||||
return None
|
|
||||||
if isinstance(ret, dict):
|
|
||||||
ret_id = ret.get("id")
|
|
||||||
return str(ret_id) if ret_id is not None else None
|
|
||||||
ret_id = getattr(ret, "id", None)
|
|
||||||
return str(ret_id) if ret_id is not None else None
|
|
||||||
|
|
||||||
async def _post_send(self, stream: dict | None = None):
|
async def _post_send(self, stream: dict | None = None):
|
||||||
if not self.send_buffer:
|
if not self.send_buffer:
|
||||||
@@ -180,37 +126,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
image_base64,
|
image_base64,
|
||||||
image_path,
|
image_path,
|
||||||
record_file_path,
|
record_file_path,
|
||||||
video_file_source,
|
|
||||||
file_source,
|
|
||||||
file_name,
|
|
||||||
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||||
|
|
||||||
# C2C 流式仅用于文本分片,富媒体时降级为普通发送,避免平台侧流式校验报错。
|
|
||||||
if stream and (image_base64 or record_file_path):
|
|
||||||
logger.debug("[QQOfficial] 检测到富媒体,降级为非流式发送。")
|
|
||||||
stream = None
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not plain_text
|
not plain_text
|
||||||
and not image_base64
|
and not image_base64
|
||||||
and not image_path
|
and not image_path
|
||||||
and not record_file_path
|
and not record_file_path
|
||||||
and not video_file_source
|
|
||||||
and not file_source
|
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# QQ C2C 流式 API 说明:
|
|
||||||
# - 开始/中间分片(state=1):增量追加内容,不需要 \n(加了会导致强制换行)
|
|
||||||
# - 最终分片(state=10):结束流,content 必须以 \n 结尾(QQ API 要求)
|
|
||||||
if (
|
|
||||||
stream
|
|
||||||
and stream.get("state") == 10
|
|
||||||
and plain_text
|
|
||||||
and not plain_text.endswith("\n")
|
|
||||||
):
|
|
||||||
plain_text = plain_text + "\n"
|
|
||||||
|
|
||||||
payload: dict = {
|
payload: dict = {
|
||||||
# "content": plain_text,
|
# "content": plain_text,
|
||||||
"markdown": MarkdownPayload(content=plain_text) if plain_text else None,
|
"markdown": MarkdownPayload(content=plain_text) if plain_text else None,
|
||||||
@@ -232,7 +157,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
if image_base64:
|
if image_base64:
|
||||||
media = await self.upload_group_and_c2c_image(
|
media = await self.upload_group_and_c2c_image(
|
||||||
image_base64,
|
image_base64,
|
||||||
self.IMAGE_FILE_TYPE,
|
1,
|
||||||
group_openid=source.group_openid,
|
group_openid=source.group_openid,
|
||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
@@ -240,39 +165,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
payload.pop("markdown", None)
|
payload.pop("markdown", None)
|
||||||
payload["content"] = plain_text or None
|
payload["content"] = plain_text or None
|
||||||
if record_file_path: # group record msg
|
if record_file_path: # group record msg
|
||||||
media = await self.upload_group_and_c2c_media(
|
media = await self.upload_group_and_c2c_record(
|
||||||
record_file_path,
|
record_file_path,
|
||||||
self.VOICE_FILE_TYPE,
|
3,
|
||||||
group_openid=source.group_openid,
|
group_openid=source.group_openid,
|
||||||
)
|
)
|
||||||
if media:
|
payload["media"] = media
|
||||||
payload["media"] = media
|
payload["msg_type"] = 7
|
||||||
payload["msg_type"] = 7
|
payload.pop("markdown", None)
|
||||||
payload.pop("markdown", None)
|
payload["content"] = plain_text or None
|
||||||
payload["content"] = plain_text or None
|
|
||||||
if video_file_source:
|
|
||||||
media = await self.upload_group_and_c2c_media(
|
|
||||||
video_file_source,
|
|
||||||
self.VIDEO_FILE_TYPE,
|
|
||||||
group_openid=source.group_openid,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
payload.pop("markdown", None)
|
|
||||||
payload["content"] = plain_text or None
|
|
||||||
if file_source:
|
|
||||||
media = await self.upload_group_and_c2c_media(
|
|
||||||
file_source,
|
|
||||||
self.FILE_FILE_TYPE,
|
|
||||||
file_name=file_name,
|
|
||||||
group_openid=source.group_openid,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
payload.pop("markdown", None)
|
|
||||||
payload["content"] = plain_text or None
|
|
||||||
ret = await self._send_with_markdown_fallback(
|
ret = await self._send_with_markdown_fallback(
|
||||||
send_func=lambda retry_payload: self.bot.api.post_group_message(
|
send_func=lambda retry_payload: self.bot.api.post_group_message(
|
||||||
group_openid=source.group_openid, # type: ignore
|
group_openid=source.group_openid, # type: ignore
|
||||||
@@ -280,14 +181,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
),
|
),
|
||||||
payload=payload,
|
payload=payload,
|
||||||
plain_text=plain_text,
|
plain_text=plain_text,
|
||||||
stream=stream,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
case botpy.message.C2CMessage():
|
case botpy.message.C2CMessage():
|
||||||
if image_base64:
|
if image_base64:
|
||||||
media = await self.upload_group_and_c2c_image(
|
media = await self.upload_group_and_c2c_image(
|
||||||
image_base64,
|
image_base64,
|
||||||
self.IMAGE_FILE_TYPE,
|
1,
|
||||||
openid=source.author.user_openid,
|
openid=source.author.user_openid,
|
||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
@@ -295,39 +195,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
payload.pop("markdown", None)
|
payload.pop("markdown", None)
|
||||||
payload["content"] = plain_text or None
|
payload["content"] = plain_text or None
|
||||||
if record_file_path: # c2c record
|
if record_file_path: # c2c record
|
||||||
media = await self.upload_group_and_c2c_media(
|
media = await self.upload_group_and_c2c_record(
|
||||||
record_file_path,
|
record_file_path,
|
||||||
self.VOICE_FILE_TYPE,
|
3,
|
||||||
openid=source.author.user_openid,
|
openid=source.author.user_openid,
|
||||||
)
|
)
|
||||||
if media:
|
payload["media"] = media
|
||||||
payload["media"] = media
|
payload["msg_type"] = 7
|
||||||
payload["msg_type"] = 7
|
payload.pop("markdown", None)
|
||||||
payload.pop("markdown", None)
|
payload["content"] = plain_text or None
|
||||||
payload["content"] = plain_text or None
|
|
||||||
if video_file_source:
|
|
||||||
media = await self.upload_group_and_c2c_media(
|
|
||||||
video_file_source,
|
|
||||||
self.VIDEO_FILE_TYPE,
|
|
||||||
openid=source.author.user_openid,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
payload.pop("markdown", None)
|
|
||||||
payload["content"] = plain_text or None
|
|
||||||
if file_source:
|
|
||||||
media = await self.upload_group_and_c2c_media(
|
|
||||||
file_source,
|
|
||||||
self.FILE_FILE_TYPE,
|
|
||||||
file_name=file_name,
|
|
||||||
openid=source.author.user_openid,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
payload.pop("markdown", None)
|
|
||||||
payload["content"] = plain_text or None
|
|
||||||
if stream:
|
if stream:
|
||||||
ret = await self._send_with_markdown_fallback(
|
ret = await self._send_with_markdown_fallback(
|
||||||
send_func=lambda retry_payload: self.post_c2c_message(
|
send_func=lambda retry_payload: self.post_c2c_message(
|
||||||
@@ -337,7 +213,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
),
|
),
|
||||||
payload=payload,
|
payload=payload,
|
||||||
plain_text=plain_text,
|
plain_text=plain_text,
|
||||||
stream=stream,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ret = await self._send_with_markdown_fallback(
|
ret = await self._send_with_markdown_fallback(
|
||||||
@@ -347,7 +222,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
),
|
),
|
||||||
payload=payload,
|
payload=payload,
|
||||||
plain_text=plain_text,
|
plain_text=plain_text,
|
||||||
stream=stream,
|
|
||||||
)
|
)
|
||||||
logger.debug(f"Message sent to C2C: {ret}")
|
logger.debug(f"Message sent to C2C: {ret}")
|
||||||
|
|
||||||
@@ -363,7 +237,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
),
|
),
|
||||||
payload=payload,
|
payload=payload,
|
||||||
plain_text=plain_text,
|
plain_text=plain_text,
|
||||||
stream=stream,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
case botpy.message.DirectMessage():
|
case botpy.message.DirectMessage():
|
||||||
@@ -378,7 +251,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
),
|
),
|
||||||
payload=payload,
|
payload=payload,
|
||||||
plain_text=plain_text,
|
plain_text=plain_text,
|
||||||
stream=stream,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
@@ -395,31 +267,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
send_func,
|
send_func,
|
||||||
payload: dict,
|
payload: dict,
|
||||||
plain_text: str,
|
plain_text: str,
|
||||||
stream: dict | None = None,
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
return await send_func(payload)
|
return await send_func(payload)
|
||||||
except botpy.errors.ServerError as err:
|
except botpy.errors.ServerError as err:
|
||||||
# QQ 流式 markdown 分片校验:内容必须以换行结尾。
|
|
||||||
# 某些边界场景服务端仍可能判定失败,这里做一次修正重试。
|
|
||||||
if stream and self.STREAM_MARKDOWN_NEWLINE_ERROR in str(err):
|
|
||||||
retry_payload = payload.copy()
|
|
||||||
|
|
||||||
markdown_payload = retry_payload.get("markdown")
|
|
||||||
if isinstance(markdown_payload, dict):
|
|
||||||
md_content = cast(str, markdown_payload.get("content", "") or "")
|
|
||||||
if md_content and not md_content.endswith("\n"):
|
|
||||||
retry_payload["markdown"] = {"content": md_content + "\n"}
|
|
||||||
|
|
||||||
content = cast(str | None, retry_payload.get("content"))
|
|
||||||
if content and not content.endswith("\n"):
|
|
||||||
retry_payload["content"] = content + "\n"
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"[QQOfficial] 流式 markdown 分片换行校验失败,已修正后重试一次。"
|
|
||||||
)
|
|
||||||
return await send_func(retry_payload)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.MARKDOWN_NOT_ALLOWED_ERROR not in str(err)
|
self.MARKDOWN_NOT_ALLOWED_ERROR not in str(err)
|
||||||
or not payload.get("markdown")
|
or not payload.get("markdown")
|
||||||
@@ -431,14 +282,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
"[QQOfficial] markdown 发送被拒绝,回退到 content 模式重试。"
|
"[QQOfficial] markdown 发送被拒绝,回退到 content 模式重试。"
|
||||||
)
|
)
|
||||||
fallback_payload = payload.copy()
|
fallback_payload = payload.copy()
|
||||||
fallback_payload.pop("markdown", None)
|
fallback_payload["markdown"] = None
|
||||||
fallback_payload["content"] = plain_text
|
fallback_payload["content"] = plain_text
|
||||||
if fallback_payload.get("msg_type") == 2:
|
if fallback_payload.get("msg_type") == 2:
|
||||||
fallback_payload["msg_type"] = 0
|
fallback_payload["msg_type"] = 0
|
||||||
if stream:
|
|
||||||
fallback_content = cast(str, fallback_payload.get("content") or "")
|
|
||||||
if fallback_content and not fallback_content.endswith("\n"):
|
|
||||||
fallback_payload["content"] = fallback_content + "\n"
|
|
||||||
return await send_func(fallback_payload)
|
return await send_func(fallback_payload)
|
||||||
|
|
||||||
async def upload_group_and_c2c_image(
|
async def upload_group_and_c2c_image(
|
||||||
@@ -480,19 +327,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
ttl=result.get("ttl", 0),
|
ttl=result.get("ttl", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upload_group_and_c2c_media(
|
async def upload_group_and_c2c_record(
|
||||||
self,
|
self,
|
||||||
file_source: str,
|
file_source: str,
|
||||||
file_type: int,
|
file_type: int,
|
||||||
srv_send_msg: bool = False,
|
srv_send_msg: bool = False,
|
||||||
file_name: str | None = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Media | None:
|
) -> Media | None:
|
||||||
"""上传媒体文件"""
|
"""上传媒体文件"""
|
||||||
# 构建基础payload
|
# 构建基础payload
|
||||||
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
|
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
|
||||||
if file_name:
|
|
||||||
payload["file_name"] = file_name
|
|
||||||
|
|
||||||
# 处理文件数据
|
# 处理文件数据
|
||||||
if os.path.exists(file_source):
|
if os.path.exists(file_source):
|
||||||
@@ -556,21 +400,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
) -> message.Message:
|
) -> message.Message:
|
||||||
payload = locals()
|
payload = locals()
|
||||||
payload.pop("self", None)
|
payload.pop("self", None)
|
||||||
# QQ API does not accept stream.id=None; remove it when not yet assigned
|
|
||||||
if "stream" in payload and payload["stream"] is not None:
|
|
||||||
stream_data = dict(payload["stream"])
|
|
||||||
if stream_data.get("id") is None:
|
|
||||||
stream_data.pop("id", None)
|
|
||||||
payload["stream"] = stream_data
|
|
||||||
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||||
result = await self.bot.api._http.request(route, json=payload)
|
result = await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.warning("[QQOfficial] post_c2c_message: API 返回 None,跳过本次发送")
|
|
||||||
return None
|
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
logger.error(f"[QQOfficial] post_c2c_message: 响应不是 dict: {result}")
|
raise RuntimeError(
|
||||||
return None
|
f"Failed to post c2c message, response is not dict: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
return message.Message(**result)
|
return message.Message(**result)
|
||||||
|
|
||||||
@@ -580,9 +416,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
image_base64 = None # only one img supported
|
image_base64 = None # only one img supported
|
||||||
image_file_path = None
|
image_file_path = None
|
||||||
record_file_path = None
|
record_file_path = None
|
||||||
video_file_source = None
|
|
||||||
file_source = None
|
|
||||||
file_name = None
|
|
||||||
for i in message.chain:
|
for i in message.chain:
|
||||||
if isinstance(i, Plain):
|
if isinstance(i, Plain):
|
||||||
plain_text += i.text
|
plain_text += i.text
|
||||||
@@ -621,30 +454,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理语音时出错: {e}")
|
logger.error(f"处理语音时出错: {e}")
|
||||||
record_file_path = None
|
record_file_path = None
|
||||||
elif isinstance(i, Video) and not video_file_source:
|
|
||||||
if i.file.startswith("file:///"):
|
|
||||||
video_file_source = i.file[8:]
|
|
||||||
else:
|
|
||||||
video_file_source = i.file
|
|
||||||
elif isinstance(i, File) and not file_source:
|
|
||||||
file_name = i.name
|
|
||||||
if i.file_:
|
|
||||||
file_path = i.file_
|
|
||||||
if file_path.startswith("file:///"):
|
|
||||||
file_path = file_path[8:]
|
|
||||||
elif file_path.startswith("file://"):
|
|
||||||
file_path = file_path[7:]
|
|
||||||
file_source = file_path
|
|
||||||
elif i.url:
|
|
||||||
file_source = i.url
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"qq_official 忽略 {i.type}")
|
logger.debug(f"qq_official 忽略 {i.type}")
|
||||||
return (
|
return plain_text, image_base64, image_file_path, record_file_path
|
||||||
plain_text,
|
|
||||||
image_base64,
|
|
||||||
image_file_path,
|
|
||||||
record_file_path,
|
|
||||||
video_file_source,
|
|
||||||
file_source,
|
|
||||||
file_name,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import time
|
import time
|
||||||
from types import SimpleNamespace
|
from typing import cast
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
import botpy
|
import botpy
|
||||||
import botpy.message
|
import botpy.message
|
||||||
@@ -14,7 +12,7 @@ from botpy import Client
|
|||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
|
from astrbot.api.message_components import At, File, Image, Plain
|
||||||
from astrbot.api.platform import (
|
from astrbot.api.platform import (
|
||||||
AstrBotMessage,
|
AstrBotMessage,
|
||||||
MessageMember,
|
MessageMember,
|
||||||
@@ -48,7 +46,6 @@ class botClient(Client):
|
|||||||
)
|
)
|
||||||
abm.group_id = cast(str, message.group_openid)
|
abm.group_id = cast(str, message.group_openid)
|
||||||
abm.session_id = abm.group_id
|
abm.session_id = abm.group_id
|
||||||
self.platform.remember_session_scene(abm.session_id, "group")
|
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
# 收到频道消息
|
# 收到频道消息
|
||||||
@@ -59,7 +56,6 @@ class botClient(Client):
|
|||||||
)
|
)
|
||||||
abm.group_id = message.channel_id
|
abm.group_id = message.channel_id
|
||||||
abm.session_id = abm.group_id
|
abm.session_id = abm.group_id
|
||||||
self.platform.remember_session_scene(abm.session_id, "channel")
|
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
# 收到私聊消息
|
# 收到私聊消息
|
||||||
@@ -71,7 +67,6 @@ class botClient(Client):
|
|||||||
MessageType.FRIEND_MESSAGE,
|
MessageType.FRIEND_MESSAGE,
|
||||||
)
|
)
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
self.platform.remember_session_scene(abm.session_id, "friend")
|
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
# 收到 C2C 消息
|
# 收到 C2C 消息
|
||||||
@@ -81,11 +76,9 @@ class botClient(Client):
|
|||||||
MessageType.FRIEND_MESSAGE,
|
MessageType.FRIEND_MESSAGE,
|
||||||
)
|
)
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
self.platform.remember_session_scene(abm.session_id, "friend")
|
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
def _commit(self, abm: AstrBotMessage) -> None:
|
def _commit(self, abm: AstrBotMessage) -> None:
|
||||||
self.platform.remember_session_message_id(abm.session_id, abm.message_id)
|
|
||||||
self.platform.commit_event(
|
self.platform.commit_event(
|
||||||
QQOfficialMessageEvent(
|
QQOfficialMessageEvent(
|
||||||
abm.message_str,
|
abm.message_str,
|
||||||
@@ -131,9 +124,6 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
|
|
||||||
self.client.set_platform(self)
|
self.client.set_platform(self)
|
||||||
|
|
||||||
self._session_last_message_id: dict[str, str] = {}
|
|
||||||
self._session_scene: dict[str, str] = {}
|
|
||||||
|
|
||||||
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
||||||
|
|
||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
@@ -141,191 +131,14 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
session: MessageSesion,
|
session: MessageSesion,
|
||||||
message_chain: MessageChain,
|
message_chain: MessageChain,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self._send_by_session_common(session, message_chain)
|
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||||
|
|
||||||
async def _send_by_session_common(
|
|
||||||
self,
|
|
||||||
session: MessageSesion,
|
|
||||||
message_chain: MessageChain,
|
|
||||||
) -> None:
|
|
||||||
(
|
|
||||||
plain_text,
|
|
||||||
image_base64,
|
|
||||||
image_path,
|
|
||||||
record_file_path,
|
|
||||||
video_file_source,
|
|
||||||
file_source,
|
|
||||||
file_name,
|
|
||||||
) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
|
|
||||||
if (
|
|
||||||
not plain_text
|
|
||||||
and not image_path
|
|
||||||
and not image_base64
|
|
||||||
and not record_file_path
|
|
||||||
and not video_file_source
|
|
||||||
and not file_source
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
msg_id = self._session_last_message_id.get(session.session_id)
|
|
||||||
if not msg_id:
|
|
||||||
logger.warning(
|
|
||||||
"[QQOfficial] No cached msg_id for session: %s, skip send_by_session",
|
|
||||||
session.session_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
|
|
||||||
ret: Any = None
|
|
||||||
send_helper = SimpleNamespace(bot=self.client)
|
|
||||||
|
|
||||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
|
||||||
scene = self._session_scene.get(session.session_id)
|
|
||||||
if scene == "group":
|
|
||||||
payload["msg_seq"] = random.randint(1, 10000)
|
|
||||||
if image_base64:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
image_base64,
|
|
||||||
QQOfficialMessageEvent.IMAGE_FILE_TYPE,
|
|
||||||
group_openid=session.session_id,
|
|
||||||
)
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
if record_file_path:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
record_file_path,
|
|
||||||
QQOfficialMessageEvent.VOICE_FILE_TYPE,
|
|
||||||
group_openid=session.session_id,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
if video_file_source:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
video_file_source,
|
|
||||||
QQOfficialMessageEvent.VIDEO_FILE_TYPE,
|
|
||||||
group_openid=session.session_id,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
payload.pop("msg_id", None)
|
|
||||||
if file_source:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
file_source,
|
|
||||||
QQOfficialMessageEvent.FILE_FILE_TYPE,
|
|
||||||
file_name=file_name,
|
|
||||||
group_openid=session.session_id,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
payload.pop("msg_id", None)
|
|
||||||
ret = await self.client.api.post_group_message(
|
|
||||||
group_openid=session.session_id,
|
|
||||||
**payload,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if image_path:
|
|
||||||
payload["file_image"] = image_path
|
|
||||||
ret = await self.client.api.post_message(
|
|
||||||
channel_id=session.session_id,
|
|
||||||
**payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif session.message_type == MessageType.FRIEND_MESSAGE:
|
|
||||||
payload["msg_seq"] = random.randint(1, 10000)
|
|
||||||
if image_base64:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
image_base64,
|
|
||||||
QQOfficialMessageEvent.IMAGE_FILE_TYPE,
|
|
||||||
openid=session.session_id,
|
|
||||||
)
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
if record_file_path:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
record_file_path,
|
|
||||||
QQOfficialMessageEvent.VOICE_FILE_TYPE,
|
|
||||||
openid=session.session_id,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
if video_file_source:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
video_file_source,
|
|
||||||
QQOfficialMessageEvent.VIDEO_FILE_TYPE,
|
|
||||||
openid=session.session_id,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
# QQ API rejects msg_id for media (video/file) messages sent
|
|
||||||
# via the proactive tool-call path; remove it to avoid 越权 error.
|
|
||||||
payload.pop("msg_id", None)
|
|
||||||
if file_source:
|
|
||||||
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
file_source,
|
|
||||||
QQOfficialMessageEvent.FILE_FILE_TYPE,
|
|
||||||
file_name=file_name,
|
|
||||||
openid=session.session_id,
|
|
||||||
)
|
|
||||||
if media:
|
|
||||||
payload["media"] = media
|
|
||||||
payload["msg_type"] = 7
|
|
||||||
payload.pop("msg_id", None)
|
|
||||||
|
|
||||||
ret = await QQOfficialMessageEvent.post_c2c_message(
|
|
||||||
send_helper, # type: ignore
|
|
||||||
openid=session.session_id,
|
|
||||||
**payload,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"[QQOfficial] Unsupported message type for send_by_session: %s",
|
|
||||||
session.message_type,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
sent_message_id = self._extract_message_id(ret)
|
|
||||||
if sent_message_id:
|
|
||||||
self.remember_session_message_id(session.session_id, sent_message_id)
|
|
||||||
await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
def remember_session_message_id(self, session_id: str, message_id: str) -> None:
|
|
||||||
if not session_id or not message_id:
|
|
||||||
return
|
|
||||||
self._session_last_message_id[session_id] = message_id
|
|
||||||
|
|
||||||
def remember_session_scene(self, session_id: str, scene: str) -> None:
|
|
||||||
if not session_id or not scene:
|
|
||||||
return
|
|
||||||
self._session_scene[session_id] = scene
|
|
||||||
|
|
||||||
def _extract_message_id(self, ret: Any) -> str | None:
|
|
||||||
if isinstance(ret, dict):
|
|
||||||
message_id = ret.get("id")
|
|
||||||
return str(message_id) if message_id else None
|
|
||||||
message_id = getattr(ret, "id", None)
|
|
||||||
if message_id:
|
|
||||||
return str(message_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
name="qq_official",
|
name="qq_official",
|
||||||
description="QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
id=cast(str, self.config.get("id")),
|
id=cast(str, self.config.get("id")),
|
||||||
support_proactive_message=True,
|
support_proactive_message=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -345,10 +158,7 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
content_type = cast(
|
content_type = cast(str, getattr(attachment, "content_type", "") or "")
|
||||||
str,
|
|
||||||
getattr(attachment, "content_type", "") or "",
|
|
||||||
).lower()
|
|
||||||
url = QQOfficialPlatformAdapter._normalize_attachment_url(
|
url = QQOfficialPlatformAdapter._normalize_attachment_url(
|
||||||
cast(str | None, getattr(attachment, "url", None))
|
cast(str | None, getattr(attachment, "url", None))
|
||||||
)
|
)
|
||||||
@@ -364,73 +174,7 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
or getattr(attachment, "name", None)
|
or getattr(attachment, "name", None)
|
||||||
or "attachment",
|
or "attachment",
|
||||||
)
|
)
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
msg.append(File(name=filename, file=url, url=url))
|
||||||
image_exts = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
|
||||||
audio_exts = {
|
|
||||||
".mp3",
|
|
||||||
".wav",
|
|
||||||
".ogg",
|
|
||||||
".m4a",
|
|
||||||
".amr",
|
|
||||||
".silk",
|
|
||||||
}
|
|
||||||
video_exts = {
|
|
||||||
".mp4",
|
|
||||||
".mov",
|
|
||||||
".avi",
|
|
||||||
".mkv",
|
|
||||||
".webm",
|
|
||||||
}
|
|
||||||
|
|
||||||
if content_type.startswith("audio") or ext in audio_exts:
|
|
||||||
msg.append(Record.fromURL(url))
|
|
||||||
elif content_type.startswith("video") or ext in video_exts:
|
|
||||||
msg.append(Video.fromURL(url))
|
|
||||||
elif content_type.startswith("image") or ext in image_exts:
|
|
||||||
msg.append(Image.fromURL(url))
|
|
||||||
else:
|
|
||||||
msg.append(File(name=filename, file=url, url=url))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_face_message(content: str) -> str:
|
|
||||||
"""Parse QQ official face message format and convert to readable text.
|
|
||||||
|
|
||||||
QQ official face message format:
|
|
||||||
<faceType=4,faceId="",ext="eyJ0ZXh0IjoiW+a7oeWktOmXruWPt10ifQ==">
|
|
||||||
|
|
||||||
The ext field contains base64-encoded JSON with a 'text' field
|
|
||||||
describing the emoji (e.g., '[满头问号]').
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The message content that may contain face tags.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Content with face tags replaced by readable emoji descriptions.
|
|
||||||
"""
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
|
|
||||||
def replace_face(match):
|
|
||||||
face_tag = match.group(0)
|
|
||||||
# Extract ext field from the face tag
|
|
||||||
ext_match = re.search(r'ext="([^"]*)"', face_tag)
|
|
||||||
if ext_match:
|
|
||||||
try:
|
|
||||||
ext_encoded = ext_match.group(1)
|
|
||||||
# Decode base64 and parse JSON
|
|
||||||
ext_decoded = base64.b64decode(ext_encoded).decode("utf-8")
|
|
||||||
ext_data = json.loads(ext_decoded)
|
|
||||||
emoji_text = ext_data.get("text", "")
|
|
||||||
if emoji_text:
|
|
||||||
return f"[表情:{emoji_text}]"
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# Fallback if parsing fails
|
|
||||||
return "[表情]"
|
|
||||||
|
|
||||||
# Match face tags: <faceType=...>
|
|
||||||
return re.sub(r"<faceType=\d+[^>]*>", replace_face, content)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_from_qqofficial(
|
def _parse_from_qqofficial(
|
||||||
@@ -457,10 +201,7 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
abm.group_id = message.group_openid
|
abm.group_id = message.group_openid
|
||||||
else:
|
else:
|
||||||
abm.sender = MessageMember(message.author.user_openid, "")
|
abm.sender = MessageMember(message.author.user_openid, "")
|
||||||
# Parse face messages to readable text
|
abm.message_str = message.content.strip()
|
||||||
abm.message_str = QQOfficialPlatformAdapter._parse_face_message(
|
|
||||||
message.content.strip()
|
|
||||||
)
|
|
||||||
abm.self_id = "unknown_selfid"
|
abm.self_id = "unknown_selfid"
|
||||||
msg.append(At(qq="qq_official"))
|
msg.append(At(qq="qq_official"))
|
||||||
msg.append(Plain(abm.message_str))
|
msg.append(Plain(abm.message_str))
|
||||||
@@ -476,12 +217,10 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
else:
|
else:
|
||||||
abm.self_id = ""
|
abm.self_id = ""
|
||||||
|
|
||||||
plain_content = QQOfficialPlatformAdapter._parse_face_message(
|
plain_content = message.content.replace(
|
||||||
message.content.replace(
|
"<@!" + str(abm.self_id) + ">",
|
||||||
"<@!" + str(abm.self_id) + ">",
|
"",
|
||||||
"",
|
).strip()
|
||||||
).strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
QQOfficialPlatformAdapter._append_attachments(msg, message.attachments)
|
QQOfficialPlatformAdapter._append_attachments(msg, message.attachments)
|
||||||
abm.message = msg
|
abm.message = msg
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import botpy
|
import botpy
|
||||||
@@ -13,6 +15,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
|||||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||||
|
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
|
from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
|
||||||
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||||
from .qo_webhook_event import QQOfficialWebhookMessageEvent
|
from .qo_webhook_event import QQOfficialWebhookMessageEvent
|
||||||
from .qo_webhook_server import QQOfficialWebhook
|
from .qo_webhook_server import QQOfficialWebhook
|
||||||
@@ -120,11 +123,95 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
|||||||
session: MessageSesion,
|
session: MessageSesion,
|
||||||
message_chain: MessageChain,
|
message_chain: MessageChain,
|
||||||
) -> None:
|
) -> None:
|
||||||
await QQOfficialPlatformAdapter._send_by_session_common(
|
(
|
||||||
cast(Any, self),
|
plain_text,
|
||||||
session,
|
image_base64,
|
||||||
message_chain,
|
image_path,
|
||||||
)
|
record_file_path,
|
||||||
|
) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
|
||||||
|
if not plain_text and not image_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
msg_id = self._session_last_message_id.get(session.session_id)
|
||||||
|
if not msg_id:
|
||||||
|
logger.warning(
|
||||||
|
"[QQOfficialWebhook] No cached msg_id for session: %s, skip send_by_session",
|
||||||
|
session.session_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
|
||||||
|
ret: Any = None
|
||||||
|
send_helper = SimpleNamespace(bot=self.client)
|
||||||
|
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||||
|
scene = self._session_scene.get(session.session_id)
|
||||||
|
if scene == "group":
|
||||||
|
payload["msg_seq"] = random.randint(1, 10000)
|
||||||
|
if image_base64:
|
||||||
|
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
|
||||||
|
send_helper, # type: ignore
|
||||||
|
image_base64,
|
||||||
|
1,
|
||||||
|
group_openid=session.session_id,
|
||||||
|
)
|
||||||
|
payload["media"] = media
|
||||||
|
payload["msg_type"] = 7
|
||||||
|
if record_file_path:
|
||||||
|
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
|
||||||
|
send_helper, # type: ignore
|
||||||
|
record_file_path,
|
||||||
|
3,
|
||||||
|
group_openid=session.session_id,
|
||||||
|
)
|
||||||
|
payload["media"] = media
|
||||||
|
payload["msg_type"] = 7
|
||||||
|
ret = await self.client.api.post_group_message(
|
||||||
|
group_openid=session.session_id,
|
||||||
|
**payload,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if image_path:
|
||||||
|
payload["file_image"] = image_path
|
||||||
|
ret = await self.client.api.post_message(
|
||||||
|
channel_id=session.session_id,
|
||||||
|
**payload,
|
||||||
|
)
|
||||||
|
elif session.message_type == MessageType.FRIEND_MESSAGE:
|
||||||
|
payload["msg_seq"] = random.randint(1, 10000)
|
||||||
|
if image_base64:
|
||||||
|
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
|
||||||
|
send_helper, # type: ignore
|
||||||
|
image_base64,
|
||||||
|
1,
|
||||||
|
openid=session.session_id,
|
||||||
|
)
|
||||||
|
payload["media"] = media
|
||||||
|
payload["msg_type"] = 7
|
||||||
|
if record_file_path:
|
||||||
|
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
|
||||||
|
send_helper, # type: ignore
|
||||||
|
record_file_path,
|
||||||
|
3,
|
||||||
|
openid=session.session_id,
|
||||||
|
)
|
||||||
|
payload["media"] = media
|
||||||
|
payload["msg_type"] = 7
|
||||||
|
ret = await QQOfficialMessageEvent.post_c2c_message(
|
||||||
|
send_helper, # type: ignore
|
||||||
|
openid=session.session_id,
|
||||||
|
**payload,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"[QQOfficialWebhook] Unsupported message type for send_by_session: %s",
|
||||||
|
session.message_type,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
sent_message_id = self._extract_message_id(ret)
|
||||||
|
if sent_message_id:
|
||||||
|
self.remember_session_message_id(session.session_id, sent_message_id)
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
def remember_session_message_id(self, session_id: str, message_id: str) -> None:
|
def remember_session_message_id(self, session_id: str, message_id: str) -> None:
|
||||||
if not session_id or not message_id:
|
if not session_id or not message_id:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
@@ -40,9 +39,6 @@ class QQOfficialWebhook:
|
|||||||
self.client = botpy_client
|
self.client = botpy_client
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue
|
||||||
self.shutdown_event = asyncio.Event()
|
self.shutdown_event = asyncio.Event()
|
||||||
# Deduplication cache for webhook retry callbacks.
|
|
||||||
self._seen_event_ids: dict[str, float] = {}
|
|
||||||
self._dedup_ttl: int = 60 # seconds
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info("正在登录到 QQ 官方机器人...")
|
logger.info("正在登录到 QQ 官方机器人...")
|
||||||
@@ -59,7 +55,7 @@ class QQOfficialWebhook:
|
|||||||
max_async=1,
|
max_async=1,
|
||||||
connect=bot_connect,
|
connect=bot_connect,
|
||||||
dispatch=self.client.ws_dispatch,
|
dispatch=self.client.ws_dispatch,
|
||||||
loop=asyncio.get_running_loop(),
|
loop=asyncio.get_event_loop(),
|
||||||
api=self.api,
|
api=self.api,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,22 +106,6 @@ class QQOfficialWebhook:
|
|||||||
print(signed)
|
print(signed)
|
||||||
return signed
|
return signed
|
||||||
|
|
||||||
event_id = msg.get("id")
|
|
||||||
if event_id:
|
|
||||||
now = time.monotonic()
|
|
||||||
# Lazily evict expired entries to prevent unbounded growth.
|
|
||||||
expired = [
|
|
||||||
k
|
|
||||||
for k, ts in self._seen_event_ids.items()
|
|
||||||
if now - ts > self._dedup_ttl
|
|
||||||
]
|
|
||||||
for k in expired:
|
|
||||||
del self._seen_event_ids[k]
|
|
||||||
if event_id in self._seen_event_ids:
|
|
||||||
logger.debug(f"Duplicate webhook event {event_id!r}, skipping.")
|
|
||||||
return {"opcode": 12}
|
|
||||||
self._seen_event_ids[event_id] = now
|
|
||||||
|
|
||||||
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
|
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
|
||||||
event = msg["t"].lower()
|
event = msg["t"].lower()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -289,8 +289,8 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
else:
|
else:
|
||||||
message.type = MessageType.GROUP_MESSAGE
|
message.type = MessageType.GROUP_MESSAGE
|
||||||
message.group_id = str(update.message.chat.id)
|
message.group_id = str(update.message.chat.id)
|
||||||
if update.message.is_topic_message and update.message.message_thread_id:
|
if update.message.message_thread_id:
|
||||||
# Telegram Topic Group: include thread id to isolate per-topic sessions.
|
# Topic Group
|
||||||
message.group_id += "#" + str(update.message.message_thread_id)
|
message.group_id += "#" + str(update.message.message_thread_id)
|
||||||
message.session_id = message.group_id
|
message.session_id = message.group_id
|
||||||
message.message_id = str(update.message.message_id)
|
message.message_id = str(update.message.message_id)
|
||||||
|
|||||||
@@ -25,16 +25,6 @@ from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
|
|||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
|
|
||||||
|
|
||||||
def _is_gif(path: str) -> bool:
|
|
||||||
if path.lower().endswith(".gif"):
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
return f.read(6) in (b"GIF87a", b"GIF89a")
|
|
||||||
except OSError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramPlatformEvent(AstrMessageEvent):
|
class TelegramPlatformEvent(AstrMessageEvent):
|
||||||
# Telegram 的最大消息长度限制
|
# Telegram 的最大消息长度限制
|
||||||
MAX_MESSAGE_LENGTH = 4096
|
MAX_MESSAGE_LENGTH = 4096
|
||||||
@@ -288,6 +278,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
md_text = telegramify_markdown.markdownify(
|
md_text = telegramify_markdown.markdownify(
|
||||||
chunk,
|
chunk,
|
||||||
|
normalize_whitespace=False,
|
||||||
)
|
)
|
||||||
await client.send_message(
|
await client.send_message(
|
||||||
text=md_text,
|
text=md_text,
|
||||||
@@ -301,13 +292,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
await client.send_message(text=chunk, **cast(Any, payload))
|
await client.send_message(text=chunk, **cast(Any, payload))
|
||||||
elif isinstance(i, Image):
|
elif isinstance(i, Image):
|
||||||
image_path = await i.convert_to_file_path()
|
image_path = await i.convert_to_file_path()
|
||||||
if _is_gif(image_path):
|
await client.send_photo(photo=image_path, **cast(Any, payload))
|
||||||
send_coro = client.send_animation
|
|
||||||
media_kwarg = {"animation": image_path}
|
|
||||||
else:
|
|
||||||
send_coro = client.send_photo
|
|
||||||
media_kwarg = {"photo": image_path}
|
|
||||||
await send_coro(**media_kwarg, **cast(Any, payload))
|
|
||||||
elif isinstance(i, File):
|
elif isinstance(i, File):
|
||||||
path = await i.get_file()
|
path = await i.get_file()
|
||||||
name = i.name or os.path.basename(path)
|
name = i.name or os.path.basename(path)
|
||||||
@@ -422,20 +407,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
on_text(i.text)
|
on_text(i.text)
|
||||||
elif isinstance(i, Image):
|
elif isinstance(i, Image):
|
||||||
image_path = await i.convert_to_file_path()
|
image_path = await i.convert_to_file_path()
|
||||||
if _is_gif(image_path):
|
|
||||||
action = ChatAction.UPLOAD_VIDEO
|
|
||||||
send_coro = self.client.send_animation
|
|
||||||
media_kwarg = {"animation": image_path}
|
|
||||||
else:
|
|
||||||
action = ChatAction.UPLOAD_PHOTO
|
|
||||||
send_coro = self.client.send_photo
|
|
||||||
media_kwarg = {"photo": image_path}
|
|
||||||
await self._send_media_with_action(
|
await self._send_media_with_action(
|
||||||
self.client,
|
self.client,
|
||||||
action,
|
ChatAction.UPLOAD_PHOTO,
|
||||||
send_coro,
|
self.client.send_photo,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
**media_kwarg,
|
photo=image_path,
|
||||||
**cast(Any, payload),
|
**cast(Any, payload),
|
||||||
)
|
)
|
||||||
elif isinstance(i, File):
|
elif isinstance(i, File):
|
||||||
@@ -479,6 +456,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
markdown_text = telegramify_markdown.markdownify(
|
markdown_text = telegramify_markdown.markdownify(
|
||||||
delta,
|
delta,
|
||||||
|
normalize_whitespace=False,
|
||||||
)
|
)
|
||||||
await self.client.send_message(
|
await self.client.send_message(
|
||||||
text=markdown_text,
|
text=markdown_text,
|
||||||
@@ -559,6 +537,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
md = telegramify_markdown.markdownify(
|
md = telegramify_markdown.markdownify(
|
||||||
draft_text,
|
draft_text,
|
||||||
|
normalize_whitespace=False,
|
||||||
)
|
)
|
||||||
await self._send_message_draft(
|
await self._send_message_draft(
|
||||||
user_name,
|
user_name,
|
||||||
@@ -647,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 发送初始 typing 状态
|
# 发送初始 typing 状态
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = asyncio.get_running_loop().time()
|
last_chat_action_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
def _append_text(t: str) -> None:
|
def _append_text(t: str) -> None:
|
||||||
nonlocal delta
|
nonlocal delta
|
||||||
@@ -678,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 编辑或发送消息
|
# 编辑或发送消息
|
||||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||||
current_time = asyncio.get_running_loop().time()
|
current_time = asyncio.get_event_loop().time()
|
||||||
time_since_last_edit = current_time - last_edit_time
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
if time_since_last_edit >= throttle_interval:
|
if time_since_last_edit >= throttle_interval:
|
||||||
current_time = asyncio.get_running_loop().time()
|
current_time = asyncio.get_event_loop().time()
|
||||||
if current_time - last_chat_action_time >= chat_action_interval:
|
if current_time - last_chat_action_time >= chat_action_interval:
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = current_time
|
last_chat_action_time = current_time
|
||||||
@@ -695,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
current_content = delta
|
current_content = delta
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
last_edit_time = asyncio.get_running_loop().time()
|
last_edit_time = asyncio.get_event_loop().time()
|
||||||
else:
|
else:
|
||||||
current_time = asyncio.get_running_loop().time()
|
current_time = asyncio.get_event_loop().time()
|
||||||
if current_time - last_chat_action_time >= chat_action_interval:
|
if current_time - last_chat_action_time >= chat_action_interval:
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = current_time
|
last_chat_action_time = current_time
|
||||||
@@ -709,13 +688,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
message_id = msg.message_id
|
message_id = msg.message_id
|
||||||
last_edit_time = asyncio.get_running_loop().time()
|
last_edit_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if delta and current_content != delta:
|
if delta and current_content != delta:
|
||||||
try:
|
try:
|
||||||
markdown_text = telegramify_markdown.markdownify(
|
markdown_text = telegramify_markdown.markdownify(
|
||||||
delta,
|
delta,
|
||||||
|
normalize_whitespace=False,
|
||||||
)
|
)
|
||||||
await self.client.edit_message_text(
|
await self.client.edit_message_text(
|
||||||
text=markdown_text,
|
text=markdown_text,
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
return msg_list[-1]
|
return msg_list[-1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
msg_new = await asyncio.get_running_loop().run_in_executor(
|
msg_new = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
get_latest_msg_item,
|
get_latest_msg_item,
|
||||||
)
|
)
|
||||||
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if self.kf_name:
|
if self.kf_name:
|
||||||
try:
|
try:
|
||||||
acc_list = (
|
acc_list = (
|
||||||
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
elif isinstance(msg, VoiceMessage):
|
elif isinstance(msg, VoiceMessage):
|
||||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
msg.media_id,
|
msg.media_id,
|
||||||
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.message_str = text
|
abm.message_str = text
|
||||||
elif msgtype == "image":
|
elif msgtype == "image":
|
||||||
media_id = msg.get("image", {}).get("media_id", "")
|
media_id = msg.get("image", {}).get("media_id", "")
|
||||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
media_id,
|
media_id,
|
||||||
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.message = [Image(file=path, url=path)]
|
abm.message = [Image(file=path, url=path)]
|
||||||
elif msgtype == "voice":
|
elif msgtype == "voice":
|
||||||
media_id = msg.get("voice", {}).get("media_id", "")
|
media_id = msg.get("voice", {}).get("media_id", "")
|
||||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
media_id,
|
media_id,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""企业微信智能机器人平台适配器
|
"""企业微信智能机器人平台适配器
|
||||||
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调与长连接
|
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调
|
||||||
参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应
|
参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -31,7 +31,6 @@ from .wecomai_api import (
|
|||||||
WecomAIBotStreamMessageBuilder,
|
WecomAIBotStreamMessageBuilder,
|
||||||
)
|
)
|
||||||
from .wecomai_event import WecomAIBotMessageEvent
|
from .wecomai_event import WecomAIBotMessageEvent
|
||||||
from .wecomai_long_connection import WecomAIBotLongConnectionClient
|
|
||||||
from .wecomai_queue_mgr import WecomAIQueueMgr
|
from .wecomai_queue_mgr import WecomAIQueueMgr
|
||||||
from .wecomai_server import WecomAIBotServer
|
from .wecomai_server import WecomAIBotServer
|
||||||
from .wecomai_utils import (
|
from .wecomai_utils import (
|
||||||
@@ -79,13 +78,8 @@ class WecomAIBotAdapter(Platform):
|
|||||||
self.settings = platform_settings
|
self.settings = platform_settings
|
||||||
|
|
||||||
# 初始化配置参数
|
# 初始化配置参数
|
||||||
self.connection_mode = self.config.get(
|
self.token = self.config["token"]
|
||||||
"wecom_ai_bot_connection_mode", "webhook"
|
self.encoding_aes_key = self.config["encoding_aes_key"]
|
||||||
)
|
|
||||||
self.token = self.config.get("token", self.config.get("wecomaibot_token", ""))
|
|
||||||
self.encoding_aes_key = self.config.get(
|
|
||||||
"encoding_aes_key", self.config.get("wecomaibot_encoding_aes_key", "")
|
|
||||||
)
|
|
||||||
self.port = int(self.config["port"])
|
self.port = int(self.config["port"])
|
||||||
self.host = self.config.get("callback_server_host", "0.0.0.0")
|
self.host = self.config.get("callback_server_host", "0.0.0.0")
|
||||||
self.bot_name = self.config.get("wecom_ai_bot_name", "")
|
self.bot_name = self.config.get("wecom_ai_bot_name", "")
|
||||||
@@ -102,52 +96,25 @@ class WecomAIBotAdapter(Platform):
|
|||||||
self.only_use_webhook_url_to_send = bool(
|
self.only_use_webhook_url_to_send = bool(
|
||||||
self.config.get("only_use_webhook_url_to_send", False),
|
self.config.get("only_use_webhook_url_to_send", False),
|
||||||
)
|
)
|
||||||
self.long_connection_bot_id = self.config.get(
|
|
||||||
"wecomaibot_ws_bot_id", self.config.get("long_connection_bot_id", "")
|
|
||||||
)
|
|
||||||
self.long_connection_secret = self.config.get(
|
|
||||||
"wecomaibot_ws_secret", self.config.get("long_connection_secret", "")
|
|
||||||
)
|
|
||||||
self.long_connection_ws_url = self.config.get(
|
|
||||||
"wecomaibot_ws_url",
|
|
||||||
"wss://openws.work.weixin.qq.com",
|
|
||||||
)
|
|
||||||
self.long_connection_heartbeat_interval = int(
|
|
||||||
self.config.get("wecomaibot_heartbeat_interval", 30),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 平台元数据
|
# 平台元数据
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
name="wecom_ai_bot",
|
name="wecom_ai_bot",
|
||||||
description="企业微信智能机器人适配器,支持 HTTP 回调和长连接模式",
|
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
|
||||||
id=self.config.get("id", "wecom_ai_bot"),
|
id=self.config.get("id", "wecom_ai_bot"),
|
||||||
support_proactive_message=bool(self.msg_push_webhook_url),
|
support_proactive_message=bool(self.msg_push_webhook_url),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.api_client: WecomAIBotAPIClient | None = None
|
# 初始化 API 客户端
|
||||||
self.server: WecomAIBotServer | None = None
|
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
|
||||||
self.long_connection_client: WecomAIBotLongConnectionClient | None = None
|
|
||||||
|
|
||||||
if self.connection_mode == "long_connection":
|
# 初始化 HTTP 服务器
|
||||||
if not self.long_connection_bot_id or not self.long_connection_secret:
|
self.server = WecomAIBotServer(
|
||||||
logger.warning(
|
host=self.host,
|
||||||
"企业微信智能机器人长连接模式缺少 BotID 或 Secret,连接可能失败"
|
port=self.port,
|
||||||
)
|
api_client=self.api_client,
|
||||||
self.long_connection_client = WecomAIBotLongConnectionClient(
|
message_handler=self._process_message,
|
||||||
bot_id=self.long_connection_bot_id,
|
)
|
||||||
secret=self.long_connection_secret,
|
|
||||||
ws_url=self.long_connection_ws_url,
|
|
||||||
heartbeat_interval=self.long_connection_heartbeat_interval,
|
|
||||||
message_handler=self._process_long_connection_payload,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
|
|
||||||
self.server = WecomAIBotServer(
|
|
||||||
host=self.host,
|
|
||||||
port=self.port,
|
|
||||||
api_client=self.api_client,
|
|
||||||
message_handler=self._process_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 事件循环和关闭信号
|
# 事件循环和关闭信号
|
||||||
self.shutdown_event = asyncio.Event()
|
self.shutdown_event = asyncio.Event()
|
||||||
@@ -194,9 +161,6 @@ class WecomAIBotAdapter(Platform):
|
|||||||
加密后的响应消息,无需响应时返回 None
|
加密后的响应消息,无需响应时返回 None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self.api_client:
|
|
||||||
logger.error("Webhook 消息处理失败: API 客户端未初始化")
|
|
||||||
return None
|
|
||||||
msgtype = message_data.get("msgtype")
|
msgtype = message_data.get("msgtype")
|
||||||
if not msgtype:
|
if not msgtype:
|
||||||
logger.warning(f"消息类型未知,忽略: {message_data}")
|
logger.warning(f"消息类型未知,忽略: {message_data}")
|
||||||
@@ -356,100 +320,10 @@ class WecomAIBotAdapter(Platform):
|
|||||||
logger.error("处理欢迎消息时发生异常: %s", e)
|
logger.error("处理欢迎消息时发生异常: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _process_long_connection_payload(
|
|
||||||
self,
|
|
||||||
payload: dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""处理长连接回调消息。"""
|
|
||||||
cmd = payload.get("cmd")
|
|
||||||
headers = payload.get("headers") or {}
|
|
||||||
body = payload.get("body") or {}
|
|
||||||
req_id = headers.get("req_id")
|
|
||||||
if not isinstance(body, dict):
|
|
||||||
return
|
|
||||||
|
|
||||||
if cmd == "aibot_msg_callback":
|
|
||||||
session_id = self._extract_session_id(body)
|
|
||||||
stream_id = f"{session_id}_{generate_random_string(10)}"
|
|
||||||
await self._enqueue_message(
|
|
||||||
body, {"req_id": req_id or ""}, stream_id, session_id
|
|
||||||
)
|
|
||||||
self.queue_mgr.set_pending_response(
|
|
||||||
stream_id,
|
|
||||||
{
|
|
||||||
"req_id": req_id or "",
|
|
||||||
"connection_mode": "long_connection",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.initial_respond_text and req_id:
|
|
||||||
await self._send_long_connection_respond_msg(
|
|
||||||
req_id=req_id,
|
|
||||||
body={
|
|
||||||
"msgtype": "stream",
|
|
||||||
"stream": {
|
|
||||||
"id": stream_id,
|
|
||||||
"finish": False,
|
|
||||||
"content": self.initial_respond_text,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if cmd == "aibot_event_callback":
|
|
||||||
event = body.get("event") or {}
|
|
||||||
event_type = event.get("eventtype")
|
|
||||||
if (
|
|
||||||
event_type == "enter_chat"
|
|
||||||
and self.friend_message_welcome_text
|
|
||||||
and req_id
|
|
||||||
):
|
|
||||||
await self._send_long_connection_respond_welcome(req_id)
|
|
||||||
elif event_type == "disconnected_event":
|
|
||||||
logger.warning(
|
|
||||||
"[WecomAI][LongConn] 收到 disconnected_event,旧连接将被关闭"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _send_long_connection_respond_welcome(self, req_id: str) -> bool:
|
|
||||||
client = self.long_connection_client
|
|
||||||
if not client:
|
|
||||||
return False
|
|
||||||
return await client.send_command(
|
|
||||||
cmd="aibot_respond_welcome_msg",
|
|
||||||
req_id=req_id,
|
|
||||||
body={
|
|
||||||
"msgtype": "text",
|
|
||||||
"text": {
|
|
||||||
"content": self.friend_message_welcome_text,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _send_long_connection_respond_msg(
|
|
||||||
self,
|
|
||||||
req_id: str,
|
|
||||||
body: dict[str, Any],
|
|
||||||
) -> bool:
|
|
||||||
client = self.long_connection_client
|
|
||||||
if not client:
|
|
||||||
return False
|
|
||||||
return await client.send_command(
|
|
||||||
cmd="aibot_respond_msg",
|
|
||||||
req_id=req_id,
|
|
||||||
body=body,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _extract_session_id(self, message_data: dict[str, Any]) -> str:
|
def _extract_session_id(self, message_data: dict[str, Any]) -> str:
|
||||||
"""从消息数据中提取会话ID
|
"""从消息数据中提取会话ID"""
|
||||||
群聊使用 chatid,单聊使用 userid
|
user_id = message_data.get("from", {}).get("userid", "default_user")
|
||||||
"""
|
return format_session_id("wecomai", user_id)
|
||||||
chattype = message_data.get("chattype", "single")
|
|
||||||
if chattype == "group":
|
|
||||||
chat_id = message_data.get("chatid", "default_group")
|
|
||||||
return format_session_id("wecomai", chat_id)
|
|
||||||
else:
|
|
||||||
user_id = message_data.get("from", {}).get("userid", "default_user")
|
|
||||||
return format_session_id("wecomai", user_id)
|
|
||||||
|
|
||||||
async def _enqueue_message(
|
async def _enqueue_message(
|
||||||
self,
|
self,
|
||||||
@@ -481,16 +355,15 @@ class WecomAIBotAdapter(Platform):
|
|||||||
content = ""
|
content = ""
|
||||||
image_base64 = []
|
image_base64 = []
|
||||||
|
|
||||||
_img_url_to_process: list[tuple[str, str | None]] = []
|
_img_url_to_process = []
|
||||||
msg_items = []
|
msg_items = []
|
||||||
|
|
||||||
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
|
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
|
||||||
content = WecomAIBotMessageParser.parse_text_message(message_data)
|
content = WecomAIBotMessageParser.parse_text_message(message_data)
|
||||||
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||||
image_payload = message_data.get("image", {})
|
_img_url_to_process.append(
|
||||||
image_url = image_payload.get("url", "")
|
WecomAIBotMessageParser.parse_image_message(message_data),
|
||||||
if image_url:
|
)
|
||||||
_img_url_to_process.append((image_url, image_payload.get("aeskey")))
|
|
||||||
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
|
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
|
||||||
# 提取混合消息中的文本内容
|
# 提取混合消息中的文本内容
|
||||||
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
|
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
|
||||||
@@ -501,12 +374,9 @@ class WecomAIBotAdapter(Platform):
|
|||||||
if text_content:
|
if text_content:
|
||||||
text_parts.append(text_content)
|
text_parts.append(text_content)
|
||||||
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||||
image_payload = item.get("image", {})
|
image_url = item.get("image", {}).get("url", "")
|
||||||
image_url = image_payload.get("url", "")
|
|
||||||
if image_url:
|
if image_url:
|
||||||
_img_url_to_process.append(
|
_img_url_to_process.append(image_url)
|
||||||
(image_url, image_payload.get("aeskey"))
|
|
||||||
)
|
|
||||||
content = " ".join(text_parts) if text_parts else ""
|
content = " ".join(text_parts) if text_parts else ""
|
||||||
else:
|
else:
|
||||||
content = f"[{msgtype}消息]"
|
content = f"[{msgtype}消息]"
|
||||||
@@ -514,8 +384,8 @@ class WecomAIBotAdapter(Platform):
|
|||||||
# 并行处理图片下载和解密
|
# 并行处理图片下载和解密
|
||||||
if _img_url_to_process:
|
if _img_url_to_process:
|
||||||
tasks = [
|
tasks = [
|
||||||
process_encrypted_image(url, aes_key or self.encoding_aes_key)
|
process_encrypted_image(url, self.encoding_aes_key)
|
||||||
for url, aes_key in _img_url_to_process
|
for url in _img_url_to_process
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
for success, result in results:
|
for success, result in results:
|
||||||
@@ -589,43 +459,26 @@ class WecomAIBotAdapter(Platform):
|
|||||||
"""运行适配器,同时启动HTTP服务器和队列监听器"""
|
"""运行适配器,同时启动HTTP服务器和队列监听器"""
|
||||||
|
|
||||||
async def run_both() -> None:
|
async def run_both() -> None:
|
||||||
if self.connection_mode == "long_connection":
|
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||||
if not self.long_connection_client:
|
webhook_uuid = self.config.get("webhook_uuid")
|
||||||
raise RuntimeError("长连接客户端未初始化")
|
if self.unified_webhook_mode and webhook_uuid:
|
||||||
|
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
|
||||||
|
# 只运行队列监听器
|
||||||
|
await self.queue_listener.run()
|
||||||
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"启动企业微信智能机器人长连接模式: %s", self.long_connection_ws_url
|
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
|
||||||
)
|
)
|
||||||
|
# 同时运行HTTP服务器和队列监听器
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
self.long_connection_client.start(),
|
self.server.start_server(),
|
||||||
self.queue_listener.run(),
|
self.queue_listener.run(),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
|
||||||
webhook_uuid = self.config.get("webhook_uuid")
|
|
||||||
if self.unified_webhook_mode and webhook_uuid:
|
|
||||||
log_webhook_info(
|
|
||||||
f"{self.meta().id}(企业微信智能机器人)", webhook_uuid
|
|
||||||
)
|
|
||||||
# 只运行队列监听器
|
|
||||||
await self.queue_listener.run()
|
|
||||||
else:
|
|
||||||
if not self.server:
|
|
||||||
raise RuntimeError("Webhook 服务器未初始化")
|
|
||||||
logger.info(
|
|
||||||
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
|
|
||||||
)
|
|
||||||
# 同时运行HTTP服务器和队列监听器
|
|
||||||
await asyncio.gather(
|
|
||||||
self.server.start_server(),
|
|
||||||
self.queue_listener.run(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return run_both()
|
return run_both()
|
||||||
|
|
||||||
async def webhook_callback(self, request: Any) -> Any:
|
async def webhook_callback(self, request: Any) -> Any:
|
||||||
"""统一 Webhook 回调入口"""
|
"""统一 Webhook 回调入口"""
|
||||||
if self.connection_mode == "long_connection" or not self.server:
|
|
||||||
return "long_connection mode does not accept webhook callbacks", 400
|
|
||||||
# 根据请求方法分发到不同的处理函数
|
# 根据请求方法分发到不同的处理函数
|
||||||
if request.method == "GET":
|
if request.method == "GET":
|
||||||
return await self.server.handle_verify(request)
|
return await self.server.handle_verify(request)
|
||||||
@@ -636,10 +489,7 @@ class WecomAIBotAdapter(Platform):
|
|||||||
"""终止适配器"""
|
"""终止适配器"""
|
||||||
logger.info("企业微信智能机器人适配器正在关闭...")
|
logger.info("企业微信智能机器人适配器正在关闭...")
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
if self.long_connection_client:
|
await self.server.shutdown()
|
||||||
await self.long_connection_client.shutdown()
|
|
||||||
if self.server:
|
|
||||||
await self.server.shutdown()
|
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
"""获取平台元数据"""
|
"""获取平台元数据"""
|
||||||
@@ -657,22 +507,17 @@ class WecomAIBotAdapter(Platform):
|
|||||||
queue_mgr=self.queue_mgr,
|
queue_mgr=self.queue_mgr,
|
||||||
webhook_client=self.webhook_client,
|
webhook_client=self.webhook_client,
|
||||||
only_use_webhook_url_to_send=self.only_use_webhook_url_to_send,
|
only_use_webhook_url_to_send=self.only_use_webhook_url_to_send,
|
||||||
long_connection_sender=self._send_long_connection_respond_msg,
|
|
||||||
)
|
)
|
||||||
message_event.is_at_or_wake_command = (
|
|
||||||
True # 企业微信智能机器人默认消息都是 at 或唤醒命令
|
|
||||||
)
|
|
||||||
message_event.is_wake = True # 企业微信智能机器人消息默认当做唤醒命令处理
|
|
||||||
|
|
||||||
self.commit_event(message_event)
|
self.commit_event(message_event)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("处理消息时发生异常: %s", e)
|
logger.error("处理消息时发生异常: %s", e)
|
||||||
|
|
||||||
def get_client(self) -> WecomAIBotAPIClient | None:
|
def get_client(self) -> WecomAIBotAPIClient:
|
||||||
"""获取 API 客户端"""
|
"""获取 API 客户端"""
|
||||||
return self.api_client
|
return self.api_client
|
||||||
|
|
||||||
def get_server(self) -> WecomAIBotServer | None:
|
def get_server(self) -> WecomAIBotServer:
|
||||||
"""获取 HTTP 服务器实例"""
|
"""获取 HTTP 服务器实例"""
|
||||||
return self.server
|
return self.server
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
|
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
|
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import At, Image, Plain
|
from astrbot.api.message_components import At, Image, Plain
|
||||||
@@ -20,11 +18,10 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
message_obj,
|
message_obj,
|
||||||
platform_meta,
|
platform_meta,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
api_client: WecomAIBotAPIClient | None,
|
api_client: WecomAIBotAPIClient,
|
||||||
queue_mgr: WecomAIQueueMgr,
|
queue_mgr: WecomAIQueueMgr,
|
||||||
webhook_client: WecomAIBotWebhookClient | None = None,
|
webhook_client: WecomAIBotWebhookClient | None = None,
|
||||||
only_use_webhook_url_to_send: bool = False,
|
only_use_webhook_url_to_send: bool = False,
|
||||||
long_connection_sender: (Callable[[str, dict], Awaitable[bool]] | None) = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化消息事件
|
"""初始化消息事件
|
||||||
|
|
||||||
@@ -41,7 +38,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
self.queue_mgr = queue_mgr
|
self.queue_mgr = queue_mgr
|
||||||
self.webhook_client = webhook_client
|
self.webhook_client = webhook_client
|
||||||
self.only_use_webhook_url_to_send = only_use_webhook_url_to_send
|
self.only_use_webhook_url_to_send = only_use_webhook_url_to_send
|
||||||
self.long_connection_sender = long_connection_sender
|
|
||||||
|
|
||||||
async def _mark_stream_complete(self, stream_id: str) -> None:
|
async def _mark_stream_complete(self, stream_id: str) -> None:
|
||||||
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
|
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
|
||||||
@@ -121,18 +117,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_plain_text_from_chain(message_chain: MessageChain | None) -> str:
|
|
||||||
if not message_chain:
|
|
||||||
return ""
|
|
||||||
plain_parts: list[str] = []
|
|
||||||
for comp in message_chain.chain:
|
|
||||||
if isinstance(comp, At):
|
|
||||||
plain_parts.append(f"@{comp.name} ")
|
|
||||||
elif isinstance(comp, Plain):
|
|
||||||
plain_parts.append(comp.text)
|
|
||||||
return "".join(plain_parts).strip()
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain | None) -> None:
|
async def send(self, message: MessageChain | None) -> None:
|
||||||
"""发送消息"""
|
"""发送消息"""
|
||||||
raw = self.message_obj.raw_message
|
raw = self.message_obj.raw_message
|
||||||
@@ -140,44 +124,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
"wecom_ai_bot platform event raw_message should be a dict"
|
"wecom_ai_bot platform event raw_message should be a dict"
|
||||||
)
|
)
|
||||||
stream_id = raw.get("stream_id", self.session_id)
|
stream_id = raw.get("stream_id", self.session_id)
|
||||||
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
|
|
||||||
connection_mode = pending_response.get("callback_params", {}).get(
|
|
||||||
"connection_mode"
|
|
||||||
)
|
|
||||||
req_id = pending_response.get("callback_params", {}).get("req_id")
|
|
||||||
|
|
||||||
if (
|
|
||||||
connection_mode == "long_connection"
|
|
||||||
and self.long_connection_sender
|
|
||||||
and isinstance(req_id, str)
|
|
||||||
and req_id
|
|
||||||
):
|
|
||||||
if self.only_use_webhook_url_to_send and self.webhook_client and message:
|
|
||||||
await self.webhook_client.send_message_chain(message)
|
|
||||||
await super().send(MessageChain([]))
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.webhook_client and message:
|
|
||||||
await self.webhook_client.send_message_chain(
|
|
||||||
message,
|
|
||||||
unsupported_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = self._extract_plain_text_from_chain(message)
|
|
||||||
await self.long_connection_sender(
|
|
||||||
req_id,
|
|
||||||
{
|
|
||||||
"msgtype": "stream",
|
|
||||||
"stream": {
|
|
||||||
"id": stream_id,
|
|
||||||
"finish": True,
|
|
||||||
"content": content,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await super().send(MessageChain([]))
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.only_use_webhook_url_to_send and self.webhook_client and message:
|
if self.only_use_webhook_url_to_send and self.webhook_client and message:
|
||||||
await self.webhook_client.send_message_chain(message)
|
await self.webhook_client.send_message_chain(message)
|
||||||
await self._mark_stream_complete(stream_id)
|
await self._mark_stream_complete(stream_id)
|
||||||
@@ -206,77 +152,8 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
"wecom_ai_bot platform event raw_message should be a dict"
|
"wecom_ai_bot platform event raw_message should be a dict"
|
||||||
)
|
)
|
||||||
stream_id = raw.get("stream_id", self.session_id)
|
stream_id = raw.get("stream_id", self.session_id)
|
||||||
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
|
|
||||||
connection_mode = pending_response.get("callback_params", {}).get(
|
|
||||||
"connection_mode"
|
|
||||||
)
|
|
||||||
req_id = pending_response.get("callback_params", {}).get("req_id")
|
|
||||||
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
|
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
|
||||||
|
|
||||||
if (
|
|
||||||
connection_mode == "long_connection"
|
|
||||||
and self.long_connection_sender
|
|
||||||
and isinstance(req_id, str)
|
|
||||||
and req_id
|
|
||||||
):
|
|
||||||
if self.only_use_webhook_url_to_send and self.webhook_client:
|
|
||||||
merged_chain = MessageChain([])
|
|
||||||
async for chain in generator:
|
|
||||||
merged_chain.chain.extend(chain.chain)
|
|
||||||
merged_chain.squash_plain()
|
|
||||||
await self.webhook_client.send_message_chain(merged_chain)
|
|
||||||
await self.long_connection_sender(
|
|
||||||
req_id,
|
|
||||||
{
|
|
||||||
"msgtype": "stream",
|
|
||||||
"stream": {
|
|
||||||
"id": stream_id,
|
|
||||||
"finish": True,
|
|
||||||
"content": "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await super().send_streaming(generator, use_fallback)
|
|
||||||
return
|
|
||||||
|
|
||||||
increment_plain = ""
|
|
||||||
async for chain in generator:
|
|
||||||
if self.webhook_client:
|
|
||||||
await self.webhook_client.send_message_chain(
|
|
||||||
chain,
|
|
||||||
unsupported_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
chain.squash_plain()
|
|
||||||
chunk_text = self._extract_plain_text_from_chain(chain)
|
|
||||||
if chunk_text:
|
|
||||||
increment_plain += chunk_text
|
|
||||||
await self.long_connection_sender(
|
|
||||||
req_id,
|
|
||||||
{
|
|
||||||
"msgtype": "stream",
|
|
||||||
"stream": {
|
|
||||||
"id": stream_id,
|
|
||||||
"finish": False,
|
|
||||||
"content": increment_plain,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.long_connection_sender(
|
|
||||||
req_id,
|
|
||||||
{
|
|
||||||
"msgtype": "stream",
|
|
||||||
"stream": {
|
|
||||||
"id": stream_id,
|
|
||||||
"finish": True,
|
|
||||||
"content": increment_plain,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await super().send_streaming(generator, use_fallback)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.only_use_webhook_url_to_send and self.webhook_client:
|
if self.only_use_webhook_url_to_send and self.webhook_client:
|
||||||
merged_chain = MessageChain([])
|
merged_chain = MessageChain([])
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
|
|||||||
@@ -1,236 +0,0 @@
|
|||||||
"""企业微信智能机器人长连接客户端。"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
from astrbot.api import logger
|
|
||||||
|
|
||||||
|
|
||||||
class WecomAIBotLongConnectionClient:
|
|
||||||
"""企业微信智能机器人 WebSocket 长连接客户端。"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
bot_id: str,
|
|
||||||
secret: str,
|
|
||||||
ws_url: str,
|
|
||||||
heartbeat_interval: int,
|
|
||||||
message_handler: Callable[[dict[str, Any]], Awaitable[None]],
|
|
||||||
) -> None:
|
|
||||||
self.bot_id = bot_id
|
|
||||||
self.secret = secret
|
|
||||||
self.ws_url = ws_url
|
|
||||||
self.heartbeat_interval = max(5, int(heartbeat_interval))
|
|
||||||
self.message_handler = message_handler
|
|
||||||
|
|
||||||
self._session: aiohttp.ClientSession | None = None
|
|
||||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
|
||||||
self._shutdown_event = asyncio.Event()
|
|
||||||
self._send_lock = asyncio.Lock()
|
|
||||||
self._command_lock = asyncio.Lock()
|
|
||||||
self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gen_req_id() -> str:
|
|
||||||
return uuid.uuid4().hex
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
"""启动长连接并自动重连。"""
|
|
||||||
reconnect_delay = 1
|
|
||||||
while not self._shutdown_event.is_set():
|
|
||||||
try:
|
|
||||||
await self._run_once()
|
|
||||||
reconnect_delay = 1
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("[WecomAI][LongConn] 长连接异常: %s", e)
|
|
||||||
if self._shutdown_event.is_set():
|
|
||||||
break
|
|
||||||
await asyncio.sleep(reconnect_delay)
|
|
||||||
reconnect_delay = min(reconnect_delay * 2, 30)
|
|
||||||
|
|
||||||
async def _run_once(self) -> None:
|
|
||||||
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=None)
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
self._session = session
|
|
||||||
logger.info("[WecomAI][LongConn] 正在连接: %s", self.ws_url)
|
|
||||||
async with session.ws_connect(
|
|
||||||
self.ws_url, heartbeat=None, autoping=True
|
|
||||||
) as ws:
|
|
||||||
self._ws = ws
|
|
||||||
await self._subscribe()
|
|
||||||
logger.info("[WecomAI][LongConn] 订阅成功,已建立长连接")
|
|
||||||
|
|
||||||
heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
||||||
try:
|
|
||||||
while not self._shutdown_event.is_set():
|
|
||||||
message = await ws.receive()
|
|
||||||
if message.type == aiohttp.WSMsgType.TEXT:
|
|
||||||
await self._handle_text_message(message.data)
|
|
||||||
elif message.type in {
|
|
||||||
aiohttp.WSMsgType.CLOSED,
|
|
||||||
aiohttp.WSMsgType.CLOSE,
|
|
||||||
aiohttp.WSMsgType.ERROR,
|
|
||||||
}:
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
heartbeat_task.cancel()
|
|
||||||
try:
|
|
||||||
await heartbeat_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._ws = None
|
|
||||||
|
|
||||||
async def _subscribe(self) -> None:
|
|
||||||
"""发送 aibot_subscribe,并等待响应。"""
|
|
||||||
req_id = self.gen_req_id()
|
|
||||||
payload = {
|
|
||||||
"cmd": "aibot_subscribe",
|
|
||||||
"headers": {"req_id": req_id},
|
|
||||||
"body": {"bot_id": self.bot_id, "secret": self.secret},
|
|
||||||
}
|
|
||||||
await self._send_json(payload)
|
|
||||||
|
|
||||||
if not self._ws:
|
|
||||||
raise RuntimeError("WebSocket 未建立")
|
|
||||||
|
|
||||||
reply = await self._ws.receive(timeout=10)
|
|
||||||
if reply.type != aiohttp.WSMsgType.TEXT:
|
|
||||||
raise RuntimeError(f"订阅失败: 非文本响应 {reply.type}")
|
|
||||||
|
|
||||||
data = json.loads(reply.data)
|
|
||||||
if data.get("errcode") != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"订阅失败 errcode={data.get('errcode')} errmsg={data.get('errmsg')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _heartbeat_loop(self) -> None:
|
|
||||||
while not self._shutdown_event.is_set():
|
|
||||||
await asyncio.sleep(self.heartbeat_interval)
|
|
||||||
if self._shutdown_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self.send_command("ping", self.gen_req_id(), None)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("[WecomAI][LongConn] 发送心跳失败: %s", e)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _handle_text_message(self, text: str) -> None:
|
|
||||||
try:
|
|
||||||
payload = json.loads(text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("[WecomAI][LongConn] 收到非 JSON 消息: %s", text)
|
|
||||||
return
|
|
||||||
|
|
||||||
headers = payload.get("headers") or {}
|
|
||||||
req_id = headers.get("req_id")
|
|
||||||
if isinstance(req_id, str):
|
|
||||||
waiter = self._response_waiters.get(req_id)
|
|
||||||
if waiter and not waiter.done():
|
|
||||||
waiter.set_result(payload)
|
|
||||||
return
|
|
||||||
|
|
||||||
cmd = payload.get("cmd")
|
|
||||||
if cmd in {"aibot_msg_callback", "aibot_event_callback"}:
|
|
||||||
await self.message_handler(payload)
|
|
||||||
return
|
|
||||||
|
|
||||||
if payload.get("errcode") not in (None, 0):
|
|
||||||
logger.warning(
|
|
||||||
"[WecomAI][LongConn] 服务端返回错误: errcode=%s errmsg=%s",
|
|
||||||
payload.get("errcode"),
|
|
||||||
payload.get("errmsg"),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_command(
|
|
||||||
self,
|
|
||||||
cmd: str,
|
|
||||||
req_id: str,
|
|
||||||
body: dict[str, Any] | None,
|
|
||||||
) -> bool:
|
|
||||||
"""发送长连接命令。"""
|
|
||||||
headers = {"req_id": req_id}
|
|
||||||
payload: dict[str, Any] = {"cmd": cmd, "headers": headers}
|
|
||||||
if body is not None:
|
|
||||||
payload["body"] = body
|
|
||||||
|
|
||||||
async with self._command_lock:
|
|
||||||
max_retries = 3
|
|
||||||
for attempt in range(max_retries + 1):
|
|
||||||
response = await self._send_and_wait_response(req_id, payload)
|
|
||||||
if not response:
|
|
||||||
if attempt < max_retries:
|
|
||||||
await asyncio.sleep(min(0.2 * (2**attempt), 2.0))
|
|
||||||
continue
|
|
||||||
return False
|
|
||||||
|
|
||||||
errcode = response.get("errcode")
|
|
||||||
if errcode in (0, None):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if errcode == 6000 and attempt < max_retries:
|
|
||||||
backoff = min(0.2 * (2**attempt), 2.0)
|
|
||||||
logger.warning(
|
|
||||||
"[WecomAI][LongConn] 命令冲突(errcode=6000),将重试。cmd=%s req_id=%s attempt=%d",
|
|
||||||
cmd,
|
|
||||||
req_id,
|
|
||||||
attempt + 1,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(backoff)
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"[WecomAI][LongConn] 命令失败: cmd=%s req_id=%s errcode=%s errmsg=%s",
|
|
||||||
cmd,
|
|
||||||
req_id,
|
|
||||||
errcode,
|
|
||||||
response.get("errmsg"),
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _send_and_wait_response(
|
|
||||||
self,
|
|
||||||
req_id: str,
|
|
||||||
payload: dict[str, Any],
|
|
||||||
timeout: float = 10.0,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
waiter: asyncio.Future[dict[str, Any]] = loop.create_future()
|
|
||||||
self._response_waiters[req_id] = waiter
|
|
||||||
try:
|
|
||||||
await self._send_json(payload)
|
|
||||||
return await asyncio.wait_for(waiter, timeout=timeout)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
"[WecomAI][LongConn] 等待命令响应超时: cmd=%s req_id=%s",
|
|
||||||
payload.get("cmd"),
|
|
||||||
req_id,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
finally:
|
|
||||||
self._response_waiters.pop(req_id, None)
|
|
||||||
|
|
||||||
async def _send_json(self, payload: dict[str, Any]) -> None:
|
|
||||||
ws = self._ws
|
|
||||||
if ws is None or ws.closed:
|
|
||||||
raise RuntimeError("长连接尚未建立")
|
|
||||||
async with self._send_lock:
|
|
||||||
await ws.send_json(payload)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
self._shutdown_event.set()
|
|
||||||
ws = self._ws
|
|
||||||
if ws is not None and not ws.closed:
|
|
||||||
await ws.close()
|
|
||||||
|
|
||||||
session = self._session
|
|
||||||
if session is not None and not session.closed:
|
|
||||||
await session.close()
|
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -83,7 +82,7 @@ class WecomAIQueueMgr:
|
|||||||
del self.pending_responses[session_id]
|
del self.pending_responses[session_id]
|
||||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||||
if mark_finished:
|
if mark_finished:
|
||||||
self.completed_streams[session_id] = time.monotonic()
|
self.completed_streams[session_id] = asyncio.get_event_loop().time()
|
||||||
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
||||||
|
|
||||||
def remove_queue(self, session_id: str):
|
def remove_queue(self, session_id: str):
|
||||||
@@ -136,7 +135,7 @@ class WecomAIQueueMgr:
|
|||||||
"""
|
"""
|
||||||
self.pending_responses[session_id] = {
|
self.pending_responses[session_id] = {
|
||||||
"callback_params": callback_params,
|
"callback_params": callback_params,
|
||||||
"timestamp": time.monotonic(),
|
"timestamp": asyncio.get_event_loop().time(),
|
||||||
}
|
}
|
||||||
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
||||||
|
|
||||||
@@ -161,7 +160,7 @@ class WecomAIQueueMgr:
|
|||||||
finished_at = self.completed_streams.get(session_id)
|
finished_at = self.completed_streams.get(session_id)
|
||||||
if finished_at is None:
|
if finished_at is None:
|
||||||
return False
|
return False
|
||||||
if time.monotonic() - finished_at > max_age_seconds:
|
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
|
||||||
self.completed_streams.pop(session_id, None)
|
self.completed_streams.pop(session_id, None)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -173,7 +172,7 @@ class WecomAIQueueMgr:
|
|||||||
max_age_seconds: 最大存活时间(秒)
|
max_age_seconds: 最大存活时间(秒)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
current_time = time.monotonic()
|
current_time = asyncio.get_event_loop().time()
|
||||||
expired_sessions = []
|
expired_sessions = []
|
||||||
|
|
||||||
for session_id, response_data in self.pending_responses.items():
|
for session_id, response_data in self.pending_responses.items():
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
if future:
|
if future:
|
||||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||||
else:
|
else:
|
||||||
future = asyncio.get_running_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
self.wexin_event_workers[msg_id] = future
|
self.wexin_event_workers[msg_id] = future
|
||||||
await self.convert_message(msg, future)
|
await self.convert_message(msg, future)
|
||||||
# I love shield so much!
|
# I love shield so much!
|
||||||
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
elif msg.type == "voice":
|
elif msg.type == "voice":
|
||||||
assert isinstance(msg, VoiceMessage)
|
assert isinstance(msg, VoiceMessage)
|
||||||
|
|
||||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
msg.media_id,
|
msg.media_id,
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
|
|
||||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||||
|
|
||||||
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 180.0
|
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 20.0
|
||||||
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 180.0
|
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 30.0
|
||||||
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
|
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
|
||||||
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
|
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
|
||||||
MAX_MCP_TIMEOUT_SECONDS = 300.0
|
MAX_MCP_TIMEOUT_SECONDS = 300.0
|
||||||
@@ -417,11 +417,9 @@ class FunctionToolManager:
|
|||||||
for (name, cfg, _), result in zip(active_configs, results, strict=False):
|
for (name, cfg, _), result in zip(active_configs, results, strict=False):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
if isinstance(result, MCPInitTimeoutError):
|
if isinstance(result, MCPInitTimeoutError):
|
||||||
logger.error(
|
logger.error(f"MCP 服务 {name} 初始化超时({timeout_display}秒)")
|
||||||
f"Connected to MCP server {name} timeout ({timeout_display} seconds)"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to initialize MCP server {name}: {result}")
|
logger.error(f"MCP 服务 {name} 初始化失败: {result}")
|
||||||
self._log_safe_mcp_debug_config(cfg)
|
self._log_safe_mcp_debug_config(cfg)
|
||||||
failed_services.append(name)
|
failed_services.append(name)
|
||||||
async with self._runtime_lock:
|
async with self._runtime_lock:
|
||||||
@@ -432,18 +430,16 @@ class FunctionToolManager:
|
|||||||
|
|
||||||
if failed_services:
|
if failed_services:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The following MCP services failed to initialize: {', '.join(failed_services)}. "
|
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}。"
|
||||||
f"Please check the mcp_server.json file and server availability."
|
f"请检查配置文件 mcp_server.json 和服务器可用性。"
|
||||||
)
|
)
|
||||||
|
|
||||||
summary = MCPInitSummary(
|
summary = MCPInitSummary(
|
||||||
total=len(active_configs), success=success_count, failed=failed_services
|
total=len(active_configs), success=success_count, failed=failed_services
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(f"MCP 服务初始化完成: {summary.success}/{summary.total} 成功")
|
||||||
f"MCP services initialization completed: {summary.success}/{summary.total} successful, {len(summary.failed)} failed."
|
|
||||||
)
|
|
||||||
if summary.total > 0 and summary.success == 0:
|
if summary.total > 0 and summary.success == 0:
|
||||||
msg = "All MCP services failed to initialize, please check the mcp_server.json and server availability."
|
msg = "全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
|
||||||
if raise_on_all_failed:
|
if raise_on_all_failed:
|
||||||
raise MCPAllServicesFailedError(msg)
|
raise MCPAllServicesFailedError(msg)
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
@@ -465,7 +461,7 @@ class FunctionToolManager:
|
|||||||
async with self._runtime_lock:
|
async with self._runtime_lock:
|
||||||
if name in self._mcp_server_runtime or name in self._mcp_starting:
|
if name in self._mcp_server_runtime or name in self._mcp_starting:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Connected to MCP server {name}, ignoring this startup request (timeout={timeout:g})."
|
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。"
|
||||||
)
|
)
|
||||||
self._log_safe_mcp_debug_config(cfg)
|
self._log_safe_mcp_debug_config(cfg)
|
||||||
return
|
return
|
||||||
@@ -482,10 +478,10 @@ class FunctionToolManager:
|
|||||||
)
|
)
|
||||||
except asyncio.TimeoutError as exc:
|
except asyncio.TimeoutError as exc:
|
||||||
raise MCPInitTimeoutError(
|
raise MCPInitTimeoutError(
|
||||||
f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
|
f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
|
||||||
) from exc
|
) from exc
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Failed to initialize MCP client {name}", exc_info=True)
|
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if mcp_client is None:
|
if mcp_client is None:
|
||||||
@@ -495,9 +491,9 @@ class FunctionToolManager:
|
|||||||
async def lifecycle() -> None:
|
async def lifecycle() -> None:
|
||||||
try:
|
try:
|
||||||
await shutdown_event.wait()
|
await shutdown_event.wait()
|
||||||
logger.info(f"Received shutdown signal for MCP client {name}")
|
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"MCP client {name} task was cancelled")
|
logger.debug(f"MCP 客户端 {name} 任务被取消")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await self._terminate_mcp_client(name)
|
await self._terminate_mcp_client(name)
|
||||||
@@ -549,7 +545,7 @@ class FunctionToolManager:
|
|||||||
if strict:
|
if strict:
|
||||||
raise MCPShutdownTimeoutError(pending_names, timeout)
|
raise MCPShutdownTimeoutError(pending_names, timeout)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP server shutdown timeout (%s seconds), the following servers were not fully closed: %s",
|
"MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s",
|
||||||
f"{timeout:g}",
|
f"{timeout:g}",
|
||||||
", ".join(pending_names),
|
", ".join(pending_names),
|
||||||
)
|
)
|
||||||
@@ -572,9 +568,7 @@ class FunctionToolManager:
|
|||||||
try:
|
try:
|
||||||
await mcp_client.cleanup()
|
await mcp_client.cleanup()
|
||||||
except Exception as cleanup_exc: # noqa: BLE001 - only log here
|
except Exception as cleanup_exc: # noqa: BLE001 - only log here
|
||||||
logger.error(
|
logger.error(f"清理 MCP 客户端资源 {name} 失败: {cleanup_exc}")
|
||||||
f"Failed to cleanup MCP client resources {name}: {cleanup_exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
|
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
|
||||||
"""初始化单个MCP客户端"""
|
"""初始化单个MCP客户端"""
|
||||||
@@ -608,7 +602,7 @@ class FunctionToolManager:
|
|||||||
)
|
)
|
||||||
self.func_list.append(func_tool)
|
self.func_list.append(func_tool)
|
||||||
|
|
||||||
logger.info(f"Connected to MCP server {name}, Tools: {tool_names}")
|
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||||
return mcp_client
|
return mcp_client
|
||||||
|
|
||||||
async def _terminate_mcp_client(self, name: str) -> None:
|
async def _terminate_mcp_client(self, name: str) -> None:
|
||||||
@@ -628,7 +622,7 @@ class FunctionToolManager:
|
|||||||
async with self._runtime_lock:
|
async with self._runtime_lock:
|
||||||
self._mcp_server_runtime.pop(name, None)
|
self._mcp_server_runtime.pop(name, None)
|
||||||
self._mcp_starting.discard(name)
|
self._mcp_starting.discard(name)
|
||||||
logger.info(f"Disconnected from MCP server {name}")
|
logger.info(f"已关闭 MCP 服务 {name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Runtime missing but stale tools may still exist after failed flows.
|
# Runtime missing but stale tools may still exist after failed flows.
|
||||||
|
|||||||
@@ -79,7 +79,6 @@ class ProviderManager:
|
|||||||
self._provider_change_hooks: list[
|
self._provider_change_hooks: list[
|
||||||
Callable[[str, ProviderType, str | None], None]
|
Callable[[str, ProviderType, str | None], None]
|
||||||
] = []
|
] = []
|
||||||
self._mcp_init_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
def set_provider_change_callback(
|
def set_provider_change_callback(
|
||||||
self,
|
self,
|
||||||
@@ -331,16 +330,24 @@ class ProviderManager:
|
|||||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||||
|
|
||||||
async def _init_mcp_clients_bg() -> None:
|
# 初始化 MCP Client 连接(等待完成以确保工具可用)
|
||||||
try:
|
strict_mcp_init = os.getenv("ASTRBOT_MCP_INIT_STRICT", "").strip().lower() in {
|
||||||
await self.llm_tools.init_mcp_clients()
|
"1",
|
||||||
except Exception:
|
"true",
|
||||||
logger.error("MCP init background task failed", exc_info=True)
|
"yes",
|
||||||
|
"on",
|
||||||
if self._mcp_init_task is None or self._mcp_init_task.done():
|
}
|
||||||
self._mcp_init_task = asyncio.create_task(
|
mcp_init_summary = await self.llm_tools.init_mcp_clients(
|
||||||
_init_mcp_clients_bg(),
|
raise_on_all_failed=strict_mcp_init
|
||||||
name="provider-manager:mcp-init",
|
)
|
||||||
|
if (
|
||||||
|
mcp_init_summary.total > 0
|
||||||
|
and mcp_init_summary.success == 0
|
||||||
|
and not strict_mcp_init
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"MCP 服务全部初始化失败,系统将继续启动(可设置 "
|
||||||
|
"ASTRBOT_MCP_INIT_STRICT=1 以在此场景下中止启动)。"
|
||||||
)
|
)
|
||||||
|
|
||||||
def dynamic_import_provider(self, type: str) -> None:
|
def dynamic_import_provider(self, type: str) -> None:
|
||||||
@@ -808,17 +815,8 @@ class ProviderManager:
|
|||||||
config.save_config()
|
config.save_config()
|
||||||
# load instance
|
# load instance
|
||||||
await self.load_provider(new_config)
|
await self.load_provider(new_config)
|
||||||
# sync in-memory config for API queries (e.g., embedding provider list)
|
|
||||||
self.providers_config = astrbot_config["provider"]
|
|
||||||
|
|
||||||
async def terminate(self) -> None:
|
async def terminate(self) -> None:
|
||||||
if self._mcp_init_task and not self._mcp_init_task.done():
|
|
||||||
self._mcp_init_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._mcp_init_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for provider_inst in self.provider_insts:
|
for provider_inst in self.provider_insts:
|
||||||
if hasattr(provider_inst, "terminate"):
|
if hasattr(provider_inst, "terminate"):
|
||||||
await provider_inst.terminate() # type: ignore
|
await provider_inst.terminate() # type: ignore
|
||||||
|
|||||||
@@ -281,24 +281,7 @@ class TTSProvider(AbstractProvider):
|
|||||||
accumulated_text += text_part
|
accumulated_text += text_part
|
||||||
|
|
||||||
async def test(self) -> None:
|
async def test(self) -> None:
|
||||||
audio_path = await self.get_audio("hi")
|
await self.get_audio("hi")
|
||||||
|
|
||||||
# 检查生成的音频文件是否有效
|
|
||||||
if not os.path.exists(audio_path):
|
|
||||||
raise Exception("TTS test failed: audio file was not created")
|
|
||||||
|
|
||||||
file_size = os.path.getsize(audio_path)
|
|
||||||
if file_size == 0:
|
|
||||||
raise Exception(
|
|
||||||
"TTS test failed: generated audio file is empty (0 bytes). "
|
|
||||||
"Please check your TTS provider configuration, especially required parameters like group_id for MiniMax."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 清理测试文件
|
|
||||||
try:
|
|
||||||
os.remove(audio_path)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider(AbstractProvider):
|
class EmbeddingProvider(AbstractProvider):
|
||||||
|
|||||||
@@ -276,24 +276,9 @@ class ProviderAnthropic(Provider):
|
|||||||
llm_response.id = completion.id
|
llm_response.id = completion.id
|
||||||
llm_response.usage = self._extract_usage(completion.usage)
|
llm_response.usage = self._extract_usage(completion.usage)
|
||||||
|
|
||||||
# Handle cases where completion only contains ThinkingBlock (e.g., MiniMax max_tokens)
|
# TODO(Soulter): 处理 end_turn 情况
|
||||||
# When stop_reason='max_tokens', the model may return only thinking content
|
|
||||||
# This is valid and should not raise an exception
|
|
||||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||||
# Guard clause: raise early if no valid content at all
|
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||||
if not llm_response.reasoning_content:
|
|
||||||
raise ValueError(
|
|
||||||
f"Anthropic API returned unparsable completion: "
|
|
||||||
f"no text, tool_use, or thinking content found. "
|
|
||||||
f"Completion: {completion}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# We have reasoning content (ThinkingBlock) - this is valid
|
|
||||||
stop_reason = getattr(completion, "stop_reason", "unknown")
|
|
||||||
logger.debug(
|
|
||||||
f"Completion contains only ThinkingBlock (stop_reason={stop_reason})"
|
|
||||||
)
|
|
||||||
llm_response.completion_text = "" # Ensure empty string, not None
|
|
||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from ..register import register_provider_adapter
|
|||||||
|
|
||||||
TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts"
|
TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts"
|
||||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
AZURE_TTS_SUBSCRIPTION_KEY_PATTERN = r"^(?:[a-zA-Z0-9]{32}|[a-zA-Z0-9]{84})$"
|
|
||||||
|
|
||||||
|
|
||||||
class OTTSProvider:
|
class OTTSProvider:
|
||||||
@@ -117,7 +116,7 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
"azure_tts_subscription_key",
|
"azure_tts_subscription_key",
|
||||||
"",
|
"",
|
||||||
).strip()
|
).strip()
|
||||||
if not re.fullmatch(AZURE_TTS_SUBSCRIPTION_KEY_PATTERN, self.subscription_key):
|
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
||||||
raise ValueError("无效的Azure订阅密钥")
|
raise ValueError("无效的Azure订阅密钥")
|
||||||
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
||||||
self.endpoint = (
|
self.endpoint = (
|
||||||
@@ -236,9 +235,9 @@ class AzureTTSProvider(TTSProvider):
|
|||||||
raise ValueError(error_msg) from e
|
raise ValueError(error_msg) from e
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
|
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
|
||||||
if re.fullmatch(AZURE_TTS_SUBSCRIPTION_KEY_PATTERN, key_value):
|
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
|
||||||
return AzureNativeProvider(config, self.provider_settings)
|
return AzureNativeProvider(config, self.provider_settings)
|
||||||
raise ValueError("订阅密钥格式无效,应为32位或84位字母数字或other[...]格式")
|
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
|
||||||
|
|
||||||
async def get_audio(self, text: str) -> str:
|
async def get_audio(self, text: str) -> str:
|
||||||
if isinstance(self.provider, OTTSProvider):
|
if isinstance(self.provider, OTTSProvider):
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
model: str,
|
model: str,
|
||||||
text: str,
|
text: str,
|
||||||
) -> tuple[bytes | None, str]:
|
) -> tuple[bytes | None, str]:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
||||||
audio_bytes = await self._extract_audio_from_response(response)
|
audio_bytes = await self._extract_audio_from_response(response)
|
||||||
if not audio_bytes:
|
if not audio_bytes:
|
||||||
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
voice=self.voice,
|
voice=self.voice,
|
||||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
audio_bytes = await loop.run_in_executor(
|
audio_bytes = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
synthesizer.call,
|
synthesizer.call,
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
|
|||||||
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
||||||
path = os.path.join(temp_dir, filename)
|
path = os.path.join(temp_dir, filename)
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def _generate(save_path: str) -> None:
|
def _generate(save_path: str) -> None:
|
||||||
assert genie is not None
|
assert genie is not None
|
||||||
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
|
|||||||
text_queue: asyncio.Queue[str | None],
|
text_queue: asyncio.Queue[str | None],
|
||||||
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
||||||
) -> None:
|
) -> None:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
text = await text_queue.get()
|
text = await text_queue.get()
|
||||||
|
|||||||
@@ -13,11 +13,3 @@ class ProviderGroq(ProviderOpenAIOfficial):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(provider_config, provider_settings)
|
super().__init__(provider_config, provider_settings)
|
||||||
self.reasoning_key = "reasoning"
|
self.reasoning_key = "reasoning"
|
||||||
|
|
||||||
def _finally_convert_payload(self, payloads: dict) -> None:
|
|
||||||
"""Groq rejects assistant history items that include reasoning_content."""
|
|
||||||
super()._finally_convert_payload(payloads)
|
|
||||||
for message in payloads.get("messages", []):
|
|
||||||
if message.get("role") == "assistant":
|
|
||||||
message.pop("reasoning_content", None)
|
|
||||||
message.pop("reasoning", None)
|
|
||||||
|
|||||||
@@ -154,14 +154,6 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|||||||
audio_stream = self._call_tts_stream(text)
|
audio_stream = self._call_tts_stream(text)
|
||||||
audio = await self._audio_play(audio_stream)
|
audio = await self._audio_play(audio_stream)
|
||||||
|
|
||||||
# 检查音频数据是否为空
|
|
||||||
if not audio or len(audio) == 0:
|
|
||||||
raise Exception(
|
|
||||||
"MiniMax TTS API returned empty audio data. "
|
|
||||||
"Please verify your configuration, especially the 'group_id' parameter. "
|
|
||||||
"You can find your group_id in Account Management -> Basic Information on the MiniMax platform."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 结果保存至文件
|
# 结果保存至文件
|
||||||
with open(path, "wb") as file:
|
with open(path, "wb") as file:
|
||||||
file.write(audio)
|
file.write(audio)
|
||||||
@@ -169,4 +161,4 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
raise Exception(f"MiniMax TTS API request failed: {e!s}")
|
raise e
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
state.handle_chunk(chunk)
|
state.handle_chunk(chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Saving chunk state error: " + str(e))
|
logger.warning("Saving chunk state error: " + str(e))
|
||||||
if not chunk.choices:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
# logger.debug(f"chunk delta: {delta}")
|
# logger.debug(f"chunk delta: {delta}")
|
||||||
@@ -322,7 +322,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
if reasoning:
|
if reasoning:
|
||||||
llm_response.reasoning_content = reasoning
|
llm_response.reasoning_content = reasoning
|
||||||
_y = True
|
_y = True
|
||||||
if delta and delta.content:
|
if delta.content:
|
||||||
# Don't strip streaming chunks to preserve spaces between words
|
# Don't strip streaming chunks to preserve spaces between words
|
||||||
completion_text = self._normalize_content(delta.content, strip=False)
|
completion_text = self._normalize_content(delta.content, strip=False)
|
||||||
llm_response.result_chain = MessageChain(
|
llm_response.result_chain = MessageChain(
|
||||||
@@ -345,7 +345,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Extract reasoning content from OpenAI ChatCompletion if available."""
|
"""Extract reasoning content from OpenAI ChatCompletion if available."""
|
||||||
reasoning_text = ""
|
reasoning_text = ""
|
||||||
if not completion.choices:
|
if len(completion.choices) == 0:
|
||||||
return reasoning_text
|
return reasoning_text
|
||||||
if isinstance(completion, ChatCompletion):
|
if isinstance(completion, ChatCompletion):
|
||||||
choice = completion.choices[0]
|
choice = completion.choices[0]
|
||||||
@@ -468,7 +468,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
"""Parse OpenAI ChatCompletion into LLMResponse"""
|
"""Parse OpenAI ChatCompletion into LLMResponse"""
|
||||||
llm_response = LLMResponse("assistant")
|
llm_response = LLMResponse("assistant")
|
||||||
|
|
||||||
if not completion.choices:
|
if len(completion.choices) == 0:
|
||||||
raise Exception("API 返回的 completion 为空。")
|
raise Exception("API 返回的 completion 为空。")
|
||||||
choice = completion.choices[0]
|
choice = completion.choices[0]
|
||||||
|
|
||||||
@@ -629,8 +629,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
# 最后一次不等待
|
# 最后一次不等待
|
||||||
if retry_cnt < max_retries - 1:
|
if retry_cnt < max_retries - 1:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
if chosen_key in available_api_keys:
|
available_api_keys.remove(chosen_key)
|
||||||
available_api_keys.remove(chosen_key)
|
|
||||||
if len(available_api_keys) > 0:
|
if len(available_api_keys) > 0:
|
||||||
chosen_key = random.choice(available_api_keys)
|
chosen_key = random.choice(available_api_keys)
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -16,7 +16,4 @@ class ProviderOpenRouter(ProviderOpenAIOfficial):
|
|||||||
self.client._custom_headers["HTTP-Referer"] = ( # type: ignore
|
self.client._custom_headers["HTTP-Referer"] = ( # type: ignore
|
||||||
"https://github.com/AstrBotDevs/AstrBot"
|
"https://github.com/AstrBotDevs/AstrBot"
|
||||||
)
|
)
|
||||||
self.client._custom_headers["X-OpenRouter-Title"] = "AstrBot" # type: ignore
|
self.client._custom_headers["X-TITLE"] = "AstrBot" # type: ignore
|
||||||
self.client._custom_headers["X-OpenRouter-Categories"] = (
|
|
||||||
"general-chat,personal-agent" # type: ignore
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
|
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
|
||||||
|
|
||||||
# 将模型加载放到线程池中执行
|
# 将模型加载放到线程池中执行
|
||||||
self.model = await asyncio.get_running_loop().run_in_executor(
|
self.model = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
||||||
)
|
)
|
||||||
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
audio_url = output_path
|
audio_url = output_path
|
||||||
|
|
||||||
# 使用 run_in_executor 来调用模型进行识别
|
# 使用 run_in_executor 来调用模型进行识别
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
res = await loop.run_in_executor(
|
res = await loop.run_in_executor(
|
||||||
None, # 使用默认的线程池
|
None, # 使用默认的线程池
|
||||||
lambda: cast(SenseVoiceSmall, self.model)(
|
lambda: cast(SenseVoiceSmall, self.model)(
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
||||||
self.model = await loop.run_in_executor(
|
self.model = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_text(self, audio_url: str) -> str:
|
async def get_text(self, audio_url: str) -> str:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
is_tencent = False
|
is_tencent = False
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shlex
|
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
@@ -80,59 +79,7 @@ def _parse_frontmatter_description(text: str) -> str:
|
|||||||
|
|
||||||
# Regex for sanitizing paths used in prompt examples — only allow
|
# Regex for sanitizing paths used in prompt examples — only allow
|
||||||
# safe path characters to prevent prompt injection via crafted skill paths.
|
# safe path characters to prevent prompt injection via crafted skill paths.
|
||||||
_SAFE_PATH_RE = re.compile(r"[^\w./ ,()'\-]", re.UNICODE)
|
_SAFE_PATH_RE = re.compile(r"[^A-Za-z0-9_./ -]")
|
||||||
_WINDOWS_DRIVE_PATH_RE = re.compile(r"^[A-Za-z]:(?:/|\\)")
|
|
||||||
_WINDOWS_UNC_PATH_RE = re.compile(r"^(//|\\\\)[^/\\]+[/\\][^/\\]+")
|
|
||||||
_CONTROL_CHARS_RE = re.compile(r"[\x00-\x1F\x7F]")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_windows_prompt_path(path: str) -> bool:
|
|
||||||
if os.name != "nt":
|
|
||||||
return False
|
|
||||||
return bool(_WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path))
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_prompt_path_for_prompt(path: str) -> str:
|
|
||||||
if not path:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if _WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path):
|
|
||||||
path = path.replace("\\", "/")
|
|
||||||
|
|
||||||
drive_prefix = ""
|
|
||||||
if _WINDOWS_DRIVE_PATH_RE.match(path):
|
|
||||||
drive_prefix = path[:2]
|
|
||||||
path = path[2:]
|
|
||||||
|
|
||||||
path = path.replace("`", "")
|
|
||||||
path = _CONTROL_CHARS_RE.sub("", path)
|
|
||||||
sanitized = _SAFE_PATH_RE.sub("", path)
|
|
||||||
return f"{drive_prefix}{sanitized}"
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_prompt_description(description: str) -> str:
|
|
||||||
description = description.replace("`", "")
|
|
||||||
description = _CONTROL_CHARS_RE.sub(" ", description)
|
|
||||||
description = " ".join(description.split())
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_skill_display_name(name: str) -> str:
|
|
||||||
if _SKILL_NAME_RE.fullmatch(name):
|
|
||||||
return name
|
|
||||||
return "<invalid_skill_name>"
|
|
||||||
|
|
||||||
|
|
||||||
def _build_skill_read_command_example(path: str) -> str:
|
|
||||||
if path == "<skills_root>/<skill_name>/SKILL.md":
|
|
||||||
return f"cat {path}"
|
|
||||||
if _is_windows_prompt_path(path):
|
|
||||||
command = "type"
|
|
||||||
path_arg = f'"{path}"'
|
|
||||||
else:
|
|
||||||
command = "cat"
|
|
||||||
path_arg = shlex.quote(path)
|
|
||||||
return f"{command} {path_arg}"
|
|
||||||
|
|
||||||
|
|
||||||
def build_skills_prompt(skills: list[SkillInfo]) -> str:
|
def build_skills_prompt(skills: list[SkillInfo]) -> str:
|
||||||
@@ -145,37 +92,16 @@ def build_skills_prompt(skills: list[SkillInfo]) -> str:
|
|||||||
skills_lines: list[str] = []
|
skills_lines: list[str] = []
|
||||||
example_path = ""
|
example_path = ""
|
||||||
for skill in skills:
|
for skill in skills:
|
||||||
display_name = _sanitize_skill_display_name(skill.name)
|
|
||||||
|
|
||||||
description = skill.description or "No description"
|
description = skill.description or "No description"
|
||||||
if skill.source_type == "sandbox_only":
|
|
||||||
description = _sanitize_prompt_description(description)
|
|
||||||
if not description:
|
|
||||||
description = "Read SKILL.md for details."
|
|
||||||
|
|
||||||
if skill.source_type == "sandbox_only":
|
|
||||||
rendered_path = (
|
|
||||||
f"{str(SANDBOX_WORKSPACE_ROOT)}/{str(SANDBOX_SKILLS_ROOT)}/"
|
|
||||||
f"{display_name}/SKILL.md"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rendered_path = _sanitize_prompt_path_for_prompt(skill.path)
|
|
||||||
if not rendered_path:
|
|
||||||
rendered_path = "<skills_root>/<skill_name>/SKILL.md"
|
|
||||||
|
|
||||||
skills_lines.append(
|
skills_lines.append(
|
||||||
f"- **{display_name}**: {description}\n File: `{rendered_path}`"
|
f"- **{skill.name}**: {description}\n File: `{skill.path}`"
|
||||||
)
|
)
|
||||||
if not example_path:
|
if not example_path:
|
||||||
example_path = rendered_path
|
example_path = skill.path
|
||||||
skills_block = "\n".join(skills_lines)
|
skills_block = "\n".join(skills_lines)
|
||||||
# Sanitize example_path — it may originate from sandbox cache (untrusted)
|
# Sanitize example_path — it may originate from sandbox cache (untrusted)
|
||||||
if example_path == "<skills_root>/<skill_name>/SKILL.md":
|
example_path = _SAFE_PATH_RE.sub("", example_path) if example_path else ""
|
||||||
example_path = "<skills_root>/<skill_name>/SKILL.md"
|
example_path = example_path or "<skills_root>/<skill_name>/SKILL.md"
|
||||||
else:
|
|
||||||
example_path = _sanitize_prompt_path_for_prompt(example_path)
|
|
||||||
example_path = example_path or "<skills_root>/<skill_name>/SKILL.md"
|
|
||||||
example_command = _build_skill_read_command_example(example_path)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
"## Skills\n\n"
|
"## Skills\n\n"
|
||||||
@@ -193,9 +119,8 @@ def build_skills_prompt(skills: list[SkillInfo]) -> str:
|
|||||||
"*Never silently skip a matching skill* — either use it or briefly "
|
"*Never silently skip a matching skill* — either use it or briefly "
|
||||||
"explain why you chose not to.\n"
|
"explain why you chose not to.\n"
|
||||||
"3. **Mandatory grounding** — Before executing any skill you MUST "
|
"3. **Mandatory grounding** — Before executing any skill you MUST "
|
||||||
"first read its `SKILL.md` by running a shell command compatible "
|
"first read its `SKILL.md` by running a shell command with the "
|
||||||
"with the current runtime shell and using the **absolute path** "
|
f"**absolute path** shown above (e.g. `cat {example_path}`). "
|
||||||
f"shown above (e.g. `{example_command}`). "
|
|
||||||
"Never rely on memory or assumptions about a skill's content.\n"
|
"Never rely on memory or assumptions about a skill's content.\n"
|
||||||
"4. **Progressive disclosure** — Load only what is directly "
|
"4. **Progressive disclosure** — Load only what is directly "
|
||||||
"referenced from `SKILL.md`:\n"
|
"referenced from `SKILL.md`:\n"
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
"""插件的重载、启停、安装、卸载等操作。"""
|
"""插件的重载、启停、安装、卸载等操作。"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import traceback
|
import traceback
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
@@ -16,12 +14,7 @@ import yaml
|
|||||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||||
from packaging.version import InvalidVersion, Version
|
from packaging.version import InvalidVersion, Version
|
||||||
|
|
||||||
from astrbot.core import (
|
from astrbot.core import logger, pip_installer, sp
|
||||||
DependencyConflictError,
|
|
||||||
logger,
|
|
||||||
pip_installer,
|
|
||||||
sp,
|
|
||||||
)
|
|
||||||
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
|
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from astrbot.core.config.default import VERSION
|
from astrbot.core.config.default import VERSION
|
||||||
@@ -31,13 +24,9 @@ from astrbot.core.utils.astrbot_path import (
|
|||||||
get_astrbot_config_path,
|
get_astrbot_config_path,
|
||||||
get_astrbot_path,
|
get_astrbot_path,
|
||||||
get_astrbot_plugin_path,
|
get_astrbot_plugin_path,
|
||||||
get_astrbot_temp_path,
|
|
||||||
)
|
)
|
||||||
from astrbot.core.utils.io import remove_dir
|
from astrbot.core.utils.io import remove_dir
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.utils.requirements_utils import (
|
|
||||||
plan_missing_requirements_install,
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import StarMetadata
|
from . import StarMetadata
|
||||||
from .command_management import sync_command_configs
|
from .command_management import sync_command_configs
|
||||||
@@ -59,97 +48,6 @@ class PluginVersionIncompatibleError(Exception):
|
|||||||
"""Raised when plugin astrbot_version is incompatible with current AstrBot."""
|
"""Raised when plugin astrbot_version is incompatible with current AstrBot."""
|
||||||
|
|
||||||
|
|
||||||
class PluginDependencyInstallError(Exception):
|
|
||||||
"""Raised when plugin dependency installation fails."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
plugin_label: str,
|
|
||||||
requirements_path: str,
|
|
||||||
error: Exception,
|
|
||||||
) -> None:
|
|
||||||
message = f"插件 {plugin_label} 依赖安装失败: {error!s}"
|
|
||||||
super().__init__(message)
|
|
||||||
self.plugin_label = plugin_label
|
|
||||||
self.requirements_path = requirements_path
|
|
||||||
self.error = error
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _temporary_filtered_requirements_file(
|
|
||||||
*,
|
|
||||||
install_lines: tuple[str, ...],
|
|
||||||
):
|
|
||||||
filtered_requirements_path: str | None = None
|
|
||||||
temp_dir = get_astrbot_temp_path()
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w",
|
|
||||||
suffix="_plugin_requirements.txt",
|
|
||||||
delete=False,
|
|
||||||
dir=temp_dir,
|
|
||||||
encoding="utf-8",
|
|
||||||
) as filtered_requirements_file:
|
|
||||||
filtered_requirements_file.write("\n".join(install_lines) + "\n")
|
|
||||||
filtered_requirements_path = filtered_requirements_file.name
|
|
||||||
|
|
||||||
yield filtered_requirements_path
|
|
||||||
finally:
|
|
||||||
if filtered_requirements_path and os.path.exists(filtered_requirements_path):
|
|
||||||
try:
|
|
||||||
os.remove(filtered_requirements_path)
|
|
||||||
except OSError as exc:
|
|
||||||
logger.warning(
|
|
||||||
"删除临时插件依赖文件失败:%s(路径:%s)",
|
|
||||||
exc,
|
|
||||||
filtered_requirements_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _install_requirements_with_precheck(
|
|
||||||
*,
|
|
||||||
plugin_label: str,
|
|
||||||
requirements_path: str,
|
|
||||||
) -> None:
|
|
||||||
install_plan = plan_missing_requirements_install(requirements_path)
|
|
||||||
|
|
||||||
if install_plan is None:
|
|
||||||
logger.info(
|
|
||||||
f"正在安装插件 {plugin_label} 的依赖库(缺失依赖预检查不可裁剪,回退到完整安装): "
|
|
||||||
f"{requirements_path}"
|
|
||||||
)
|
|
||||||
await pip_installer.install(requirements_path=requirements_path)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not install_plan.missing_names:
|
|
||||||
logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not install_plan.install_lines:
|
|
||||||
fallback_reason = install_plan.fallback_reason or "unknown reason"
|
|
||||||
logger.info(
|
|
||||||
"检测到插件 %s 缺失依赖,但无法安全裁剪 requirements,回退到完整安装: %s (%s)",
|
|
||||||
plugin_label,
|
|
||||||
requirements_path,
|
|
||||||
fallback_reason,
|
|
||||||
)
|
|
||||||
await pip_installer.install(requirements_path=requirements_path)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: "
|
|
||||||
f"{requirements_path} -> {sorted(install_plan.missing_names)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _temporary_filtered_requirements_file(
|
|
||||||
install_lines=install_plan.install_lines,
|
|
||||||
) as filtered_requirements_path:
|
|
||||||
await pip_installer.install(requirements_path=filtered_requirements_path)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
||||||
from .star_tools import StarTools
|
from .star_tools import StarTools
|
||||||
@@ -300,37 +198,15 @@ class PluginManager:
|
|||||||
to_update.append(p.root_dir_name)
|
to_update.append(p.root_dir_name)
|
||||||
for p in to_update:
|
for p in to_update:
|
||||||
plugin_path = os.path.join(plugin_dir, p)
|
plugin_path = os.path.join(plugin_dir, p)
|
||||||
await self._ensure_plugin_requirements(plugin_path, p)
|
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||||
|
pth = os.path.join(plugin_path, "requirements.txt")
|
||||||
|
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||||
|
try:
|
||||||
|
await pip_installer.install(requirements_path=pth)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _ensure_plugin_requirements(
|
|
||||||
self,
|
|
||||||
plugin_dir_path: str,
|
|
||||||
plugin_label: str,
|
|
||||||
) -> None:
|
|
||||||
requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
|
|
||||||
if not os.path.exists(requirements_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _install_requirements_with_precheck(
|
|
||||||
plugin_label=plugin_label,
|
|
||||||
requirements_path=requirements_path,
|
|
||||||
)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
raise
|
|
||||||
except DependencyConflictError as e:
|
|
||||||
logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
dependency_error = PluginDependencyInstallError(
|
|
||||||
plugin_label=plugin_label,
|
|
||||||
requirements_path=requirements_path,
|
|
||||||
error=e,
|
|
||||||
)
|
|
||||||
logger.exception(str(dependency_error))
|
|
||||||
raise dependency_error from e
|
|
||||||
|
|
||||||
async def _import_plugin_with_dependency_recovery(
|
async def _import_plugin_with_dependency_recovery(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -546,7 +422,7 @@ class PluginManager:
|
|||||||
root_dir_name: str,
|
root_dir_name: str,
|
||||||
plugin_dir_path: str,
|
plugin_dir_path: str,
|
||||||
reserved: bool,
|
reserved: bool,
|
||||||
error: BaseException | str,
|
error: Exception | str,
|
||||||
error_trace: str,
|
error_trace: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
record: dict = {
|
record: dict = {
|
||||||
@@ -619,9 +495,6 @@ class PluginManager:
|
|||||||
|
|
||||||
self._cleanup_plugin_state(dir_name)
|
self._cleanup_plugin_state(dir_name)
|
||||||
|
|
||||||
plugin_path = os.path.join(self.plugin_store_path, dir_name)
|
|
||||||
await self._ensure_plugin_requirements(plugin_path, dir_name)
|
|
||||||
|
|
||||||
success, error = await self.load(specified_dir_name=dir_name)
|
success, error = await self.load(specified_dir_name=dir_name)
|
||||||
if success:
|
if success:
|
||||||
self.failed_plugin_dict.pop(dir_name, None)
|
self.failed_plugin_dict.pop(dir_name, None)
|
||||||
@@ -1205,10 +1078,6 @@ class PluginManager:
|
|||||||
|
|
||||||
# reload the plugin
|
# reload the plugin
|
||||||
dir_name = os.path.basename(plugin_path)
|
dir_name = os.path.basename(plugin_path)
|
||||||
await self._ensure_plugin_requirements(
|
|
||||||
plugin_path,
|
|
||||||
dir_name,
|
|
||||||
)
|
|
||||||
success, error_message = await self.load(
|
success, error_message = await self.load(
|
||||||
specified_dir_name=dir_name,
|
specified_dir_name=dir_name,
|
||||||
ignore_version_check=ignore_version_check,
|
ignore_version_check=ignore_version_check,
|
||||||
@@ -1448,12 +1317,6 @@ class PluginManager:
|
|||||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||||
|
|
||||||
await self.updator.update(plugin, proxy=proxy)
|
await self.updator.update(plugin, proxy=proxy)
|
||||||
if plugin.root_dir_name:
|
|
||||||
plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
|
||||||
await self._ensure_plugin_requirements(
|
|
||||||
plugin_dir_path,
|
|
||||||
plugin_name,
|
|
||||||
)
|
|
||||||
await self.reload(plugin_name)
|
await self.reload(plugin_name)
|
||||||
|
|
||||||
async def turn_off_plugin(self, plugin_name: str) -> None:
|
async def turn_off_plugin(self, plugin_name: str) -> None:
|
||||||
@@ -1511,23 +1374,10 @@ class PluginManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||||
loop = asyncio.get_running_loop()
|
asyncio.get_event_loop().run_in_executor(
|
||||||
future = loop.run_in_executor(
|
|
||||||
None,
|
None,
|
||||||
star_metadata.star_cls.__del__,
|
star_metadata.star_cls.__del__,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _log_del_exception(fut: asyncio.Future) -> None:
|
|
||||||
if fut.cancelled():
|
|
||||||
return
|
|
||||||
if (exc := fut.exception()) is not None:
|
|
||||||
logger.error(
|
|
||||||
"插件 %s 在 __del__ 中抛出了异常:%r",
|
|
||||||
star_metadata.name,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
future.add_done_callback(_log_del_exception)
|
|
||||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||||
await star_metadata.star_cls.terminate()
|
await star_metadata.star_cls.terminate()
|
||||||
|
|
||||||
@@ -1625,7 +1475,6 @@ class PluginManager:
|
|||||||
os.remove(zip_file_path)
|
os.remove(zip_file_path)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.warning(f"删除插件压缩包失败: {e!s}")
|
logger.warning(f"删除插件压缩包失败: {e!s}")
|
||||||
await self._ensure_plugin_requirements(desti_dir, dir_name)
|
|
||||||
# await self.reload()
|
# await self.reload()
|
||||||
success, error_message = await self.load(
|
success, error_message = await self.load(
|
||||||
specified_dir_name=dir_name,
|
specified_dir_name=dir_name,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class CreateActiveCronTool(FunctionTool[AstrAgentContext]):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"cron_expression": {
|
"cron_expression": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Cron expression defining recurring schedule (e.g., '0 8 * * *' or '0 23 * * mon-fri'). Prefer named weekdays like 'mon-fri' or 'sat,sun' instead of numeric day-of-week ranges such as '1-5' to avoid ambiguity across cron implementations.",
|
"description": "Cron expression defining recurring schedule (e.g., '0 8 * * *').",
|
||||||
},
|
},
|
||||||
"run_at": {
|
"run_at": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|||||||
@@ -25,22 +25,12 @@ class UmopConfigRouter:
|
|||||||
)
|
)
|
||||||
self.umop_to_conf_id = sp_data
|
self.umop_to_conf_id = sp_data
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _split_umo(umo: str) -> tuple[str, str, str] | None:
|
|
||||||
"""将 UMO 拆分为 3 个部分,同时保留 session_id 中的 ':'"""
|
|
||||||
if not isinstance(umo, str):
|
|
||||||
return None
|
|
||||||
parts = umo.split(":", 2)
|
|
||||||
if len(parts) != 3:
|
|
||||||
return None
|
|
||||||
return parts[0], parts[1], parts[2]
|
|
||||||
|
|
||||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||||
p1_ls = self._split_umo(p1)
|
p1_ls = p1.split(":")
|
||||||
p2_ls = self._split_umo(p2)
|
p2_ls = p2.split(":")
|
||||||
|
|
||||||
if p1_ls is None or p2_ls is None:
|
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||||
return False # 非法格式
|
return False # 非法格式
|
||||||
|
|
||||||
return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls))
|
return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls))
|
||||||
@@ -72,7 +62,7 @@ class UmopConfigRouter:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
for part in new_routing:
|
for part in new_routing:
|
||||||
if self._split_umo(part) is None:
|
if not isinstance(part, str) or len(part.split(":")) != 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
"umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
||||||
)
|
)
|
||||||
@@ -91,7 +81,7 @@ class UmopConfigRouter:
|
|||||||
ValueError: 如果 umo 格式不正确
|
ValueError: 如果 umo 格式不正确
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._split_umo(umo) is None:
|
if not isinstance(umo, str) or len(umo.split(":")) != 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
||||||
)
|
)
|
||||||
@@ -109,7 +99,7 @@ class UmopConfigRouter:
|
|||||||
ValueError: 当 umo 格式不正确时抛出
|
ValueError: 当 umo 格式不正确时抛出
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._split_umo(umo) is None:
|
if not isinstance(umo, str) or len(umo.split(":")) != 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,121 +0,0 @@
|
|||||||
import contextlib
|
|
||||||
import functools
|
|
||||||
import importlib.metadata as importlib_metadata
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from collections.abc import Iterator
|
|
||||||
|
|
||||||
from packaging.requirements import Requirement
|
|
||||||
|
|
||||||
from astrbot.core.utils.requirements_utils import (
|
|
||||||
canonicalize_distribution_name,
|
|
||||||
collect_installed_distribution_versions,
|
|
||||||
get_requirement_check_paths,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger("astrbot")
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_core_dist_name(core_dist_name: str | None) -> str | None:
|
|
||||||
if core_dist_name:
|
|
||||||
try:
|
|
||||||
importlib_metadata.distribution(core_dist_name)
|
|
||||||
return core_dist_name
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
importlib_metadata.distribution("AstrBot")
|
|
||||||
return "AstrBot"
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not __package__:
|
|
||||||
return None
|
|
||||||
|
|
||||||
top_pkg = __package__.split(".")[0]
|
|
||||||
for dist in importlib_metadata.distributions():
|
|
||||||
try:
|
|
||||||
top_level = dist.read_text("top_level.txt") or ""
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
if top_pkg in top_level.splitlines():
|
|
||||||
if "Name" in dist.metadata:
|
|
||||||
return dist.metadata["Name"]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]:
|
|
||||||
try:
|
|
||||||
resolved_core_dist_name = _resolve_core_dist_name(core_dist_name)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("解析核心分发名称失败: %s", exc)
|
|
||||||
return ()
|
|
||||||
|
|
||||||
if not resolved_core_dist_name:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
try:
|
|
||||||
dist = importlib_metadata.distribution(resolved_core_dist_name)
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
return ()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc)
|
|
||||||
return ()
|
|
||||||
|
|
||||||
if not dist or not dist.requires:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
installed = collect_installed_distribution_versions(get_requirement_check_paths())
|
|
||||||
if not installed:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
constraints: list[str] = []
|
|
||||||
for req_str in dist.requires:
|
|
||||||
try:
|
|
||||||
req = Requirement(req_str)
|
|
||||||
if req.marker and not req.marker.evaluate():
|
|
||||||
continue
|
|
||||||
name = canonicalize_distribution_name(req.name)
|
|
||||||
if name in installed:
|
|
||||||
constraints.append(f"{name}=={installed[name]}")
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return tuple(constraints)
|
|
||||||
|
|
||||||
|
|
||||||
class CoreConstraintsProvider:
|
|
||||||
def __init__(self, core_dist_name: str | None) -> None:
|
|
||||||
self._core_dist_name = core_dist_name
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def constraints_file(self) -> Iterator[str | None]:
|
|
||||||
constraints = _get_core_constraints(self._core_dist_name)
|
|
||||||
if not constraints:
|
|
||||||
yield None
|
|
||||||
return
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
try:
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8"
|
|
||||||
) as f:
|
|
||||||
f.write("\n".join(constraints))
|
|
||||||
path = f.name
|
|
||||||
logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("创建临时约束文件失败: %s", exc)
|
|
||||||
yield None
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield path
|
|
||||||
finally:
|
|
||||||
if path and os.path.exists(path):
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
os.remove(path)
|
|
||||||
+103
-435
@@ -7,71 +7,21 @@ import io
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shlex
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
||||||
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
|
|
||||||
from astrbot.core.utils.requirements_utils import (
|
|
||||||
canonicalize_distribution_name as _canonicalize_distribution_name,
|
|
||||||
)
|
|
||||||
from astrbot.core.utils.requirements_utils import (
|
|
||||||
extract_requirement_name,
|
|
||||||
extract_requirement_names,
|
|
||||||
parse_package_install_input,
|
|
||||||
)
|
|
||||||
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
||||||
|
|
||||||
logger = logging.getLogger("astrbot")
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
|
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
|
||||||
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
|
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
|
||||||
_PIP_FAILURE_PATTERNS = {
|
|
||||||
"error_prefix": re.compile(r"^\s*error:", re.IGNORECASE),
|
|
||||||
"user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE),
|
|
||||||
"resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE),
|
|
||||||
"cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE),
|
|
||||||
"conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE),
|
|
||||||
"constraint": re.compile(r"\(constraint\)", re.IGNORECASE),
|
|
||||||
"dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE),
|
|
||||||
}
|
|
||||||
_SENSITIVE_PIP_VALUE_KEYS = frozenset(
|
|
||||||
{"password", "passwd", "pass", "api_token", "token", "auth_token"}
|
|
||||||
)
|
|
||||||
_MAX_PIP_OUTPUT_LINES = 200
|
|
||||||
|
|
||||||
|
|
||||||
class DependencyConflictError(Exception):
|
def _canonicalize_distribution_name(name: str) -> str:
|
||||||
"""Raised when pip encounters a dependency conflict."""
|
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, message: str, errors: list[str], *, is_core_conflict: bool
|
|
||||||
) -> None:
|
|
||||||
super().__init__(message)
|
|
||||||
self.errors = errors
|
|
||||||
self.is_core_conflict = is_core_conflict
|
|
||||||
|
|
||||||
|
|
||||||
class PipInstallError(Exception):
|
|
||||||
"""Raised when pip install fails without a classified dependency conflict."""
|
|
||||||
|
|
||||||
def __init__(self, message: str, *, code: int) -> None:
|
|
||||||
super().__init__(message)
|
|
||||||
self.code = code
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipConflictContext:
|
|
||||||
relevant_lines: list[str]
|
|
||||||
requested_lines: list[str]
|
|
||||||
dependency_detail_lines: list[str]
|
|
||||||
constraint_lines: list[str]
|
|
||||||
has_strong_conflict_signal: bool
|
|
||||||
has_contextual_conflict_signal: bool
|
|
||||||
|
|
||||||
|
|
||||||
def _get_pip_main():
|
def _get_pip_main():
|
||||||
@@ -91,12 +41,11 @@ def _get_pip_main():
|
|||||||
return pip_main
|
return pip_main
|
||||||
|
|
||||||
|
|
||||||
def _prepend_sys_path(path: str) -> None:
|
def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]:
|
||||||
normalized_target = os.path.realpath(path)
|
stream = io.StringIO()
|
||||||
sys.path[:] = [
|
with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
|
||||||
item for item in sys.path if os.path.realpath(item) != normalized_target
|
result_code = pip_main(args)
|
||||||
]
|
return result_code, stream.getvalue()
|
||||||
sys.path.insert(0, normalized_target)
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
|
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
|
||||||
@@ -110,258 +59,76 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No
|
|||||||
handler.close()
|
handler.close()
|
||||||
|
|
||||||
|
|
||||||
def _get_trusted_host_for_index_url(index_url: str) -> str | None:
|
def _prepend_sys_path(path: str) -> None:
|
||||||
parsed = urlparse(index_url if "://" in index_url else f"//{index_url}")
|
normalized_target = os.path.realpath(path)
|
||||||
host = parsed.hostname
|
sys.path[:] = [
|
||||||
if host == "mirrors.aliyun.com":
|
item for item in sys.path if os.path.realpath(item) != normalized_target
|
||||||
return host
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_sensitive_pip_key(raw_key: str) -> str:
|
|
||||||
return raw_key.lstrip("-").replace("-", "_").lower()
|
|
||||||
|
|
||||||
|
|
||||||
def _is_sensitive_pip_value_key(raw_key: str) -> bool:
|
|
||||||
return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS
|
|
||||||
|
|
||||||
|
|
||||||
def _redact_url_credentials(raw_value: str) -> str:
|
|
||||||
"""Redact URL credentials and known inline secret values for safe logging."""
|
|
||||||
parsed = urlparse(raw_value)
|
|
||||||
if parsed.netloc and "@" in parsed.netloc:
|
|
||||||
hostname = parsed.hostname or ""
|
|
||||||
port = f":{parsed.port}" if parsed.port else ""
|
|
||||||
return parsed._replace(netloc=f"<redacted>@{hostname}{port}").geturl()
|
|
||||||
|
|
||||||
if raw_value.startswith("--"):
|
|
||||||
option, separator, _ = raw_value.partition("=")
|
|
||||||
if separator and _is_sensitive_pip_value_key(option):
|
|
||||||
return f"{option}=****"
|
|
||||||
return raw_value
|
|
||||||
|
|
||||||
key, separator, _ = raw_value.partition("=")
|
|
||||||
if separator and _is_sensitive_pip_value_key(key):
|
|
||||||
return f"{key}=****"
|
|
||||||
|
|
||||||
return raw_value
|
|
||||||
|
|
||||||
|
|
||||||
def _redact_pip_args_for_logging(args: list[str]) -> list[str]:
|
|
||||||
redacted_args: list[str] = []
|
|
||||||
redact_next_value = False
|
|
||||||
|
|
||||||
for arg in args:
|
|
||||||
if redact_next_value:
|
|
||||||
redacted_args.append("****")
|
|
||||||
redact_next_value = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if arg.startswith("--") and "=" in arg:
|
|
||||||
option, value = arg.split("=", 1)
|
|
||||||
if _is_sensitive_pip_value_key(option):
|
|
||||||
redacted_args.append(f"{option}=****")
|
|
||||||
else:
|
|
||||||
redacted_args.append(f"{option}={_redact_url_credentials(value)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if arg.startswith("-i") and arg != "-i":
|
|
||||||
redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if _is_sensitive_pip_value_key(arg):
|
|
||||||
redacted_args.append(arg)
|
|
||||||
redact_next_value = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
redacted_args.append(_redact_url_credentials(arg))
|
|
||||||
|
|
||||||
return redacted_args
|
|
||||||
|
|
||||||
|
|
||||||
def _package_specs_override_index(package_specs: list[str]) -> bool:
|
|
||||||
for index, spec in enumerate(package_specs):
|
|
||||||
if spec == "--no-index":
|
|
||||||
return True
|
|
||||||
if spec in {"-i", "--index-url"}:
|
|
||||||
if index + 1 < len(package_specs):
|
|
||||||
return True
|
|
||||||
continue
|
|
||||||
if spec.startswith("--index-url="):
|
|
||||||
return True
|
|
||||||
if spec.startswith("-i") and spec != "-i":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class _StreamingLogWriter(io.TextIOBase):
|
|
||||||
def __init__(self, log_func, *, max_lines: int | None = None) -> None:
|
|
||||||
self._log_func = log_func
|
|
||||||
self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES)
|
|
||||||
self._buffer = ""
|
|
||||||
|
|
||||||
def write(self, text: str) -> int:
|
|
||||||
if not text:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
self._buffer += text.replace("\r\n", "\n").replace("\r", "\n")
|
|
||||||
while "\n" in self._buffer:
|
|
||||||
raw_line, self._buffer = self._buffer.split("\n", 1)
|
|
||||||
line = raw_line.rstrip("\r\n")
|
|
||||||
self._log_func(line)
|
|
||||||
self._lines.append(line)
|
|
||||||
return len(text)
|
|
||||||
|
|
||||||
def flush(self) -> None:
|
|
||||||
line = self._buffer.rstrip("\r\n")
|
|
||||||
if line:
|
|
||||||
self._log_func(line)
|
|
||||||
self._lines.append(line)
|
|
||||||
self._buffer = ""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lines(self) -> list[str]:
|
|
||||||
return list(self._lines)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]:
|
|
||||||
stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES)
|
|
||||||
with (
|
|
||||||
contextlib.redirect_stdout(stream),
|
|
||||||
contextlib.redirect_stderr(stream),
|
|
||||||
):
|
|
||||||
result_code = pip_main(args)
|
|
||||||
stream.flush()
|
|
||||||
return result_code, stream.lines
|
|
||||||
|
|
||||||
|
|
||||||
def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool:
|
|
||||||
names = pattern_names or tuple(_PIP_FAILURE_PATTERNS)
|
|
||||||
return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_conflict_detail_line(line: str) -> str:
|
|
||||||
stripped = line.strip()
|
|
||||||
if _matches_pip_failure_pattern(stripped, "user_requested"):
|
|
||||||
return re.sub(
|
|
||||||
r"^\s*The user requested\s+",
|
|
||||||
"",
|
|
||||||
stripped,
|
|
||||||
flags=re.IGNORECASE,
|
|
||||||
)
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
|
|
||||||
def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None:
|
|
||||||
matched_indices = [
|
|
||||||
index
|
|
||||||
for index, line in enumerate(output_lines)
|
|
||||||
if _matches_pip_failure_pattern(line)
|
|
||||||
]
|
]
|
||||||
if matched_indices:
|
sys.path.insert(0, normalized_target)
|
||||||
relevant_index_set: set[int] = set()
|
|
||||||
for index in matched_indices:
|
|
||||||
start = max(0, index - 1)
|
|
||||||
end = min(len(output_lines), index + 2)
|
|
||||||
relevant_index_set.update(range(start, end))
|
|
||||||
relevant_output_lines = [
|
|
||||||
line
|
|
||||||
for index, line in enumerate(output_lines)
|
|
||||||
if index in relevant_index_set
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
relevant_output_lines = output_lines[-5:]
|
|
||||||
|
|
||||||
if not relevant_output_lines:
|
|
||||||
|
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
|
||||||
|
base_path = os.path.join(site_packages_path, *module_name.split("."))
|
||||||
|
package_init = os.path.join(base_path, "__init__.py")
|
||||||
|
module_file = f"{base_path}.py"
|
||||||
|
return os.path.isfile(package_init) or os.path.isfile(module_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_module_loaded_from_site_packages(
|
||||||
|
module_name: str,
|
||||||
|
site_packages_path: str,
|
||||||
|
) -> bool:
|
||||||
|
module = sys.modules.get(module_name)
|
||||||
|
if module is None:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
module_file = getattr(module, "__file__", None)
|
||||||
|
if not module_file:
|
||||||
|
return False
|
||||||
|
|
||||||
|
module_path = os.path.realpath(module_file)
|
||||||
|
site_packages_real = os.path.realpath(site_packages_path)
|
||||||
|
try:
|
||||||
|
return (
|
||||||
|
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_requirement_name(raw_requirement: str) -> str | None:
|
||||||
|
line = raw_requirement.split("#", 1)[0].strip()
|
||||||
|
if not line:
|
||||||
|
return None
|
||||||
|
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
|
||||||
|
return None
|
||||||
|
if line.startswith("-"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dependency_detail_lines = [
|
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
|
||||||
line.strip()
|
if egg_match:
|
||||||
for line in relevant_output_lines
|
return _canonicalize_distribution_name(egg_match.group(1))
|
||||||
if _matches_pip_failure_pattern(line, "dependency_detail")
|
|
||||||
]
|
|
||||||
requested_lines = [
|
|
||||||
line.strip()
|
|
||||||
for line in relevant_output_lines
|
|
||||||
if _matches_pip_failure_pattern(line, "user_requested")
|
|
||||||
and not _matches_pip_failure_pattern(line, "constraint")
|
|
||||||
]
|
|
||||||
if not requested_lines:
|
|
||||||
requested_lines = [
|
|
||||||
line
|
|
||||||
for line in dependency_detail_lines
|
|
||||||
if not _matches_pip_failure_pattern(line, "constraint")
|
|
||||||
]
|
|
||||||
constraint_lines = [
|
|
||||||
line.strip()
|
|
||||||
for line in relevant_output_lines
|
|
||||||
if _matches_pip_failure_pattern(line, "constraint")
|
|
||||||
]
|
|
||||||
|
|
||||||
has_strong_conflict_signal = any(
|
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
|
||||||
_matches_pip_failure_pattern(
|
if not candidate:
|
||||||
line,
|
|
||||||
"resolution_impossible",
|
|
||||||
"cannot_install",
|
|
||||||
)
|
|
||||||
for line in relevant_output_lines
|
|
||||||
)
|
|
||||||
|
|
||||||
has_contextual_conflict_signal = any(
|
|
||||||
_matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines
|
|
||||||
) and bool(dependency_detail_lines or requested_lines or constraint_lines)
|
|
||||||
|
|
||||||
return PipConflictContext(
|
|
||||||
relevant_lines=relevant_output_lines,
|
|
||||||
requested_lines=requested_lines,
|
|
||||||
dependency_detail_lines=dependency_detail_lines,
|
|
||||||
constraint_lines=constraint_lines,
|
|
||||||
has_strong_conflict_signal=has_strong_conflict_signal,
|
|
||||||
has_contextual_conflict_signal=has_contextual_conflict_signal,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None:
|
|
||||||
context = _build_pip_conflict_context(output_lines)
|
|
||||||
if context is None:
|
|
||||||
return None
|
return None
|
||||||
|
return _canonicalize_distribution_name(candidate)
|
||||||
|
|
||||||
if (
|
|
||||||
not context.has_strong_conflict_signal
|
|
||||||
and not context.has_contextual_conflict_signal
|
|
||||||
and not (context.requested_lines and context.constraint_lines)
|
|
||||||
):
|
|
||||||
return None
|
|
||||||
|
|
||||||
is_core_conflict = bool(context.constraint_lines)
|
def _extract_requirement_names(requirements_path: str) -> set[str]:
|
||||||
|
names: set[str] = set()
|
||||||
detail = ""
|
try:
|
||||||
if context.constraint_lines and context.requested_lines:
|
with open(requirements_path, encoding="utf-8") as requirements_file:
|
||||||
detail = (
|
for line in requirements_file:
|
||||||
" 冲突详情: "
|
requirement_name = _extract_requirement_name(line)
|
||||||
f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs "
|
if requirement_name:
|
||||||
f"{_normalize_conflict_detail_line(context.constraint_lines[0])}。"
|
names.add(requirement_name)
|
||||||
)
|
except Exception as exc:
|
||||||
elif len(context.dependency_detail_lines) >= 2:
|
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
|
||||||
detail = (
|
return names
|
||||||
" 冲突详情: "
|
|
||||||
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs "
|
|
||||||
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}。"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_core_conflict:
|
|
||||||
message = (
|
|
||||||
f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容,"
|
|
||||||
"为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message = f"检测到依赖冲突。{detail}"
|
|
||||||
|
|
||||||
return DependencyConflictError(
|
|
||||||
message,
|
|
||||||
context.relevant_lines,
|
|
||||||
is_core_conflict=is_core_conflict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_top_level_modules(
|
def _extract_top_level_modules(
|
||||||
@@ -388,11 +155,7 @@ def _collect_candidate_modules(
|
|||||||
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
|
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
|
||||||
try:
|
try:
|
||||||
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
|
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
|
||||||
distribution_name = (
|
distribution_name = distribution.metadata.get("Name")
|
||||||
distribution.metadata["Name"]
|
|
||||||
if "Name" in distribution.metadata
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if not distribution_name:
|
if not distribution_name:
|
||||||
continue
|
continue
|
||||||
canonical_name = _canonicalize_distribution_name(distribution_name)
|
canonical_name = _canonicalize_distribution_name(distribution_name)
|
||||||
@@ -410,7 +173,7 @@ def _collect_candidate_modules(
|
|||||||
|
|
||||||
for distribution in by_name.get(requirement_name, []):
|
for distribution in by_name.get(requirement_name, []):
|
||||||
for dependency_line in distribution.requires or []:
|
for dependency_line in distribution.requires or []:
|
||||||
dependency_name = extract_requirement_name(dependency_line)
|
dependency_name = _extract_requirement_name(dependency_line)
|
||||||
if not dependency_name:
|
if not dependency_name:
|
||||||
continue
|
continue
|
||||||
if dependency_name in expanded_requirement_names:
|
if dependency_name in expanded_requirement_names:
|
||||||
@@ -467,38 +230,6 @@ def _ensure_preferred_modules(
|
|||||||
raise RuntimeError(conflict_message)
|
raise RuntimeError(conflict_message)
|
||||||
|
|
||||||
|
|
||||||
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
|
|
||||||
base_path = os.path.join(site_packages_path, *module_name.split("."))
|
|
||||||
package_init = os.path.join(base_path, "__init__.py")
|
|
||||||
module_file = f"{base_path}.py"
|
|
||||||
return os.path.isfile(package_init) or os.path.isfile(module_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_module_loaded_from_site_packages(
|
|
||||||
module_name: str,
|
|
||||||
site_packages_path: str,
|
|
||||||
) -> bool:
|
|
||||||
module = sys.modules.get(module_name)
|
|
||||||
if module is None:
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
module_file = getattr(module, "__file__", None)
|
|
||||||
if not module_file:
|
|
||||||
return False
|
|
||||||
|
|
||||||
module_path = os.path.realpath(module_file)
|
|
||||||
site_packages_real = os.path.realpath(site_packages_path)
|
|
||||||
try:
|
|
||||||
return (
|
|
||||||
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _prefer_module_from_site_packages(
|
def _prefer_module_from_site_packages(
|
||||||
module_name: str, site_packages_path: str
|
module_name: str, site_packages_path: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -800,63 +531,9 @@ def _patch_distlib_finder_for_frozen_runtime() -> None:
|
|||||||
|
|
||||||
|
|
||||||
class PipInstaller:
|
class PipInstaller:
|
||||||
def __init__(
|
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None:
|
||||||
self,
|
|
||||||
pip_install_arg: str,
|
|
||||||
pypi_index_url: str | None = None,
|
|
||||||
core_dist_name: str | None = "AstrBot",
|
|
||||||
) -> None:
|
|
||||||
self.pip_install_arg = pip_install_arg
|
self.pip_install_arg = pip_install_arg
|
||||||
self.pypi_index_url = pypi_index_url
|
self.pypi_index_url = pypi_index_url
|
||||||
self.core_dist_name = core_dist_name
|
|
||||||
self._core_constraints = CoreConstraintsProvider(core_dist_name)
|
|
||||||
|
|
||||||
def _build_pip_args(
|
|
||||||
self,
|
|
||||||
package_name: str | None,
|
|
||||||
requirements_path: str | None,
|
|
||||||
mirror: str | None,
|
|
||||||
) -> tuple[list[str], set[str]]:
|
|
||||||
args: list[str] = []
|
|
||||||
requested_requirements: set[str] = set()
|
|
||||||
normalized_requirements_path = (
|
|
||||||
requirements_path.strip() if requirements_path else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
if package_name and normalized_requirements_path:
|
|
||||||
raise ValueError(
|
|
||||||
"package_name and requirements_path cannot be used together"
|
|
||||||
)
|
|
||||||
|
|
||||||
if package_name:
|
|
||||||
parsed_package = parse_package_install_input(package_name)
|
|
||||||
if parsed_package.specs:
|
|
||||||
args = ["install", *parsed_package.specs]
|
|
||||||
requested_requirements = set(parsed_package.requirement_names)
|
|
||||||
elif normalized_requirements_path:
|
|
||||||
args = ["install", "-r", normalized_requirements_path]
|
|
||||||
requested_requirements = extract_requirement_names(
|
|
||||||
normalized_requirements_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if not args:
|
|
||||||
return [], requested_requirements
|
|
||||||
|
|
||||||
pip_install_args = (
|
|
||||||
shlex.split(self.pip_install_arg) if self.pip_install_arg else []
|
|
||||||
)
|
|
||||||
|
|
||||||
if not _package_specs_override_index([*args[1:], *pip_install_args]):
|
|
||||||
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
|
||||||
trusted_host = _get_trusted_host_for_index_url(index_url)
|
|
||||||
if trusted_host:
|
|
||||||
args.extend(["--trusted-host", trusted_host])
|
|
||||||
args.extend(["-i", index_url])
|
|
||||||
|
|
||||||
if pip_install_args:
|
|
||||||
args.extend(pip_install_args)
|
|
||||||
|
|
||||||
return args, requested_requirements
|
|
||||||
|
|
||||||
async def install(
|
async def install(
|
||||||
self,
|
self,
|
||||||
@@ -864,37 +541,36 @@ class PipInstaller:
|
|||||||
requirements_path: str | None = None,
|
requirements_path: str | None = None,
|
||||||
mirror: str | None = None,
|
mirror: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
args, requested_requirements = self._build_pip_args(
|
args = ["install"]
|
||||||
package_name, requirements_path, mirror
|
requested_requirements: set[str] = set()
|
||||||
)
|
if package_name:
|
||||||
if not args:
|
args.append(package_name)
|
||||||
logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。")
|
requirement_name = _extract_requirement_name(package_name)
|
||||||
return
|
if requirement_name:
|
||||||
|
requested_requirements.add(requirement_name)
|
||||||
|
elif requirements_path:
|
||||||
|
args.extend(["-r", requirements_path])
|
||||||
|
requested_requirements = _extract_requirement_names(requirements_path)
|
||||||
|
|
||||||
|
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
||||||
|
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
|
||||||
|
|
||||||
target_site_packages = None
|
target_site_packages = None
|
||||||
if is_packaged_desktop_runtime():
|
if is_packaged_desktop_runtime():
|
||||||
target_site_packages = get_astrbot_site_packages_path()
|
target_site_packages = get_astrbot_site_packages_path()
|
||||||
os.makedirs(target_site_packages, exist_ok=True)
|
os.makedirs(target_site_packages, exist_ok=True)
|
||||||
_prepend_sys_path(target_site_packages)
|
_prepend_sys_path(target_site_packages)
|
||||||
args.extend(
|
args.extend(["--target", target_site_packages])
|
||||||
[
|
args.extend(["--upgrade", "--force-reinstall"])
|
||||||
"--target",
|
|
||||||
target_site_packages,
|
|
||||||
"--upgrade",
|
|
||||||
"--upgrade-strategy",
|
|
||||||
"only-if-needed",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._core_constraints.constraints_file() as constraints_file_path:
|
if self.pip_install_arg:
|
||||||
if constraints_file_path:
|
args.extend(self.pip_install_arg.split())
|
||||||
args.extend(["-c", constraints_file_path])
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
|
||||||
"Pip 包管理器 argv: %s",
|
result_code = await self._run_pip_in_process(args)
|
||||||
["pip", *_redact_pip_args_for_logging(args)],
|
|
||||||
)
|
if result_code != 0:
|
||||||
await self._run_pip_with_classification(args)
|
raise Exception(f"安装失败,错误码:{result_code}")
|
||||||
|
|
||||||
if target_site_packages:
|
if target_site_packages:
|
||||||
_prepend_sys_path(target_site_packages)
|
_prepend_sys_path(target_site_packages)
|
||||||
@@ -913,7 +589,7 @@ class PipInstaller:
|
|||||||
if not os.path.isdir(target_site_packages):
|
if not os.path.isdir(target_site_packages):
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_requirements = extract_requirement_names(requirements_path)
|
requested_requirements = _extract_requirement_names(requirements_path)
|
||||||
if not requested_requirements:
|
if not requested_requirements:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -929,21 +605,13 @@ class PipInstaller:
|
|||||||
_patch_distlib_finder_for_frozen_runtime()
|
_patch_distlib_finder_for_frozen_runtime()
|
||||||
|
|
||||||
original_handlers = list(logging.getLogger().handlers)
|
original_handlers = list(logging.getLogger().handlers)
|
||||||
try:
|
result_code, output = await asyncio.to_thread(
|
||||||
result_code, output_lines = await asyncio.to_thread(
|
_run_pip_main_with_output, pip_main, args
|
||||||
_run_pip_main_streaming, pip_main, args
|
)
|
||||||
)
|
for line in output.splitlines():
|
||||||
finally:
|
line = line.strip()
|
||||||
_cleanup_added_root_handlers(original_handlers)
|
if line:
|
||||||
|
logger.info(line)
|
||||||
if result_code != 0:
|
|
||||||
conflict = _classify_pip_failure(output_lines)
|
|
||||||
if conflict:
|
|
||||||
raise conflict
|
|
||||||
|
|
||||||
|
_cleanup_added_root_handlers(original_handlers)
|
||||||
return result_code
|
return result_code
|
||||||
|
|
||||||
async def _run_pip_with_classification(self, args: list[str]) -> None:
|
|
||||||
result_code = await self._run_pip_in_process(args)
|
|
||||||
if result_code != 0:
|
|
||||||
raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code)
|
|
||||||
|
|||||||
@@ -1,486 +0,0 @@
|
|||||||
import importlib.metadata as importlib_metadata
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import shlex
|
|
||||||
import sys
|
|
||||||
from collections.abc import Iterable, Iterator, Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from packaging.requirements import InvalidRequirement, Requirement
|
|
||||||
from packaging.specifiers import SpecifierSet
|
|
||||||
from packaging.version import InvalidVersion, Version
|
|
||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
|
||||||
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
|
||||||
|
|
||||||
logger = logging.getLogger("astrbot")
|
|
||||||
|
|
||||||
|
|
||||||
class RequirementsPrecheckFailed(Exception):
|
|
||||||
"""Raised when the pre-check of requirements fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ParsedPackageInput:
|
|
||||||
specs: tuple[str, ...]
|
|
||||||
requirement_names: frozenset[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MissingRequirementsPlan:
|
|
||||||
missing_names: frozenset[str]
|
|
||||||
install_lines: tuple[str, ...]
|
|
||||||
fallback_reason: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def canonicalize_distribution_name(name: str) -> str:
|
|
||||||
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
|
|
||||||
|
|
||||||
|
|
||||||
def strip_inline_requirement_comment(raw_input: str) -> str:
|
|
||||||
if raw_input.lstrip().startswith("#"):
|
|
||||||
return ""
|
|
||||||
return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool:
|
|
||||||
try:
|
|
||||||
parsed_version = Version(version)
|
|
||||||
except InvalidVersion:
|
|
||||||
return False
|
|
||||||
return specifier.contains(parsed_version, prereleases=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _looks_like_local_path_reference(token: str) -> bool:
|
|
||||||
candidate = token.strip()
|
|
||||||
if not candidate:
|
|
||||||
return False
|
|
||||||
return candidate in {".", ".."} or candidate.startswith(
|
|
||||||
("./", "../", "/", "~/", ".\\", "..\\", "\\")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def looks_like_direct_reference(token: str) -> bool:
|
|
||||||
candidate = token.strip()
|
|
||||||
if not candidate:
|
|
||||||
return False
|
|
||||||
return (
|
|
||||||
_looks_like_local_path_reference(candidate)
|
|
||||||
or candidate.startswith("git+")
|
|
||||||
or "://" in candidate
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_requirement_name(raw_requirement: str) -> str | None:
|
|
||||||
line = raw_requirement.split("#", 1)[0].strip()
|
|
||||||
if not line:
|
|
||||||
return None
|
|
||||||
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
|
|
||||||
return None
|
|
||||||
|
|
||||||
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
|
|
||||||
if egg_match:
|
|
||||||
return canonicalize_distribution_name(egg_match.group(1))
|
|
||||||
|
|
||||||
if line.startswith("-"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
|
|
||||||
if not candidate:
|
|
||||||
return None
|
|
||||||
return canonicalize_distribution_name(candidate)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_editable_or_direct_name(target: str) -> str | None:
|
|
||||||
name = extract_requirement_name(target)
|
|
||||||
if not name:
|
|
||||||
return None
|
|
||||||
if "#egg=" in target or not looks_like_direct_reference(target):
|
|
||||||
return name
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_requirement_name_and_spec(
|
|
||||||
line: str,
|
|
||||||
) -> tuple[str | None, SpecifierSet | None]:
|
|
||||||
if line.startswith(("-c", "--constraint")):
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
req = Requirement(line)
|
|
||||||
except InvalidRequirement:
|
|
||||||
tokens = shlex.split(line)
|
|
||||||
if not tokens:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
editable_target: str | None = None
|
|
||||||
if tokens[0] in {"-e", "--editable"} and len(tokens) > 1:
|
|
||||||
editable_target = tokens[1]
|
|
||||||
elif tokens[0].startswith("--editable="):
|
|
||||||
editable_target = tokens[0].split("=", 1)[1]
|
|
||||||
|
|
||||||
if editable_target:
|
|
||||||
name = _parse_editable_or_direct_name(editable_target)
|
|
||||||
return (name, None) if name else (None, None)
|
|
||||||
|
|
||||||
name = _parse_editable_or_direct_name(line)
|
|
||||||
return (name, None) if name else (None, None)
|
|
||||||
|
|
||||||
if req.marker and not req.marker.evaluate():
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
return canonicalize_distribution_name(req.name), (req.specifier or None)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_requirement_line(
|
|
||||||
line: str,
|
|
||||||
) -> tuple[str, SpecifierSet | None] | None:
|
|
||||||
name, specifier = _parse_requirement_name_and_spec(line)
|
|
||||||
return (name, specifier) if name else None
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]:
|
|
||||||
requirement_names: set[str] = set()
|
|
||||||
skip_next_for: str | None = None
|
|
||||||
|
|
||||||
for token in tokens:
|
|
||||||
if skip_next_for:
|
|
||||||
if skip_next_for == "editable":
|
|
||||||
name = _parse_editable_or_direct_name(token)
|
|
||||||
if name:
|
|
||||||
requirement_names.add(name)
|
|
||||||
skip_next_for = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token in {"-e", "--editable"}:
|
|
||||||
skip_next_for = "editable"
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token in {
|
|
||||||
"-i",
|
|
||||||
"--index-url",
|
|
||||||
"--extra-index-url",
|
|
||||||
"-f",
|
|
||||||
"--find-links",
|
|
||||||
"--trusted-host",
|
|
||||||
"-r",
|
|
||||||
"--requirement",
|
|
||||||
"-c",
|
|
||||||
"--constraint",
|
|
||||||
}:
|
|
||||||
skip_next_for = "option-value"
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token.startswith(("--editable=",)):
|
|
||||||
editable_target = token.split("=", 1)[1]
|
|
||||||
name = _parse_editable_or_direct_name(editable_target)
|
|
||||||
if name:
|
|
||||||
requirement_names.add(name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token.startswith(
|
|
||||||
(
|
|
||||||
"--index-url=",
|
|
||||||
"--extra-index-url=",
|
|
||||||
"--find-links=",
|
|
||||||
"--trusted-host=",
|
|
||||||
"--requirement=",
|
|
||||||
"--constraint=",
|
|
||||||
)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (
|
|
||||||
(token.startswith("-i") and token != "-i")
|
|
||||||
or (token.startswith("-f") and token != "-f")
|
|
||||||
or token == "--no-index"
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token.startswith("-"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
name, _ = _parse_requirement_name_and_spec(token)
|
|
||||||
if name:
|
|
||||||
requirement_names.add(name)
|
|
||||||
|
|
||||||
return frozenset(requirement_names)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_package_install_input(raw_input: str) -> ParsedPackageInput:
|
|
||||||
specs: list[str] = []
|
|
||||||
requirement_names: set[str] = set()
|
|
||||||
normalized = raw_input.strip()
|
|
||||||
if not normalized:
|
|
||||||
return ParsedPackageInput(specs=(), requirement_names=frozenset())
|
|
||||||
|
|
||||||
for raw_line in normalized.splitlines():
|
|
||||||
line = strip_inline_requirement_comment(raw_line)
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
Requirement(line)
|
|
||||||
except InvalidRequirement:
|
|
||||||
tokens = shlex.split(line)
|
|
||||||
if not tokens:
|
|
||||||
continue
|
|
||||||
specs.extend(tokens)
|
|
||||||
requirement_names.update(
|
|
||||||
_extract_requirement_names_from_package_tokens(tokens)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
specs.append(line)
|
|
||||||
name, _ = _parse_requirement_name_and_spec(line)
|
|
||||||
if name:
|
|
||||||
requirement_names.add(name)
|
|
||||||
|
|
||||||
return ParsedPackageInput(
|
|
||||||
specs=tuple(specs),
|
|
||||||
requirement_names=frozenset(requirement_names),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _iter_requirement_lines(
|
|
||||||
requirements_path: str,
|
|
||||||
_visited: set[str] | None = None,
|
|
||||||
) -> Iterator[str]:
|
|
||||||
visited = _visited or set()
|
|
||||||
resolved_path = os.path.realpath(requirements_path)
|
|
||||||
if resolved_path in visited:
|
|
||||||
logger.warning(
|
|
||||||
"检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path
|
|
||||||
)
|
|
||||||
return
|
|
||||||
visited.add(resolved_path)
|
|
||||||
|
|
||||||
with open(resolved_path, encoding="utf-8") as f:
|
|
||||||
for raw_line in f:
|
|
||||||
line = strip_inline_requirement_comment(raw_line)
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tokens = shlex.split(line)
|
|
||||||
if not tokens:
|
|
||||||
continue
|
|
||||||
|
|
||||||
nested: str | None = None
|
|
||||||
if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1:
|
|
||||||
nested = tokens[1]
|
|
||||||
elif tokens[0].startswith("--requirement="):
|
|
||||||
nested = tokens[0].split("=", 1)[1]
|
|
||||||
|
|
||||||
if nested:
|
|
||||||
if not os.path.isabs(nested):
|
|
||||||
nested = os.path.join(os.path.dirname(resolved_path), nested)
|
|
||||||
yield from _iter_requirement_lines(nested, _visited=visited)
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield line
|
|
||||||
|
|
||||||
|
|
||||||
def iter_requirements(
|
|
||||||
requirements_path: str | None = None,
|
|
||||||
lines: Iterable[str] | None = None,
|
|
||||||
) -> Iterator[tuple[str, SpecifierSet | None]]:
|
|
||||||
if lines is None:
|
|
||||||
if requirements_path is None:
|
|
||||||
raise ValueError("Either requirements_path or lines must be provided")
|
|
||||||
lines = _iter_requirement_lines(requirements_path)
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
parsed = _parse_requirement_line(line)
|
|
||||||
if parsed is not None:
|
|
||||||
yield parsed
|
|
||||||
|
|
||||||
|
|
||||||
def extract_requirement_names(requirements_path: str) -> set[str]:
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
name for name, _ in iter_requirements(requirements_path=requirements_path)
|
|
||||||
}
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
|
|
||||||
return set()
|
|
||||||
|
|
||||||
|
|
||||||
def get_requirement_check_paths() -> list[str]:
|
|
||||||
paths = list(sys.path)
|
|
||||||
if is_packaged_desktop_runtime():
|
|
||||||
target_site_packages = get_astrbot_site_packages_path()
|
|
||||||
if os.path.isdir(target_site_packages):
|
|
||||||
paths.insert(0, target_site_packages)
|
|
||||||
return paths
|
|
||||||
|
|
||||||
|
|
||||||
def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
|
|
||||||
distribution_name = (
|
|
||||||
distribution.metadata["Name"] if "Name" in distribution.metadata else None
|
|
||||||
)
|
|
||||||
if not distribution_name:
|
|
||||||
return None, None
|
|
||||||
return canonicalize_distribution_name(distribution_name), distribution.version
|
|
||||||
|
|
||||||
|
|
||||||
def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None:
|
|
||||||
installed: dict[str, str] = {}
|
|
||||||
try:
|
|
||||||
for distribution in importlib_metadata.distributions(path=paths):
|
|
||||||
distribution_name, version = _canonical_distribution_identity(distribution)
|
|
||||||
if not distribution_name or not version:
|
|
||||||
continue
|
|
||||||
installed.setdefault(distribution_name, version)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc)
|
|
||||||
return None
|
|
||||||
return installed
|
|
||||||
|
|
||||||
|
|
||||||
def _load_requirement_lines_for_precheck(
|
|
||||||
requirements_path: str,
|
|
||||||
) -> tuple[bool, list[str] | None]:
|
|
||||||
try:
|
|
||||||
requirement_lines = list(_iter_requirement_lines(requirements_path))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"预检查缺失依赖失败,将回退到完整安装: %s (%s)",
|
|
||||||
requirements_path,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
fallback_line = next(
|
|
||||||
(
|
|
||||||
line
|
|
||||||
for line in requirement_lines
|
|
||||||
if (
|
|
||||||
(
|
|
||||||
line.startswith(("-e ", "--editable ", "--editable="))
|
|
||||||
and "#egg=" not in line
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
_parse_requirement_line(line) is None
|
|
||||||
and looks_like_direct_reference(line)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if fallback_line is not None:
|
|
||||||
logger.info(
|
|
||||||
"缺失依赖预检查发现无法安全裁剪的 option/direct-reference 行,将回退到完整安装: %s (%s)",
|
|
||||||
requirements_path,
|
|
||||||
fallback_line,
|
|
||||||
)
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
return True, requirement_lines
|
|
||||||
|
|
||||||
|
|
||||||
def find_missing_requirements(requirements_path: str) -> set[str] | None:
|
|
||||||
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
|
|
||||||
requirements_path
|
|
||||||
)
|
|
||||||
if not can_precheck or requirement_lines is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return find_missing_requirements_from_lines(requirement_lines)
|
|
||||||
|
|
||||||
|
|
||||||
def find_missing_requirements_from_lines(
|
|
||||||
requirement_lines: Sequence[str],
|
|
||||||
) -> set[str] | None:
|
|
||||||
|
|
||||||
required = list(iter_requirements(lines=requirement_lines))
|
|
||||||
if not required:
|
|
||||||
return set()
|
|
||||||
|
|
||||||
installed = collect_installed_distribution_versions(get_requirement_check_paths())
|
|
||||||
if installed is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
missing: set[str] = set()
|
|
||||||
for name, specifier in required:
|
|
||||||
installed_version = installed.get(name)
|
|
||||||
if not installed_version:
|
|
||||||
missing.add(name)
|
|
||||||
continue
|
|
||||||
if specifier and not _specifier_contains_version(specifier, installed_version):
|
|
||||||
missing.add(name)
|
|
||||||
|
|
||||||
return missing
|
|
||||||
|
|
||||||
|
|
||||||
def build_missing_requirements_install_lines(
|
|
||||||
requirements_path: str,
|
|
||||||
requirement_lines: Sequence[str],
|
|
||||||
missing_names: set[str] | frozenset[str],
|
|
||||||
) -> tuple[str, ...] | None:
|
|
||||||
wanted_names = set(missing_names)
|
|
||||||
install_lines: list[str] = []
|
|
||||||
for line in requirement_lines:
|
|
||||||
parsed = _parse_requirement_line(line)
|
|
||||||
if parsed is None:
|
|
||||||
if looks_like_direct_reference(line) or line.startswith(("-", "--")):
|
|
||||||
logger.debug(
|
|
||||||
"缺失依赖行筛选回退到完整安装:requirements 中包含无法安全裁剪的 option/direct-reference 行: %s (%s)",
|
|
||||||
requirements_path,
|
|
||||||
line,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
continue
|
|
||||||
|
|
||||||
name, _specifier = parsed
|
|
||||||
if name in wanted_names:
|
|
||||||
install_lines.append(line)
|
|
||||||
|
|
||||||
return tuple(install_lines)
|
|
||||||
|
|
||||||
|
|
||||||
def plan_missing_requirements_install(
|
|
||||||
requirements_path: str,
|
|
||||||
) -> MissingRequirementsPlan | None:
|
|
||||||
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
|
|
||||||
requirements_path
|
|
||||||
)
|
|
||||||
if not can_precheck or requirement_lines is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
missing = find_missing_requirements_from_lines(requirement_lines)
|
|
||||||
if missing is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
install_lines = build_missing_requirements_install_lines(
|
|
||||||
requirements_path,
|
|
||||||
requirement_lines,
|
|
||||||
missing,
|
|
||||||
)
|
|
||||||
if install_lines is None:
|
|
||||||
return None
|
|
||||||
if missing and not install_lines:
|
|
||||||
logger.warning(
|
|
||||||
"预检查缺失依赖成功,但无法映射到可安装 requirement 行,将回退到完整安装: %s -> %s",
|
|
||||||
requirements_path,
|
|
||||||
sorted(missing),
|
|
||||||
)
|
|
||||||
return MissingRequirementsPlan(
|
|
||||||
missing_names=frozenset(missing),
|
|
||||||
install_lines=(),
|
|
||||||
fallback_reason="unmapped missing requirement names",
|
|
||||||
)
|
|
||||||
|
|
||||||
return MissingRequirementsPlan(
|
|
||||||
missing_names=frozenset(missing),
|
|
||||||
install_lines=install_lines,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def find_missing_requirements_or_raise(requirements_path: str) -> set[str]:
|
|
||||||
missing = find_missing_requirements(requirements_path)
|
|
||||||
if missing is None:
|
|
||||||
raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}")
|
|
||||||
return missing
|
|
||||||
@@ -1,13 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
|
||||||
import weakref
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
class _PerLoopSessionLockManager:
|
class SessionLockManager:
|
||||||
"""Per-event-loop session lock manager; keeps original simple semantics."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
self._lock_count: dict[str, int] = defaultdict(int)
|
self._lock_count: dict[str, int] = defaultdict(int)
|
||||||
@@ -30,26 +26,4 @@ class _PerLoopSessionLockManager:
|
|||||||
self._lock_count.pop(session_id, None)
|
self._lock_count.pop(session_id, None)
|
||||||
|
|
||||||
|
|
||||||
class SessionLockManager:
|
|
||||||
"""Thread-safe session lock manager with per-event-loop isolation."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._state_guard = threading.Lock()
|
|
||||||
self._loop_managers: weakref.WeakKeyDictionary[
|
|
||||||
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
|
|
||||||
] = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
|
|
||||||
"""Get the lock manager for the current event loop."""
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
with self._state_guard:
|
|
||||||
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def acquire_lock(self, session_id: str):
|
|
||||||
manager = self._get_loop_manager()
|
|
||||||
async with manager.acquire_lock(session_id):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
session_lock_manager = SessionLockManager()
|
session_lock_manager = SessionLockManager()
|
||||||
|
|||||||
@@ -82,8 +82,7 @@ class AuthRoute(Route):
|
|||||||
def generate_jwt(self, username):
|
def generate_jwt(self, username):
|
||||||
payload = {
|
payload = {
|
||||||
"username": username,
|
"username": username,
|
||||||
"exp": datetime.datetime.now(datetime.timezone.utc)
|
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
|
||||||
+ datetime.timedelta(days=7),
|
|
||||||
}
|
}
|
||||||
jwt_token = self.config["dashboard"].get("jwt_secret", None)
|
jwt_token = self.config["dashboard"].get("jwt_secret", None)
|
||||||
if not jwt_token:
|
if not jwt_token:
|
||||||
|
|||||||
@@ -977,17 +977,7 @@ class BackupRoute(Route):
|
|||||||
if not jwt_secret:
|
if not jwt_secret:
|
||||||
return Response().error("服务器配置错误").__dict__
|
return Response().error("服务器配置错误").__dict__
|
||||||
|
|
||||||
# Verify JWT token with strict security options
|
jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
||||||
jwt.decode(
|
|
||||||
token,
|
|
||||||
jwt_secret,
|
|
||||||
algorithms=["HS256"],
|
|
||||||
options={
|
|
||||||
"require": ["exp"], # Require expiration claim
|
|
||||||
"verify_signature": True, # Explicitly verify signature
|
|
||||||
"verify_exp": True, # Verify expiration
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
return Response().error("Token 已过期,请刷新页面后重试").__dict__
|
return Response().error("Token 已过期,请刷新页面后重试").__dict__
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
|
|||||||
@@ -36,20 +36,6 @@ async def track_conversation(convs: dict, conv_id: str):
|
|||||||
convs.pop(conv_id, None)
|
convs.pop(conv_id, None)
|
||||||
|
|
||||||
|
|
||||||
async def _poll_webchat_stream_result(back_queue, username: str):
|
|
||||||
try:
|
|
||||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return None, False
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
|
||||||
return None, True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"WebChat stream error: {e}")
|
|
||||||
return None, False
|
|
||||||
return result, False
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRoute(Route):
|
class ChatRoute(Route):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -65,7 +51,6 @@ class ChatRoute(Route):
|
|||||||
"/chat/get_session": ("GET", self.get_session),
|
"/chat/get_session": ("GET", self.get_session),
|
||||||
"/chat/stop": ("POST", self.stop_session),
|
"/chat/stop": ("POST", self.stop_session),
|
||||||
"/chat/delete_session": ("GET", self.delete_webchat_session),
|
"/chat/delete_session": ("GET", self.delete_webchat_session),
|
||||||
"/chat/batch_delete_sessions": ("POST", self.batch_delete_sessions),
|
|
||||||
"/chat/update_session_display_name": (
|
"/chat/update_session_display_name": (
|
||||||
"POST",
|
"POST",
|
||||||
self.update_session_display_name,
|
self.update_session_display_name,
|
||||||
@@ -357,12 +342,16 @@ class ChatRoute(Route):
|
|||||||
|
|
||||||
async with track_conversation(self.running_convs, webchat_conv_id):
|
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||||
while True:
|
while True:
|
||||||
result, should_break = await _poll_webchat_stream_result(
|
try:
|
||||||
back_queue, username
|
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||||
)
|
except asyncio.TimeoutError:
|
||||||
if should_break:
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||||
client_disconnected = True
|
client_disconnected = True
|
||||||
break
|
except Exception as e:
|
||||||
|
logger.error(f"WebChat stream error: {e}")
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -589,9 +578,19 @@ class ChatRoute(Route):
|
|||||||
|
|
||||||
return Response().ok(data={"stopped_count": stopped_count}).__dict__
|
return Response().ok(data={"stopped_count": stopped_count}).__dict__
|
||||||
|
|
||||||
async def _delete_session_internal(self, session, username: str) -> None:
|
async def delete_webchat_session(self):
|
||||||
"""Delete a single session and all its related data."""
|
"""Delete a Platform session and all its related data."""
|
||||||
session_id = session.session_id
|
session_id = request.args.get("session_id")
|
||||||
|
if not session_id:
|
||||||
|
return Response().error("Missing key: session_id").__dict__
|
||||||
|
username = g.get("username", "guest")
|
||||||
|
|
||||||
|
# 验证会话是否存在且属于当前用户
|
||||||
|
session = await self.db.get_platform_session_by_id(session_id)
|
||||||
|
if not session:
|
||||||
|
return Response().error(f"Session {session_id} not found").__dict__
|
||||||
|
if session.creator != username:
|
||||||
|
return Response().error("Permission denied").__dict__
|
||||||
|
|
||||||
# 删除该会话下的所有对话
|
# 删除该会话下的所有对话
|
||||||
message_type = "GroupMessage" if session.is_group else "FriendMessage"
|
message_type = "GroupMessage" if session.is_group else "FriendMessage"
|
||||||
@@ -633,70 +632,8 @@ class ChatRoute(Route):
|
|||||||
# 删除会话
|
# 删除会话
|
||||||
await self.db.delete_platform_session(session_id)
|
await self.db.delete_platform_session(session_id)
|
||||||
|
|
||||||
async def delete_webchat_session(self):
|
|
||||||
"""Delete a Platform session and all its related data."""
|
|
||||||
session_id = request.args.get("session_id")
|
|
||||||
if not session_id:
|
|
||||||
return Response().error("Missing key: session_id").__dict__
|
|
||||||
username = g.get("username", "guest")
|
|
||||||
|
|
||||||
session = await self.db.get_platform_session_by_id(session_id)
|
|
||||||
if not session:
|
|
||||||
return Response().error(f"Session {session_id} not found").__dict__
|
|
||||||
if session.creator != username:
|
|
||||||
return Response().error("Permission denied").__dict__
|
|
||||||
|
|
||||||
await self._delete_session_internal(session, username)
|
|
||||||
|
|
||||||
return Response().ok().__dict__
|
return Response().ok().__dict__
|
||||||
|
|
||||||
async def batch_delete_sessions(self):
|
|
||||||
"""Batch delete multiple Platform sessions."""
|
|
||||||
post_data = await request.json
|
|
||||||
if post_data is None:
|
|
||||||
return Response().error("Missing JSON body").__dict__
|
|
||||||
if not isinstance(post_data, dict):
|
|
||||||
return Response().error("Invalid JSON body: expected object").__dict__
|
|
||||||
|
|
||||||
session_ids = post_data.get("session_ids")
|
|
||||||
if not session_ids or not isinstance(session_ids, list):
|
|
||||||
return Response().error("Missing or invalid key: session_ids").__dict__
|
|
||||||
|
|
||||||
username = g.get("username", "guest")
|
|
||||||
sessions = await self.db.get_platform_sessions_by_ids(session_ids)
|
|
||||||
sessions_by_id = {session.session_id: session for session in sessions}
|
|
||||||
deleted_count = 0
|
|
||||||
failed_items = []
|
|
||||||
|
|
||||||
for sid in session_ids:
|
|
||||||
session = sessions_by_id.get(sid)
|
|
||||||
if not session:
|
|
||||||
failed_items.append({"session_id": sid, "reason": "not found"})
|
|
||||||
continue
|
|
||||||
if session.creator != username:
|
|
||||||
failed_items.append({"session_id": sid, "reason": "permission denied"})
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._delete_session_internal(session, username)
|
|
||||||
deleted_count += 1
|
|
||||||
sessions_by_id.pop(sid, None)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to delete session %s", sid)
|
|
||||||
failed_items.append({"session_id": sid, "reason": "internal_error"})
|
|
||||||
|
|
||||||
return (
|
|
||||||
Response()
|
|
||||||
.ok(
|
|
||||||
data={
|
|
||||||
"deleted_count": deleted_count,
|
|
||||||
"failed_count": len(failed_items),
|
|
||||||
"failed_items": failed_items,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
def _extract_attachment_ids(self, history_list) -> list[str]:
|
def _extract_attachment_ids(self, history_list) -> list[str]:
|
||||||
"""从消息历史中提取所有 attachment_id"""
|
"""从消息历史中提取所有 attachment_id"""
|
||||||
attachment_ids = []
|
attachment_ids = []
|
||||||
|
|||||||
@@ -130,51 +130,21 @@ class LiveChatRoute(Route):
|
|||||||
|
|
||||||
async def live_chat_ws(self) -> None:
|
async def live_chat_ws(self) -> None:
|
||||||
"""Legacy Live Chat WebSocket 处理器(默认 ct=live)"""
|
"""Legacy Live Chat WebSocket 处理器(默认 ct=live)"""
|
||||||
token = websocket.args.get("token")
|
await self._unified_ws_loop(force_ct="live")
|
||||||
if not token:
|
|
||||||
await websocket.close(1008, "Missing authentication token")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
jwt_secret = self.config["dashboard"].get("jwt_secret")
|
|
||||||
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
|
||||||
username = payload["username"]
|
|
||||||
except jwt.ExpiredSignatureError:
|
|
||||||
await websocket.close(1008, "Token expired")
|
|
||||||
return
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
await websocket.close(1008, "Invalid token")
|
|
||||||
return
|
|
||||||
|
|
||||||
await self.run_ws_session(username=username, force_ct="live")
|
|
||||||
|
|
||||||
async def unified_chat_ws(self) -> None:
|
async def unified_chat_ws(self) -> None:
|
||||||
"""Unified Chat WebSocket 处理器(支持 ct=live/chat)"""
|
"""Unified Chat WebSocket 处理器(支持 ct=live/chat)"""
|
||||||
token = websocket.args.get("token")
|
await self._unified_ws_loop(force_ct=None)
|
||||||
if not token:
|
|
||||||
await websocket.close(1008, "Missing authentication token")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
jwt_secret = self.config["dashboard"].get("jwt_secret")
|
|
||||||
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
|
||||||
username = payload["username"]
|
|
||||||
except jwt.ExpiredSignatureError:
|
|
||||||
await websocket.close(1008, "Token expired")
|
|
||||||
return
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
await websocket.close(1008, "Invalid token")
|
|
||||||
return
|
|
||||||
|
|
||||||
await self.run_ws_session(username=username, force_ct=None)
|
|
||||||
|
|
||||||
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
|
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
|
||||||
"""统一 WebSocket 循环"""
|
"""统一 WebSocket 循环"""
|
||||||
# Keep the legacy entry point for internal call sites.
|
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
|
||||||
|
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
|
||||||
token = websocket.args.get("token")
|
token = websocket.args.get("token")
|
||||||
if not token:
|
if not token:
|
||||||
await websocket.close(1008, "Missing authentication token")
|
await websocket.close(1008, "Missing authentication token")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
jwt_secret = self.config["dashboard"].get("jwt_secret")
|
jwt_secret = self.config["dashboard"].get("jwt_secret")
|
||||||
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
||||||
@@ -185,10 +155,7 @@ class LiveChatRoute(Route):
|
|||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
await websocket.close(1008, "Invalid token")
|
await websocket.close(1008, "Invalid token")
|
||||||
return
|
return
|
||||||
await self.run_ws_session(username=username, force_ct=force_ct)
|
|
||||||
|
|
||||||
async def run_ws_session(self, username: str, force_ct: str | None = None) -> None:
|
|
||||||
"""Run a live/unified websocket session for an authenticated username."""
|
|
||||||
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
|
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
|
||||||
live_session = LiveChatSession(session_id, username)
|
live_session = LiveChatSession(session_id, username)
|
||||||
self.sessions[session_id] = live_session
|
self.sessions[session_id] = live_session
|
||||||
@@ -723,16 +690,6 @@ class LiveChatRoute(Route):
|
|||||||
|
|
||||||
elif msg_type == "end_speaking":
|
elif msg_type == "end_speaking":
|
||||||
# 结束说话
|
# 结束说话
|
||||||
if session.is_processing:
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"t": "error",
|
|
||||||
"data": "Session is busy",
|
|
||||||
"code": "PROCESSING_ERROR",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
stamp = message.get("stamp")
|
stamp = message.get("stamp")
|
||||||
if not stamp:
|
if not stamp:
|
||||||
logger.warning("[Live Chat] end_speaking 缺少 stamp")
|
logger.warning("[Live Chat] end_speaking 缺少 stamp")
|
||||||
@@ -746,59 +703,45 @@ class LiveChatRoute(Route):
|
|||||||
# 处理音频:STT -> LLM -> TTS
|
# 处理音频:STT -> LLM -> TTS
|
||||||
await self._process_audio(session, audio_path, assemble_duration)
|
await self._process_audio(session, audio_path, assemble_duration)
|
||||||
|
|
||||||
elif msg_type == "text_input":
|
|
||||||
if session.is_processing:
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"t": "error",
|
|
||||||
"data": "Session is busy",
|
|
||||||
"code": "PROCESSING_ERROR",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
user_text = message.get("text")
|
|
||||||
if not isinstance(user_text, str):
|
|
||||||
user_text = message.get("message")
|
|
||||||
|
|
||||||
if not isinstance(user_text, str) or not user_text.strip():
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"t": "error",
|
|
||||||
"data": "message must be non-empty text",
|
|
||||||
"code": "INVALID_MESSAGE_FORMAT",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
await self._process_live_user_text(
|
|
||||||
session,
|
|
||||||
user_text=user_text.strip(),
|
|
||||||
initial_metrics={"input_type": "text"},
|
|
||||||
processing_start_time=time.time(),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif msg_type == "interrupt":
|
elif msg_type == "interrupt":
|
||||||
# 用户打断
|
# 用户打断
|
||||||
session.should_interrupt = True
|
session.should_interrupt = True
|
||||||
logger.info(f"[Live Chat] 用户打断: {session.username}")
|
logger.info(f"[Live Chat] 用户打断: {session.username}")
|
||||||
|
|
||||||
async def _process_live_user_text(
|
async def _process_audio(
|
||||||
self,
|
self, session: LiveChatSession, audio_path: str, assemble_duration: float
|
||||||
session: LiveChatSession,
|
|
||||||
user_text: str,
|
|
||||||
initial_metrics: dict[str, Any] | None = None,
|
|
||||||
processing_start_time: float | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""处理 Live 用户文本:走 run_live_agent pipeline 并回传流式 TTS."""
|
"""处理音频:STT -> LLM -> 流式 TTS"""
|
||||||
try:
|
try:
|
||||||
if initial_metrics:
|
# 发送 WAV 组装耗时
|
||||||
await websocket.send_json({"t": "metrics", "data": initial_metrics})
|
await websocket.send_json(
|
||||||
|
{"t": "metrics", "data": {"wav_assemble_time": assemble_duration}}
|
||||||
|
)
|
||||||
|
wav_assembly_finish_time = time.time()
|
||||||
|
|
||||||
processing_start = processing_start_time or time.time()
|
|
||||||
session.is_processing = True
|
session.is_processing = True
|
||||||
session.should_interrupt = False
|
session.should_interrupt = False
|
||||||
|
|
||||||
|
# 1. STT - 语音转文字
|
||||||
|
ctx = self.plugin_manager.context
|
||||||
|
stt_provider = ctx.provider_manager.stt_provider_insts[0]
|
||||||
|
|
||||||
|
if not stt_provider:
|
||||||
|
logger.error("[Live Chat] STT Provider 未配置")
|
||||||
|
await websocket.send_json({"t": "error", "data": "语音识别服务未配置"})
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.send_json(
|
||||||
|
{"t": "metrics", "data": {"stt": stt_provider.meta().type}}
|
||||||
|
)
|
||||||
|
|
||||||
|
user_text = await stt_provider.get_text(audio_path)
|
||||||
|
if not user_text:
|
||||||
|
logger.warning("[Live Chat] STT 识别结果为空")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"[Live Chat] STT 结果: {user_text}")
|
||||||
|
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"t": "user_msg",
|
"t": "user_msg",
|
||||||
@@ -818,6 +761,7 @@ class LiveChatRoute(Route):
|
|||||||
"action_type": "live", # 标记为 live mode
|
"action_type": "live", # 标记为 live mode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 将消息放入队列
|
||||||
await queue.put((session.username, cid, payload))
|
await queue.put((session.username, cid, payload))
|
||||||
|
|
||||||
# 3. 等待响应并流式发送 TTS 音频
|
# 3. 等待响应并流式发送 TTS 音频
|
||||||
@@ -832,9 +776,11 @@ class LiveChatRoute(Route):
|
|||||||
# 用户打断,停止处理
|
# 用户打断,停止处理
|
||||||
logger.info("[Live Chat] 检测到用户打断")
|
logger.info("[Live Chat] 检测到用户打断")
|
||||||
await websocket.send_json({"t": "stop_play"})
|
await websocket.send_json({"t": "stop_play"})
|
||||||
|
# 保存消息并标记为被打断
|
||||||
await self._save_interrupted_message(
|
await self._save_interrupted_message(
|
||||||
session, user_text, bot_text
|
session, user_text, bot_text
|
||||||
)
|
)
|
||||||
|
# 清空队列中未处理的消息
|
||||||
while not back_queue.empty():
|
while not back_queue.empty():
|
||||||
try:
|
try:
|
||||||
back_queue.get_nowait()
|
back_queue.get_nowait()
|
||||||
@@ -859,7 +805,6 @@ class LiveChatRoute(Route):
|
|||||||
|
|
||||||
result_type = result.get("type")
|
result_type = result.get("type")
|
||||||
result_chain_type = result.get("chain_type")
|
result_chain_type = result.get("chain_type")
|
||||||
result_streaming = bool(result.get("streaming", False))
|
|
||||||
data = result.get("data", "")
|
data = result.get("data", "")
|
||||||
|
|
||||||
if result_chain_type == "agent_stats":
|
if result_chain_type == "agent_stats":
|
||||||
@@ -882,41 +827,29 @@ class LiveChatRoute(Route):
|
|||||||
if result_chain_type == "tts_stats":
|
if result_chain_type == "tts_stats":
|
||||||
try:
|
try:
|
||||||
stats = json.loads(data)
|
stats = json.loads(data)
|
||||||
await websocket.send_json({"t": "metrics", "data": stats})
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"t": "metrics",
|
||||||
|
"data": stats,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if result_chain_type == "live_text_delta":
|
|
||||||
if data:
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"t": "bot_delta_chunk",
|
|
||||||
"data": {"text": data},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if result_type == "plain":
|
if result_type == "plain":
|
||||||
if (
|
# 普通文本消息
|
||||||
result_streaming
|
|
||||||
and data
|
|
||||||
and result_chain_type != "reasoning"
|
|
||||||
):
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"t": "bot_delta_chunk",
|
|
||||||
"data": {"text": data},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
bot_text += data
|
bot_text += data
|
||||||
|
|
||||||
elif result_type == "audio_chunk":
|
elif result_type == "audio_chunk":
|
||||||
|
# 流式音频数据
|
||||||
if not audio_playing:
|
if not audio_playing:
|
||||||
audio_playing = True
|
audio_playing = True
|
||||||
logger.debug("[Live Chat] 开始播放音频流")
|
logger.debug("[Live Chat] 开始播放音频流")
|
||||||
|
|
||||||
|
# Calculate latency from wav assembly finish to first audio chunk
|
||||||
speak_to_first_frame_latency = (
|
speak_to_first_frame_latency = (
|
||||||
time.time() - processing_start
|
time.time() - wav_assembly_finish_time
|
||||||
)
|
)
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
@@ -936,15 +869,19 @@ class LiveChatRoute(Route):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 发送音频数据给前端
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"t": "response",
|
"t": "response",
|
||||||
"data": data,
|
"data": data, # base64 编码的音频数据
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
elif result_type in ["complete", "end"]:
|
elif result_type in ["complete", "end"]:
|
||||||
|
# 处理完成
|
||||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||||
|
|
||||||
|
# 如果没有音频流,发送 bot 消息文本
|
||||||
if not audio_playing:
|
if not audio_playing:
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
@@ -956,8 +893,11 @@ class LiveChatRoute(Route):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 发送结束标记
|
||||||
await websocket.send_json({"t": "end"})
|
await websocket.send_json({"t": "end"})
|
||||||
wav_to_tts_duration = time.time() - processing_start
|
|
||||||
|
# 发送总耗时
|
||||||
|
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"t": "metrics",
|
"t": "metrics",
|
||||||
@@ -969,65 +909,13 @@ class LiveChatRoute(Route):
|
|||||||
webchat_queue_mgr.remove_back_queue(message_id)
|
webchat_queue_mgr.remove_back_queue(message_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Live Chat] 处理文本失败: {e}", exc_info=True)
|
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
|
||||||
await websocket.send_json({"t": "error", "data": f"处理失败: {str(e)}"})
|
await websocket.send_json({"t": "error", "data": f"处理失败: {str(e)}"})
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
session.is_processing = False
|
session.is_processing = False
|
||||||
session.should_interrupt = False
|
session.should_interrupt = False
|
||||||
|
|
||||||
async def _process_audio(
|
|
||||||
self, session: LiveChatSession, audio_path: str, assemble_duration: float
|
|
||||||
) -> None:
|
|
||||||
"""处理音频:STT -> LLM -> 流式 TTS"""
|
|
||||||
try:
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"t": "metrics",
|
|
||||||
"data": {
|
|
||||||
"wav_assemble_time": assemble_duration,
|
|
||||||
"input_type": "audio",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
wav_assembly_finish_time = time.time()
|
|
||||||
|
|
||||||
# 1. STT - 语音转文字
|
|
||||||
ctx = self.plugin_manager.context
|
|
||||||
stt_provider = ctx.provider_manager.stt_provider_insts[0]
|
|
||||||
|
|
||||||
if not stt_provider:
|
|
||||||
logger.error("[Live Chat] STT Provider 未配置")
|
|
||||||
await websocket.send_json({"t": "error", "data": "语音识别服务未配置"})
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"t": "metrics",
|
|
||||||
"data": {
|
|
||||||
"stt": stt_provider.meta().type,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
user_text = await stt_provider.get_text(audio_path)
|
|
||||||
if not user_text:
|
|
||||||
logger.warning("[Live Chat] STT 识别结果为空")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"[Live Chat] STT 结果: {user_text}")
|
|
||||||
|
|
||||||
await self._process_live_user_text(
|
|
||||||
session,
|
|
||||||
user_text=user_text,
|
|
||||||
initial_metrics=None,
|
|
||||||
processing_start_time=wav_assembly_finish_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
|
|
||||||
await websocket.send_json({"t": "error", "data": f"处理失败: {str(e)}"})
|
|
||||||
|
|
||||||
async def _save_interrupted_message(
|
async def _save_interrupted_message(
|
||||||
self, session: LiveChatSession, user_text: str, bot_text: str
|
self, session: LiveChatSession, user_text: str, bot_text: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
|||||||
|
|
||||||
from .api_key import ALL_OPEN_API_SCOPES
|
from .api_key import ALL_OPEN_API_SCOPES
|
||||||
from .chat import ChatRoute
|
from .chat import ChatRoute
|
||||||
from .live_chat import LiveChatRoute
|
|
||||||
from .route import Response, Route, RouteContext
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
|
|
||||||
@@ -30,14 +29,12 @@ class OpenApiRoute(Route):
|
|||||||
db: BaseDatabase,
|
db: BaseDatabase,
|
||||||
core_lifecycle: AstrBotCoreLifecycle,
|
core_lifecycle: AstrBotCoreLifecycle,
|
||||||
chat_route: ChatRoute,
|
chat_route: ChatRoute,
|
||||||
live_chat_route: LiveChatRoute,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.db = db
|
self.db = db
|
||||||
self.core_lifecycle = core_lifecycle
|
self.core_lifecycle = core_lifecycle
|
||||||
self.platform_manager = core_lifecycle.platform_manager
|
self.platform_manager = core_lifecycle.platform_manager
|
||||||
self.chat_route = chat_route
|
self.chat_route = chat_route
|
||||||
self.live_chat_route = live_chat_route
|
|
||||||
|
|
||||||
self.routes = {
|
self.routes = {
|
||||||
"/v1/chat": ("POST", self.chat_send),
|
"/v1/chat": ("POST", self.chat_send),
|
||||||
@@ -49,7 +46,6 @@ class OpenApiRoute(Route):
|
|||||||
}
|
}
|
||||||
self.register_routes()
|
self.register_routes()
|
||||||
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
|
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
|
||||||
self.app.websocket("/api/v1/live/ws")(self.live_ws)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_open_username(
|
def _resolve_open_username(
|
||||||
@@ -538,39 +534,6 @@ class OpenApiRoute(Route):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Open API WS connection closed: %s", e)
|
logger.debug("Open API WS connection closed: %s", e)
|
||||||
|
|
||||||
async def live_ws(self) -> None:
|
|
||||||
authed, auth_err = await self._authenticate_chat_ws_api_key()
|
|
||||||
if not authed:
|
|
||||||
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
|
|
||||||
await websocket.close(1008, auth_err or "Unauthorized")
|
|
||||||
return
|
|
||||||
|
|
||||||
username, username_err = self._resolve_open_username(
|
|
||||||
websocket.args.get("username")
|
|
||||||
)
|
|
||||||
if username_err or not username:
|
|
||||||
await self._send_chat_ws_error(
|
|
||||||
username_err or "Invalid username",
|
|
||||||
"BAD_USER",
|
|
||||||
)
|
|
||||||
await websocket.close(1008, username_err or "Invalid username")
|
|
||||||
return
|
|
||||||
|
|
||||||
ct = websocket.args.get("ct")
|
|
||||||
force_ct = ct.strip() if isinstance(ct, str) and ct.strip() else "live"
|
|
||||||
if force_ct not in {"live", "chat"}:
|
|
||||||
await self._send_chat_ws_error(
|
|
||||||
"ct must be 'live' or 'chat'",
|
|
||||||
"INVALID_MESSAGE",
|
|
||||||
)
|
|
||||||
await websocket.close(1008, "Invalid ct")
|
|
||||||
return
|
|
||||||
|
|
||||||
await self.live_chat_route.run_ws_session(
|
|
||||||
username=username,
|
|
||||||
force_ct=force_ct,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def upload_file(self):
|
async def upload_file(self):
|
||||||
return await self.chat_route.post_file()
|
return await self.chat_route.post_file()
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import os
|
|||||||
import ssl
|
import ssl
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import certifi
|
import certifi
|
||||||
@@ -353,34 +352,6 @@ class PluginRoute(Route):
|
|||||||
logger.warning(f"获取插件 Logo 失败: {e}")
|
logger.warning(f"获取插件 Logo 失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _resolve_plugin_dir(self, plugin) -> Path | None:
|
|
||||||
if not plugin.root_dir_name:
|
|
||||||
return None
|
|
||||||
|
|
||||||
base_dir = Path(
|
|
||||||
self.plugin_manager.reserved_plugin_path
|
|
||||||
if plugin.reserved
|
|
||||||
else self.plugin_manager.plugin_store_path
|
|
||||||
)
|
|
||||||
plugin_dir = base_dir / plugin.root_dir_name
|
|
||||||
if not plugin_dir.is_dir():
|
|
||||||
return None
|
|
||||||
return plugin_dir
|
|
||||||
|
|
||||||
def _get_plugin_installed_at(self, plugin) -> str | None:
|
|
||||||
plugin_dir = self._resolve_plugin_dir(plugin)
|
|
||||||
if plugin_dir is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return datetime.fromtimestamp(
|
|
||||||
plugin_dir.stat().st_mtime,
|
|
||||||
timezone.utc,
|
|
||||||
).isoformat()
|
|
||||||
except OSError as exc:
|
|
||||||
logger.warning(f"获取插件安装时间失败 {plugin.name}: {exc!s}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_plugins(self):
|
async def get_plugins(self):
|
||||||
_plugin_resp = []
|
_plugin_resp = []
|
||||||
plugin_name = request.args.get("name")
|
plugin_name = request.args.get("name")
|
||||||
@@ -406,7 +377,6 @@ class PluginRoute(Route):
|
|||||||
"logo": f"/api/file/{logo_url}" if logo_url else None,
|
"logo": f"/api/file/{logo_url}" if logo_url else None,
|
||||||
"support_platforms": plugin.support_platforms,
|
"support_platforms": plugin.support_platforms,
|
||||||
"astrbot_version": plugin.astrbot_version,
|
"astrbot_version": plugin.astrbot_version,
|
||||||
"installed_at": self._get_plugin_installed_at(plugin),
|
|
||||||
}
|
}
|
||||||
# 检查是否为全空的幽灵插件
|
# 检查是否为全空的幽灵插件
|
||||||
if not any(
|
if not any(
|
||||||
|
|||||||
@@ -12,32 +12,6 @@ from .route import Response, Route, RouteContext
|
|||||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||||
|
|
||||||
|
|
||||||
class EmptyMcpServersError(ValueError):
|
|
||||||
"""Raised when mcpServers is empty."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
|
|
||||||
"""Extract server configuration from user-submitted mcpServers field.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: Invalid configuration
|
|
||||||
"""
|
|
||||||
if not isinstance(mcp_servers_value, dict):
|
|
||||||
raise ValueError("mcpServers must be a JSON object")
|
|
||||||
if not mcp_servers_value:
|
|
||||||
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
|
|
||||||
key_0 = next(iter(mcp_servers_value))
|
|
||||||
extracted = mcp_servers_value[key_0]
|
|
||||||
if not isinstance(extracted, dict):
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
|
|
||||||
"and each value is an object containing fields like command/url."
|
|
||||||
)
|
|
||||||
return extracted
|
|
||||||
|
|
||||||
|
|
||||||
class ToolsRoute(Route):
|
class ToolsRoute(Route):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -59,37 +33,13 @@ class ToolsRoute(Route):
|
|||||||
self.register_routes()
|
self.register_routes()
|
||||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||||
|
|
||||||
def _rollback_mcp_server(self, name: str) -> bool:
|
|
||||||
try:
|
|
||||||
rollback_config = self.tool_mgr.load_mcp_config()
|
|
||||||
if name in rollback_config["mcpServers"]:
|
|
||||||
rollback_config["mcpServers"].pop(name)
|
|
||||||
return self.tool_mgr.save_mcp_config(rollback_config)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_mcp_servers(self):
|
async def get_mcp_servers(self):
|
||||||
try:
|
try:
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
servers = []
|
servers = []
|
||||||
mcp_servers = config.get("mcpServers", {})
|
|
||||||
|
|
||||||
if not isinstance(mcp_servers, dict):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
|
|
||||||
)
|
|
||||||
mcp_servers = {}
|
|
||||||
|
|
||||||
# 获取所有服务器并添加它们的工具列表
|
# 获取所有服务器并添加它们的工具列表
|
||||||
for name, server_config in mcp_servers.items():
|
for name, server_config in config["mcpServers"].items():
|
||||||
if not isinstance(server_config, dict):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
server_info = {
|
server_info = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"active": server_config.get("active", True),
|
"active": server_config.get("active", True),
|
||||||
@@ -115,7 +65,7 @@ class ToolsRoute(Route):
|
|||||||
return Response().ok(servers).__dict__
|
return Response().ok(servers).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Failed to get MCP server list: {e!s}").__dict__
|
return Response().error(f"获取 MCP 服务器列表失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def add_mcp_server(self):
|
async def add_mcp_server(self):
|
||||||
try:
|
try:
|
||||||
@@ -125,7 +75,7 @@ class ToolsRoute(Route):
|
|||||||
|
|
||||||
# 检查必填字段
|
# 检查必填字段
|
||||||
if not name:
|
if not name:
|
||||||
return Response().error("Server name cannot be empty").__dict__
|
return Response().error("服务器名称不能为空").__dict__
|
||||||
|
|
||||||
# 移除特殊字段并检查配置是否有效
|
# 移除特殊字段并检查配置是否有效
|
||||||
has_valid_config = False
|
has_valid_config = False
|
||||||
@@ -135,33 +85,21 @@ class ToolsRoute(Route):
|
|||||||
for key, value in server_data.items():
|
for key, value in server_data.items():
|
||||||
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
|
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
|
||||||
if key == "mcpServers":
|
if key == "mcpServers":
|
||||||
try:
|
key_0 = list(server_data["mcpServers"].keys())[
|
||||||
server_config = _extract_mcp_server_config(
|
0
|
||||||
server_data["mcpServers"]
|
] # 不考虑为空的情况
|
||||||
)
|
server_config = server_data["mcpServers"][key_0]
|
||||||
except ValueError as e:
|
|
||||||
return Response().error(f"{e!s}").__dict__
|
|
||||||
else:
|
else:
|
||||||
server_config[key] = value
|
server_config[key] = value
|
||||||
has_valid_config = True
|
has_valid_config = True
|
||||||
|
|
||||||
if not has_valid_config:
|
if not has_valid_config:
|
||||||
return (
|
return Response().error("必须提供有效的服务器配置").__dict__
|
||||||
Response()
|
|
||||||
.error("A valid server configuration is required")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
|
|
||||||
if name in config["mcpServers"]:
|
if name in config["mcpServers"]:
|
||||||
return Response().error(f"Server {name} already exists").__dict__
|
return Response().error(f"服务器 {name} 已存在").__dict__
|
||||||
|
|
||||||
try:
|
|
||||||
await self.tool_mgr.test_mcp_server_connection(server_config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return Response().error(f"MCP connection test failed: {e!s}").__dict__
|
|
||||||
|
|
||||||
config["mcpServers"][name] = server_config
|
config["mcpServers"][name] = server_config
|
||||||
|
|
||||||
@@ -173,27 +111,17 @@ class ToolsRoute(Route):
|
|||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
rollback_ok = self._rollback_mcp_server(name)
|
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||||
err_msg = f"Timed out while enabling MCP server {name}."
|
|
||||||
if not rollback_ok:
|
|
||||||
err_msg += " Configuration rollback failed. Please check the config manually."
|
|
||||||
return Response().error(err_msg).__dict__
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
rollback_ok = self._rollback_mcp_server(name)
|
return (
|
||||||
err_msg = f"Failed to enable MCP server {name}: {e!s}"
|
Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__
|
||||||
if not rollback_ok:
|
)
|
||||||
err_msg += " Configuration rollback failed. Please check the config manually."
|
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
|
||||||
return Response().error(err_msg).__dict__
|
return Response().error("保存配置失败").__dict__
|
||||||
return (
|
|
||||||
Response()
|
|
||||||
.ok(None, f"Successfully added MCP server {name}")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
return Response().error("Failed to save configuration").__dict__
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Failed to add MCP server: {e!s}").__dict__
|
return Response().error(f"添加 MCP 服务器失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def update_mcp_server(self):
|
async def update_mcp_server(self):
|
||||||
try:
|
try:
|
||||||
@@ -203,25 +131,23 @@ class ToolsRoute(Route):
|
|||||||
old_name = server_data.get("oldName") or name
|
old_name = server_data.get("oldName") or name
|
||||||
|
|
||||||
if not name:
|
if not name:
|
||||||
return Response().error("Server name cannot be empty").__dict__
|
return Response().error("服务器名称不能为空").__dict__
|
||||||
|
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
|
|
||||||
if old_name not in config["mcpServers"]:
|
if old_name not in config["mcpServers"]:
|
||||||
return Response().error(f"Server {old_name} does not exist").__dict__
|
return Response().error(f"服务器 {old_name} 不存在").__dict__
|
||||||
|
|
||||||
is_rename = name != old_name
|
is_rename = name != old_name
|
||||||
|
|
||||||
if name in config["mcpServers"] and is_rename:
|
if name in config["mcpServers"] and is_rename:
|
||||||
return Response().error(f"Server {name} already exists").__dict__
|
return Response().error(f"服务器 {name} 已存在").__dict__
|
||||||
|
|
||||||
# 获取活动状态
|
# 获取活动状态
|
||||||
old_config = config["mcpServers"][old_name]
|
active = server_data.get(
|
||||||
if isinstance(old_config, dict):
|
"active",
|
||||||
old_active = old_config.get("active", True)
|
config["mcpServers"][old_name].get("active", True),
|
||||||
else:
|
)
|
||||||
old_active = True
|
|
||||||
active = server_data.get("active", old_active)
|
|
||||||
|
|
||||||
# 创建新的配置对象
|
# 创建新的配置对象
|
||||||
server_config = {"active": active}
|
server_config = {"active": active}
|
||||||
@@ -239,19 +165,17 @@ class ToolsRoute(Route):
|
|||||||
"oldName",
|
"oldName",
|
||||||
]: # 排除特殊字段
|
]: # 排除特殊字段
|
||||||
if key == "mcpServers":
|
if key == "mcpServers":
|
||||||
try:
|
key_0 = list(server_data["mcpServers"].keys())[
|
||||||
server_config = _extract_mcp_server_config(
|
0
|
||||||
server_data["mcpServers"]
|
] # 不考虑为空的情况
|
||||||
)
|
server_config = server_data["mcpServers"][key_0]
|
||||||
except ValueError as e:
|
|
||||||
return Response().error(f"{e!s}").__dict__
|
|
||||||
else:
|
else:
|
||||||
server_config[key] = value
|
server_config[key] = value
|
||||||
only_update_active = False
|
only_update_active = False
|
||||||
|
|
||||||
# 如果只更新活动状态,保留原始配置
|
# 如果只更新活动状态,保留原始配置
|
||||||
if only_update_active and isinstance(old_config, dict):
|
if only_update_active:
|
||||||
for key, value in old_config.items():
|
for key, value in config["mcpServers"][old_name].items():
|
||||||
if key != "active": # 除了active之外的所有字段都保留
|
if key != "active": # 除了active之外的所有字段都保留
|
||||||
server_config[key] = value
|
server_config[key] = value
|
||||||
|
|
||||||
@@ -276,7 +200,7 @@ class ToolsRoute(Route):
|
|||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(
|
.error(
|
||||||
f"Timed out while disabling MCP server {old_name} before enabling: {e!s}"
|
f"启用前停用 MCP 服务器时 {old_name} 超时: {e!s}"
|
||||||
)
|
)
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
@@ -285,7 +209,7 @@ class ToolsRoute(Route):
|
|||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(
|
.error(
|
||||||
f"Failed to disable MCP server {old_name} before enabling: {e!s}"
|
f"启用前停用 MCP 服务器时 {old_name} 失败: {e!s}"
|
||||||
)
|
)
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
@@ -297,15 +221,13 @@ class ToolsRoute(Route):
|
|||||||
)
|
)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
return (
|
return (
|
||||||
Response()
|
Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||||
.error(f"Timed out while enabling MCP server {name}.")
|
|
||||||
.__dict__
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"Failed to enable MCP server {name}: {e!s}")
|
.error(f"启用 MCP 服务器 {name} 失败: {e!s}")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
# 如果要停用服务器
|
# 如果要停用服务器
|
||||||
@@ -315,26 +237,22 @@ class ToolsRoute(Route):
|
|||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"Timed out while disabling MCP server {old_name}.")
|
.error(f"停用 MCP 服务器 {old_name} 超时。")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"Failed to disable MCP server {old_name}: {e!s}")
|
.error(f"停用 MCP 服务器 {old_name} 失败: {e!s}")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
|
||||||
Response()
|
return Response().error("保存配置失败").__dict__
|
||||||
.ok(None, f"Successfully updated MCP server {name}")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
return Response().error("Failed to save configuration").__dict__
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Failed to update MCP server: {e!s}").__dict__
|
return Response().error(f"更新 MCP 服务器失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def delete_mcp_server(self):
|
async def delete_mcp_server(self):
|
||||||
try:
|
try:
|
||||||
@@ -342,12 +260,12 @@ class ToolsRoute(Route):
|
|||||||
name = server_data.get("name", "")
|
name = server_data.get("name", "")
|
||||||
|
|
||||||
if not name:
|
if not name:
|
||||||
return Response().error("Server name cannot be empty").__dict__
|
return Response().error("服务器名称不能为空").__dict__
|
||||||
|
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
|
|
||||||
if name not in config["mcpServers"]:
|
if name not in config["mcpServers"]:
|
||||||
return Response().error(f"Server {name} does not exist").__dict__
|
return Response().error(f"服务器 {name} 不存在").__dict__
|
||||||
|
|
||||||
del config["mcpServers"][name]
|
del config["mcpServers"][name]
|
||||||
|
|
||||||
@@ -357,76 +275,51 @@ class ToolsRoute(Route):
|
|||||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
return (
|
return (
|
||||||
Response()
|
Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
|
||||||
.error(f"Timed out while disabling MCP server {name}.")
|
|
||||||
.__dict__
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"Failed to disable MCP server {name}: {e!s}")
|
.error(f"停用 MCP 服务器 {name} 失败: {e!s}")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
return (
|
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
|
||||||
Response()
|
return Response().error("保存配置失败").__dict__
|
||||||
.ok(None, f"Successfully deleted MCP server {name}")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
return Response().error("Failed to save configuration").__dict__
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Failed to delete MCP server: {e!s}").__dict__
|
return Response().error(f"删除 MCP 服务器失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def test_mcp_connection(self):
|
async def test_mcp_connection(self):
|
||||||
"""Test MCP server connection."""
|
"""测试 MCP 服务器连接"""
|
||||||
try:
|
try:
|
||||||
server_data = await request.json
|
server_data = await request.json
|
||||||
config = server_data.get("mcp_server_config", None)
|
config = server_data.get("mcp_server_config", None)
|
||||||
|
|
||||||
if not isinstance(config, dict) or not config:
|
if not isinstance(config, dict) or not config:
|
||||||
return Response().error("Invalid MCP server configuration").__dict__
|
return Response().error("无效的 MCP 服务器配置").__dict__
|
||||||
|
|
||||||
if "mcpServers" in config:
|
if "mcpServers" in config:
|
||||||
mcp_servers = config["mcpServers"]
|
keys = list(config["mcpServers"].keys())
|
||||||
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
|
if not keys:
|
||||||
return (
|
return Response().error("MCP 服务器配置不能为空").__dict__
|
||||||
Response()
|
if len(keys) > 1:
|
||||||
.error(
|
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
|
||||||
"Only one MCP server configuration can be tested at a time"
|
config = config["mcpServers"][keys[0]]
|
||||||
)
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
config = _extract_mcp_server_config(mcp_servers)
|
|
||||||
except EmptyMcpServersError:
|
|
||||||
return (
|
|
||||||
Response()
|
|
||||||
.error("MCP server configuration cannot be empty")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
return Response().error(f"{e!s}").__dict__
|
|
||||||
elif not config:
|
elif not config:
|
||||||
return (
|
return Response().error("MCP 服务器配置不能为空").__dict__
|
||||||
Response()
|
|
||||||
.error("MCP server configuration cannot be empty")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||||
return (
|
return (
|
||||||
Response()
|
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
||||||
.ok(data=tools_name, message="🎉 MCP server is available!")
|
|
||||||
.__dict__
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Failed to test MCP connection: {e!s}").__dict__
|
return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def get_tool_list(self):
|
async def get_tool_list(self):
|
||||||
"""Get all registered tools."""
|
"""获取所有注册的工具列表"""
|
||||||
try:
|
try:
|
||||||
tools = self.tool_mgr.func_list
|
tools = self.tool_mgr.func_list
|
||||||
tools_dict = []
|
tools_dict = []
|
||||||
@@ -456,44 +349,36 @@ class ToolsRoute(Route):
|
|||||||
return Response().ok(data=tools_dict).__dict__
|
return Response().ok(data=tools_dict).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Failed to get tool list: {e!s}").__dict__
|
return Response().error(f"获取工具列表失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def toggle_tool(self):
|
async def toggle_tool(self):
|
||||||
"""Activate or deactivate a specified tool."""
|
"""启用或停用指定的工具"""
|
||||||
try:
|
try:
|
||||||
data = await request.json
|
data = await request.json
|
||||||
tool_name = data.get("name")
|
tool_name = data.get("name")
|
||||||
action = data.get("activate") # True or False
|
action = data.get("activate") # True or False
|
||||||
|
|
||||||
if not tool_name or action is None:
|
if not tool_name or action is None:
|
||||||
return (
|
return Response().error("缺少必要参数: name 或 action").__dict__
|
||||||
Response()
|
|
||||||
.error("Missing required parameters: name or activate")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
try:
|
try:
|
||||||
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return Response().error(f"Failed to activate tool: {e!s}").__dict__
|
return Response().error(f"启用工具失败: {e!s}").__dict__
|
||||||
else:
|
else:
|
||||||
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
||||||
|
|
||||||
if ok:
|
if ok:
|
||||||
return Response().ok(None, "Operation successful.").__dict__
|
return Response().ok(None, "操作成功。").__dict__
|
||||||
return (
|
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
|
||||||
Response()
|
|
||||||
.error(f"Tool {tool_name} does not exist or the operation failed.")
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Failed to operate tool: {e!s}").__dict__
|
return Response().error(f"操作工具失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def sync_provider(self):
|
async def sync_provider(self):
|
||||||
"""Sync MCP provider configuration."""
|
"""同步 MCP 提供者配置"""
|
||||||
try:
|
try:
|
||||||
data = await request.json
|
data = await request.json
|
||||||
provider_name = data.get("name") # modelscope, or others
|
provider_name = data.get("name") # modelscope, or others
|
||||||
@@ -502,11 +387,9 @@ class ToolsRoute(Route):
|
|||||||
access_token = data.get("access_token", "")
|
access_token = data.get("access_token", "")
|
||||||
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
|
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
|
||||||
case _:
|
case _:
|
||||||
return (
|
return Response().error(f"未知: {provider_name}").__dict__
|
||||||
Response().error(f"Unknown provider: {provider_name}").__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response().ok(message="Sync completed").__dict__
|
return Response().ok(message="同步成功").__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"Sync failed: {e!s}").__dict__
|
return Response().error(f"同步失败: {e!s}").__dict__
|
||||||
|
|||||||
@@ -115,13 +115,11 @@ class AstrBotDashboard:
|
|||||||
self.ar = AuthRoute(self.context)
|
self.ar = AuthRoute(self.context)
|
||||||
self.api_key_route = ApiKeyRoute(self.context, db)
|
self.api_key_route = ApiKeyRoute(self.context, db)
|
||||||
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
|
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
|
||||||
self.live_chat_route = LiveChatRoute(self.context, db, core_lifecycle)
|
|
||||||
self.open_api_route = OpenApiRoute(
|
self.open_api_route = OpenApiRoute(
|
||||||
self.context,
|
self.context,
|
||||||
db,
|
db,
|
||||||
core_lifecycle,
|
core_lifecycle,
|
||||||
self.chat_route,
|
self.chat_route,
|
||||||
self.live_chat_route,
|
|
||||||
)
|
)
|
||||||
self.chatui_project_route = ChatUIProjectRoute(self.context, db)
|
self.chatui_project_route = ChatUIProjectRoute(self.context, db)
|
||||||
self.tools_root = ToolsRoute(self.context, core_lifecycle)
|
self.tools_root = ToolsRoute(self.context, core_lifecycle)
|
||||||
@@ -140,6 +138,7 @@ class AstrBotDashboard:
|
|||||||
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
||||||
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
||||||
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
|
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
|
||||||
|
self.live_chat_route = LiveChatRoute(self.context, db, core_lifecycle)
|
||||||
|
|
||||||
self.app.add_url_rule(
|
self.app.add_url_rule(
|
||||||
"/api/plug/<path:subpath>",
|
"/api/plug/<path:subpath>",
|
||||||
@@ -245,7 +244,6 @@ class AstrBotDashboard:
|
|||||||
scope_map = {
|
scope_map = {
|
||||||
"/api/v1/chat": "chat",
|
"/api/v1/chat": "chat",
|
||||||
"/api/v1/chat/ws": "chat",
|
"/api/v1/chat/ws": "chat",
|
||||||
"/api/v1/live/ws": "chat",
|
|
||||||
"/api/v1/chat/sessions": "chat",
|
"/api/v1/chat/sessions": "chat",
|
||||||
"/api/v1/configs": "config",
|
"/api/v1/configs": "config",
|
||||||
"/api/v1/file": "file",
|
"/api/v1/file": "file",
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
## What's Changed
|
|
||||||
|
|
||||||
### 新增
|
|
||||||
|
|
||||||
- 新增技能 ZIP 批量上传能力 ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804))。
|
|
||||||
|
|
||||||
### 修复
|
|
||||||
|
|
||||||
- 修复 MCP Server 配置异常时可能导致崩溃的问题 ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673))。
|
|
||||||
- 修复钉钉适配器文本消息被忽略、无法主动发送文件的问题 ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921))。
|
|
||||||
- 修复钉钉适配器无法接收图片与文件的问题 ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920))。
|
|
||||||
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))。
|
|
||||||
- 修复 OpenRouter `api_base` 配置错误的问题 ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911))。
|
|
||||||
- 修复插件市场中按展示名搜索已安装插件不生效的问题 ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811))。
|
|
||||||
- 修复仅图片响应未应用 `reply_with_quote` 与 `reply_with_mention` 的问题 ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219))。
|
|
||||||
- 修复 `RegexFilter` 使用 `re.match` 导致匹配范围不正确的问题 ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368))。
|
|
||||||
- 修复桌面运行环境检测依赖 frozen Python 的问题 ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859))。
|
|
||||||
- 修复通过“创建新配置”创建平台机器人后找不到 pipeline scheduler 的问题 ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776))。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## What's Changed (EN)
|
|
||||||
|
|
||||||
### New Features
|
|
||||||
|
|
||||||
- Added batch upload support for multiple skill ZIP files ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804)).
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
- Fixed potential crash on malformed MCP server config ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673)).
|
|
||||||
- Fixed DingTalk adapter issue where text messages were ignored and files could not be sent proactively ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921)).
|
|
||||||
- Fixed DingTalk adapter issue where image and file messages could not be received ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920)).
|
|
||||||
- Fixed incorrect OpenRouter `api_base` configuration ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911)).
|
|
||||||
- Fixed searching installed plugins by display name in extensions ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811)).
|
|
||||||
- Fixed image-only responses not applying `reply_with_quote` and `reply_with_mention` ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219)).
|
|
||||||
- Fixed `RegexFilter` using `re.match` instead of `re.search` for expected matching behavior ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368)).
|
|
||||||
- Fixed desktop runtime detection requiring frozen Python ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859)).
|
|
||||||
- Fixed missing pipeline scheduler after creating a platform bot via "create new config" ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776)).
|
|
||||||
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))
|
|
||||||
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
## What's Changed
|
|
||||||
|
|
||||||
### 新增
|
|
||||||
|
|
||||||
- 企业微信智能机器人支持长连接模式。[#5930](https://github.com/AstrBotDevs/AstrBot/pull/5930)
|
|
||||||
|
|
||||||
### New
|
|
||||||
|
|
||||||
- Wecom AI Bot supports long-connection mode(Websockets). [#5930](https://github.com/AstrBotDevs/AstrBot/pull/5930)
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
## What's Changed
|
|
||||||
|
|
||||||
### 新增
|
|
||||||
|
|
||||||
- Lark 适配器支持 CardKit 流式输出(飞书)([#5777](https://github.com/AstrBotDevs/AstrBot/pull/5777))。
|
|
||||||
- WebUI 已安装插件列表新增筛选与排序功能 ([#5923](https://github.com/AstrBotDevs/AstrBot/pull/5923))。
|
|
||||||
|
|
||||||
### 优化
|
|
||||||
- 启动时后台加载 MCP Server,不阻塞加载流程 ([#5993](https://github.com/AstrBotDevs/AstrBot/pull/5993))。
|
|
||||||
|
|
||||||
### 修复
|
|
||||||
|
|
||||||
- 部分情况下 MCP 页报错 500 导致查看不了 MCP 服务器 ([#5993](https://github.com/AstrBotDevs/AstrBot/pull/5993))。
|
|
||||||
- 修复 TTS Provider 测试:增加文件大小校验,并补充 MiniMax 空音频检测 ([#5999](https://github.com/AstrBotDevs/AstrBot/pull/5999))。
|
|
||||||
- 修复前端切换到 Chat 后又回到 Welcome 时,页面切换配置未正确持久化的问题 ([#5792](https://github.com/AstrBotDevs/AstrBot/pull/5792))。
|
|
||||||
- 修复 Azure TTS 不支持 84 位订阅密钥的问题 ([#5813](https://github.com/AstrBotDevs/AstrBot/pull/5813))。
|
|
||||||
|
|
||||||
### 文档
|
|
||||||
|
|
||||||
- 文档仓库迁移:将 `AstrBotDevs/AstrBot-docs` 内容迁移至 `AstrBotDevs/AstrBot` ([#5960](https://github.com/AstrBotDevs/AstrBot/pull/5960))。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## What's Changed (EN)
|
|
||||||
|
|
||||||
### New Features
|
|
||||||
|
|
||||||
- Added CardKit streaming output support for the Lark/Feishu adapter ([#5777](https://github.com/AstrBotDevs/AstrBot/pull/5777)).
|
|
||||||
- Added filtering and sorting for installed plugins in the WebUI ([#5923](https://github.com/AstrBotDevs/AstrBot/pull/5923)).
|
|
||||||
|
|
||||||
### Impprovement
|
|
||||||
- MCP Server now loads in the background during startup without blocking the loading process ([#5993](https://github.com/AstrBotDevs/AstrBot/pull/5993)).
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
- Added file size validation in TTS provider tests and MiniMax empty-audio detection ([#5999](https://github.com/AstrBotDevs/AstrBot/pull/5999)).
|
|
||||||
- Fixed frontend state persistence when switching from Chat back to Welcome ([#5792](https://github.com/AstrBotDevs/AstrBot/pull/5792)).
|
|
||||||
- Fixed Azure TTS support for 84-character subscription keys ([#5813](https://github.com/AstrBotDevs/AstrBot/pull/5813)).
|
|
||||||
- Reverted the MCP stdio missing-command error wording change after the previous fix ([#5992](https://github.com/AstrBotDevs/AstrBot/pull/5992)).
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
|
|
||||||
- Migrated documentation content from `AstrBotDevs/AstrBot-docs` into `AstrBotDevs/AstrBot` ([#5960](https://github.com/AstrBotDevs/AstrBot/pull/5960)).
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
## What's Changed
|
|
||||||
|
|
||||||
### 新增
|
|
||||||
|
|
||||||
- 新增俄语翻译([#6081](https://github.com/AstrBotDevs/AstrBot/pull/6081))。
|
|
||||||
- QQ 官方 Bot 新增文件、语音、视频消息支持(含 WebSocket 模式)([#6063](https://github.com/AstrBotDevs/AstrBot/pull/6063))。
|
|
||||||
|
|
||||||
### 优化
|
|
||||||
|
|
||||||
- 优化 QQ 官方 Bot 的流式消息投递可靠性与主动媒体发送能力([#6131](https://github.com/AstrBotDevs/AstrBot/pull/6131))。
|
|
||||||
- 优化边界场景下 booter 选择逻辑与消息发送工具([#6064](https://github.com/AstrBotDevs/AstrBot/pull/6064))。
|
|
||||||
|
|
||||||
### 修复
|
|
||||||
|
|
||||||
- 修复 Dashboard README 对话框锚点导航失效([#6083](https://github.com/AstrBotDevs/AstrBot/pull/6083))。
|
|
||||||
- 优先使用具名 weekday 的 cron 示例,避免歧义([#6091](https://github.com/AstrBotDevs/AstrBot/pull/6091))。
|
|
||||||
- 修复插件市场安装后状态未及时刷新的问题([#6124](https://github.com/AstrBotDevs/AstrBot/pull/6124))。
|
|
||||||
- 修复插件依赖安装逻辑:仅安装缺失依赖([#6088](https://github.com/AstrBotDevs/AstrBot/pull/6088))。
|
|
||||||
- 移除 Telegram 适配器中已废弃的 `normalize_whitespace` 参数([#6044](https://github.com/AstrBotDevs/AstrBot/pull/6044))。
|
|
||||||
- 修复 Windows 本地 skill 文件读取问题([#6028](https://github.com/AstrBotDevs/AstrBot/pull/6028))。
|
|
||||||
- 修复 Discord pre-ack emoji 配置重启后不持久化的问题([#6031](https://github.com/AstrBotDevs/AstrBot/pull/6031))。
|
|
||||||
- 统一 WebUI 搜索框清空行为([#6017](https://github.com/AstrBotDevs/AstrBot/pull/6017))。
|
|
||||||
- 优化插件依赖自动安装流程与 Dashboard 安装体验([#5954](https://github.com/AstrBotDevs/AstrBot/pull/5954))。
|
|
||||||
|
|
||||||
|
|
||||||
### 文档
|
|
||||||
|
|
||||||
- 新增 Astrbook 和玖帕喵社区链接([#6135](https://github.com/AstrBotDevs/AstrBot/pull/6135))。
|
|
||||||
- 修正文档 `docker.md` 与 `napcat.md` 中的拼写错误([#6048](https://github.com/AstrBotDevs/AstrBot/pull/6048))。
|
|
||||||
- 在多语言 README 中补充官方开发群号,并改进配置元数据中的正则说明。
|
|
||||||
- 更新编辑链接模式并移除过时仓库引用。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## What's Changed (EN)
|
|
||||||
|
|
||||||
### New Features
|
|
||||||
|
|
||||||
- Added Russian translation support ([#6081](https://github.com/AstrBotDevs/AstrBot/pull/6081)).
|
|
||||||
- Added file, voice, and video message support for QQ Official Bot (including WebSocket mode) ([#6063](https://github.com/AstrBotDevs/AstrBot/pull/6063)).
|
|
||||||
|
|
||||||
### Improvements
|
|
||||||
|
|
||||||
- Improved streaming message delivery reliability and proactive media sending for QQ Official API ([#6131](https://github.com/AstrBotDevs/AstrBot/pull/6131)).
|
|
||||||
- Optimized booter selection logic in edge cases and message sending tooling ([#6064](https://github.com/AstrBotDevs/AstrBot/pull/6064)).
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
- Fixed broken README dialog anchor navigation in the Dashboard ([#6083](https://github.com/AstrBotDevs/AstrBot/pull/6083)).
|
|
||||||
- Preferred named weekday cron examples to reduce ambiguity ([#6091](https://github.com/AstrBotDevs/AstrBot/pull/6091)).
|
|
||||||
- Fixed plugin market install-state refresh after installation ([#6124](https://github.com/AstrBotDevs/AstrBot/pull/6124)).
|
|
||||||
- Fixed plugin dependency installation logic to install only missing packages ([#6088](https://github.com/AstrBotDevs/AstrBot/pull/6088)).
|
|
||||||
- Removed deprecated `normalize_whitespace` parameter in the Telegram adapter ([#6044](https://github.com/AstrBotDevs/AstrBot/pull/6044)).
|
|
||||||
- Fixed local skill file reading issues on Windows ([#6028](https://github.com/AstrBotDevs/AstrBot/pull/6028)).
|
|
||||||
- Fixed Discord pre-ack emoji config not being persisted across restarts ([#6031](https://github.com/AstrBotDevs/AstrBot/pull/6031)).
|
|
||||||
- Unified WebUI search input clear behavior ([#6017](https://github.com/AstrBotDevs/AstrBot/pull/6017)).
|
|
||||||
- Improved plugin dependency auto-install flow and Dashboard installation experience ([#5954](https://github.com/AstrBotDevs/AstrBot/pull/5954)).
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
|
|
||||||
- Added Astrbook and Jiupa Miao community links ([#6135](https://github.com/AstrBotDevs/AstrBot/pull/6135)).
|
|
||||||
- Fixed typos in `docker.md` and `napcat.md` ([#6048](https://github.com/AstrBotDevs/AstrBot/pull/6048)).
|
|
||||||
- Added official developer group IDs to multilingual READMEs and improved regex description in config metadata.
|
|
||||||
- Updated edit-link patterns and removed obsolete repository references.
|
|
||||||
@@ -37,7 +37,6 @@ services:
|
|||||||
- DEFAULT_SHIP_MEMORY=512m
|
- DEFAULT_SHIP_MEMORY=512m
|
||||||
volumes:
|
volumes:
|
||||||
- ${PWD}/data/shipyard/bay_data:/app/data
|
- ${PWD}/data/shipyard/bay_data:/app/data
|
||||||
- ${PWD}/data/temp:/AstrBot/data/temp # Bind the local temp directory to the sandbox so that the uploaded file can be accessed in the sandbox
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
networks:
|
networks:
|
||||||
- astrbot_network
|
- astrbot_network
|
||||||
|
|||||||
+8
-13
@@ -17,17 +17,17 @@
|
|||||||
"@tiptap/starter-kit": "2.1.7",
|
"@tiptap/starter-kit": "2.1.7",
|
||||||
"@tiptap/vue-3": "2.1.7",
|
"@tiptap/vue-3": "2.1.7",
|
||||||
"apexcharts": "3.42.0",
|
"apexcharts": "3.42.0",
|
||||||
"axios": "1.13.5",
|
"axios": ">=1.6.2 <1.10.0 || >1.10.0 <2.0.0",
|
||||||
"axios-mock-adapter": "^1.22.0",
|
"axios-mock-adapter": "^1.22.0",
|
||||||
"chance": "1.1.11",
|
"chance": "1.1.11",
|
||||||
"date-fns": "2.30.0",
|
"date-fns": "2.30.0",
|
||||||
"dompurify": "^3.3.2",
|
"dompurify": "^3.3.1",
|
||||||
"event-source-polyfill": "^1.0.31",
|
"event-source-polyfill": "^1.0.31",
|
||||||
"highlight.js": "^11.11.1",
|
"highlight.js": "^11.11.1",
|
||||||
"js-md5": "^0.8.3",
|
"js-md5": "^0.8.3",
|
||||||
"katex": "^0.16.27",
|
"katex": "^0.16.27",
|
||||||
"lodash": "4.17.23",
|
"lodash": "4.17.21",
|
||||||
"markdown-it": "^14.1.1",
|
"markdown-it": "^14.1.0",
|
||||||
"markstream-vue": "^0.0.6",
|
"markstream-vue": "^0.0.6",
|
||||||
"mermaid": "^11.12.2",
|
"mermaid": "^11.12.2",
|
||||||
"monaco-editor": "^0.52.2",
|
"monaco-editor": "^0.52.2",
|
||||||
@@ -36,8 +36,9 @@
|
|||||||
"remixicon": "3.5.0",
|
"remixicon": "3.5.0",
|
||||||
"shiki": "^3.20.0",
|
"shiki": "^3.20.0",
|
||||||
"stream-markdown": "^0.0.13",
|
"stream-markdown": "^0.0.13",
|
||||||
|
"stream-monaco": "^0.0.17",
|
||||||
"vee-validate": "4.11.3",
|
"vee-validate": "4.11.3",
|
||||||
"vite-plugin-vuetify": "2.1.3",
|
"vite-plugin-vuetify": "1.0.2",
|
||||||
"vue": "3.3.4",
|
"vue": "3.3.4",
|
||||||
"vue-i18n": "^11.1.5",
|
"vue-i18n": "^11.1.5",
|
||||||
"vue-router": "4.2.4",
|
"vue-router": "4.2.4",
|
||||||
@@ -53,7 +54,7 @@
|
|||||||
"@types/dompurify": "^3.0.5",
|
"@types/dompurify": "^3.0.5",
|
||||||
"@types/markdown-it": "^14.1.2",
|
"@types/markdown-it": "^14.1.2",
|
||||||
"@types/node": "^20.5.7",
|
"@types/node": "^20.5.7",
|
||||||
"@vitejs/plugin-vue": "5.2.4",
|
"@vitejs/plugin-vue": "4.3.3",
|
||||||
"@vue/eslint-config-prettier": "8.0.0",
|
"@vue/eslint-config-prettier": "8.0.0",
|
||||||
"@vue/eslint-config-typescript": "11.0.3",
|
"@vue/eslint-config-typescript": "11.0.3",
|
||||||
"@vue/tsconfig": "^0.4.0",
|
"@vue/tsconfig": "^0.4.0",
|
||||||
@@ -63,15 +64,9 @@
|
|||||||
"sass": "1.66.1",
|
"sass": "1.66.1",
|
||||||
"sass-loader": "13.3.2",
|
"sass-loader": "13.3.2",
|
||||||
"typescript": "5.1.6",
|
"typescript": "5.1.6",
|
||||||
"vite": "6.4.1",
|
"vite": "4.4.9",
|
||||||
"vue-cli-plugin-vuetify": "2.5.8",
|
"vue-cli-plugin-vuetify": "2.5.8",
|
||||||
"vue-tsc": "1.8.8",
|
"vue-tsc": "1.8.8",
|
||||||
"vuetify-loader": "^2.0.0-alpha.9"
|
"vuetify-loader": "^2.0.0-alpha.9"
|
||||||
},
|
|
||||||
"pnpm": {
|
|
||||||
"overrides": {
|
|
||||||
"immutable": "4.3.8",
|
|
||||||
"lodash-es": "4.17.23"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Generated
+271
-601
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,6 @@
|
|||||||
:currSessionId="currSessionId"
|
:currSessionId="currSessionId"
|
||||||
:selectedProjectId="selectedProjectId"
|
:selectedProjectId="selectedProjectId"
|
||||||
:transportMode="transportMode"
|
:transportMode="transportMode"
|
||||||
:sendShortcut="sendShortcut"
|
|
||||||
:isDark="isDark"
|
:isDark="isDark"
|
||||||
:chatboxMode="chatboxMode"
|
:chatboxMode="chatboxMode"
|
||||||
:isMobile="isMobile"
|
:isMobile="isMobile"
|
||||||
@@ -21,7 +20,6 @@
|
|||||||
@selectConversation="handleSelectConversation"
|
@selectConversation="handleSelectConversation"
|
||||||
@editTitle="showEditTitleDialog"
|
@editTitle="showEditTitleDialog"
|
||||||
@deleteConversation="handleDeleteConversation"
|
@deleteConversation="handleDeleteConversation"
|
||||||
@batchDeleteConversations="handleBatchDeleteConversations"
|
|
||||||
@closeMobileSidebar="closeMobileSidebar"
|
@closeMobileSidebar="closeMobileSidebar"
|
||||||
@toggleTheme="toggleTheme"
|
@toggleTheme="toggleTheme"
|
||||||
@toggleFullscreen="toggleFullscreen"
|
@toggleFullscreen="toggleFullscreen"
|
||||||
@@ -30,7 +28,6 @@
|
|||||||
@editProject="showEditProjectDialog"
|
@editProject="showEditProjectDialog"
|
||||||
@deleteProject="handleDeleteProject"
|
@deleteProject="handleDeleteProject"
|
||||||
@updateTransportMode="setTransportMode"
|
@updateTransportMode="setTransportMode"
|
||||||
@updateSendShortcut="setSendShortcut"
|
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<!-- 右侧聊天内容区域 -->
|
<!-- 右侧聊天内容区域 -->
|
||||||
@@ -74,14 +71,13 @@
|
|||||||
:stagedImagesUrl="stagedImagesUrl"
|
:stagedImagesUrl="stagedImagesUrl"
|
||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:stagedFiles="stagedNonImageFiles"
|
:stagedFiles="stagedNonImageFiles"
|
||||||
:disabled="false"
|
:disabled="isStreaming"
|
||||||
:is-running="isStreaming || isConvRunning"
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
:session-id="currSessionId || null"
|
:session-id="currSessionId || null"
|
||||||
:current-session="getCurrentSession"
|
:current-session="getCurrentSession"
|
||||||
:replyTo="replyTo"
|
:replyTo="replyTo"
|
||||||
:send-shortcut="sendShortcut"
|
|
||||||
@send="handleSendMessage"
|
@send="handleSendMessage"
|
||||||
@stop="handleStopMessage"
|
@stop="handleStopMessage"
|
||||||
@toggleStreaming="toggleStreaming"
|
@toggleStreaming="toggleStreaming"
|
||||||
@@ -106,14 +102,13 @@
|
|||||||
:stagedImagesUrl="stagedImagesUrl"
|
:stagedImagesUrl="stagedImagesUrl"
|
||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:stagedFiles="stagedNonImageFiles"
|
:stagedFiles="stagedNonImageFiles"
|
||||||
:disabled="false"
|
:disabled="isStreaming"
|
||||||
:is-running="isStreaming || isConvRunning"
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
:session-id="currSessionId || null"
|
:session-id="currSessionId || null"
|
||||||
:current-session="getCurrentSession"
|
:current-session="getCurrentSession"
|
||||||
:replyTo="replyTo"
|
:replyTo="replyTo"
|
||||||
:send-shortcut="sendShortcut"
|
|
||||||
@send="handleSendMessage"
|
@send="handleSendMessage"
|
||||||
@stop="handleStopMessage"
|
@stop="handleStopMessage"
|
||||||
@toggleStreaming="toggleStreaming"
|
@toggleStreaming="toggleStreaming"
|
||||||
@@ -137,14 +132,13 @@
|
|||||||
:stagedImagesUrl="stagedImagesUrl"
|
:stagedImagesUrl="stagedImagesUrl"
|
||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:stagedFiles="stagedNonImageFiles"
|
:stagedFiles="stagedNonImageFiles"
|
||||||
:disabled="false"
|
:disabled="isStreaming"
|
||||||
:is-running="isStreaming || isConvRunning"
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
:session-id="currSessionId || null"
|
:session-id="currSessionId || null"
|
||||||
:current-session="getCurrentSession"
|
:current-session="getCurrentSession"
|
||||||
:replyTo="replyTo"
|
:replyTo="replyTo"
|
||||||
:send-shortcut="sendShortcut"
|
|
||||||
@send="handleSendMessage"
|
@send="handleSendMessage"
|
||||||
@stop="handleStopMessage"
|
@stop="handleStopMessage"
|
||||||
@toggleStreaming="toggleStreaming"
|
@toggleStreaming="toggleStreaming"
|
||||||
@@ -226,13 +220,10 @@ import { useMediaHandling } from '@/composables/useMediaHandling';
|
|||||||
import { useProjects } from '@/composables/useProjects';
|
import { useProjects } from '@/composables/useProjects';
|
||||||
import type { Project } from '@/components/chat/ProjectList.vue';
|
import type { Project } from '@/components/chat/ProjectList.vue';
|
||||||
import { useRecording } from '@/composables/useRecording';
|
import { useRecording } from '@/composables/useRecording';
|
||||||
import { useToast } from '@/utils/toast';
|
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
chatboxMode?: boolean;
|
chatboxMode?: boolean;
|
||||||
}
|
}
|
||||||
type SendShortcut = 'enter' | 'shift_enter';
|
|
||||||
const SEND_SHORTCUT_STORAGE_KEY = 'chat_send_shortcut';
|
|
||||||
|
|
||||||
const props = withDefaults(defineProps<Props>(), {
|
const props = withDefaults(defineProps<Props>(), {
|
||||||
chatboxMode: false
|
chatboxMode: false
|
||||||
@@ -242,7 +233,6 @@ const router = useRouter();
|
|||||||
const route = useRoute();
|
const route = useRoute();
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const { tm } = useModuleI18n('features/chat');
|
const { tm } = useModuleI18n('features/chat');
|
||||||
const { warning: toastWarning } = useToast();
|
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const customizer = useCustomizerStore();
|
const customizer = useCustomizerStore();
|
||||||
|
|
||||||
@@ -267,7 +257,6 @@ const {
|
|||||||
getSessions,
|
getSessions,
|
||||||
newSession,
|
newSession,
|
||||||
deleteSession: deleteSessionFn,
|
deleteSession: deleteSessionFn,
|
||||||
batchDeleteSessions,
|
|
||||||
showEditTitleDialog,
|
showEditTitleDialog,
|
||||||
saveTitle,
|
saveTitle,
|
||||||
updateSessionTitle,
|
updateSessionTitle,
|
||||||
@@ -341,18 +330,6 @@ interface ReplyInfo {
|
|||||||
const replyTo = ref<ReplyInfo | null>(null);
|
const replyTo = ref<ReplyInfo | null>(null);
|
||||||
|
|
||||||
const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark');
|
const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark');
|
||||||
const sendShortcut = ref<SendShortcut>('shift_enter');
|
|
||||||
|
|
||||||
function setSendShortcut(mode: SendShortcut) {
|
|
||||||
sendShortcut.value = mode;
|
|
||||||
localStorage.setItem(SEND_SHORTCUT_STORAGE_KEY, mode);
|
|
||||||
}
|
|
||||||
|
|
||||||
function focusChatInput() {
|
|
||||||
nextTick(() => {
|
|
||||||
chatInputRef.value?.focusInput?.();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检测是否为手机端
|
// 检测是否为手机端
|
||||||
function checkMobile() {
|
function checkMobile() {
|
||||||
@@ -511,7 +488,6 @@ async function handleSelectConversation(sessionIds: string[]) {
|
|||||||
nextTick(() => {
|
nextTick(() => {
|
||||||
messageList.value?.scrollToBottom();
|
messageList.value?.scrollToBottom();
|
||||||
});
|
});
|
||||||
focusChatInput();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChat() {
|
function handleNewChat() {
|
||||||
@@ -521,7 +497,6 @@ function handleNewChat() {
|
|||||||
// 退出项目视图
|
// 退出项目视图
|
||||||
selectedProjectId.value = null;
|
selectedProjectId.value = null;
|
||||||
projectSessions.value = [];
|
projectSessions.value = [];
|
||||||
focusChatInput();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleDeleteConversation(sessionId: string) {
|
async function handleDeleteConversation(sessionId: string) {
|
||||||
@@ -535,33 +510,6 @@ async function handleDeleteConversation(sessionId: string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleBatchDeleteConversations(sessionIds: string[]) {
|
|
||||||
try {
|
|
||||||
const result = await batchDeleteSessions(sessionIds);
|
|
||||||
|
|
||||||
// 仅在当前会话成功删除时清除信息
|
|
||||||
if (result.currentSessionDeleted) {
|
|
||||||
messages.value = [];
|
|
||||||
}
|
|
||||||
|
|
||||||
// 失败处理
|
|
||||||
if (result.failed_count > 0) {
|
|
||||||
toastWarning(
|
|
||||||
tm('batch.partialFailure', { failed: result.failed_count, total: sessionIds.length })
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果在项目视图中,刷新项目会话列表
|
|
||||||
if (selectedProjectId.value) {
|
|
||||||
const sessions = await getProjectSessions(selectedProjectId.value);
|
|
||||||
projectSessions.value = sessions;
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Batch delete sessions failed:', err);
|
|
||||||
toastWarning(tm('batch.requestFailed'));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleSelectProject(projectId: string) {
|
async function handleSelectProject(projectId: string) {
|
||||||
selectedProjectId.value = projectId;
|
selectedProjectId.value = projectId;
|
||||||
const sessions = await getProjectSessions(projectId);
|
const sessions = await getProjectSessions(projectId);
|
||||||
@@ -679,11 +627,6 @@ async function handleSendMessage() {
|
|||||||
const selectedProviderId = selection?.providerId || '';
|
const selectedProviderId = selection?.providerId || '';
|
||||||
const selectedModelName = selection?.modelName || '';
|
const selectedModelName = selection?.modelName || '';
|
||||||
|
|
||||||
// 点击发送后立即将消息区滚到底部,确保用户看到最新消息
|
|
||||||
nextTick(() => {
|
|
||||||
messageList.value?.scrollToBottom();
|
|
||||||
});
|
|
||||||
|
|
||||||
await sendMsg(
|
await sendMsg(
|
||||||
promptToSend,
|
promptToSend,
|
||||||
filesToSend,
|
filesToSend,
|
||||||
@@ -693,11 +636,6 @@ async function handleSendMessage() {
|
|||||||
replyToSend
|
replyToSend
|
||||||
);
|
);
|
||||||
|
|
||||||
// 发送流程结束后再兜底一次,处理异步渲染场景
|
|
||||||
nextTick(() => {
|
|
||||||
messageList.value?.scrollToBottom();
|
|
||||||
});
|
|
||||||
|
|
||||||
// 如果在项目中创建了新会话,将其添加到项目
|
// 如果在项目中创建了新会话,将其添加到项目
|
||||||
if (isCreatingNewSession && currentProjectId && currSessionId.value) {
|
if (isCreatingNewSession && currentProjectId && currSessionId.value) {
|
||||||
await addSessionToProject(currSessionId.value, currentProjectId);
|
await addSessionToProject(currSessionId.value, currentProjectId);
|
||||||
@@ -756,10 +694,6 @@ watch(sessions, (newSessions) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
const storedShortcut = localStorage.getItem(SEND_SHORTCUT_STORAGE_KEY);
|
|
||||||
if (storedShortcut === 'enter' || storedShortcut === 'shift_enter') {
|
|
||||||
sendShortcut.value = storedShortcut;
|
|
||||||
}
|
|
||||||
checkMobile();
|
checkMobile();
|
||||||
window.addEventListener('resize', checkMobile);
|
window.addEventListener('resize', checkMobile);
|
||||||
getSessions();
|
getSessions();
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
<transition name="fade">
|
<transition name="fade">
|
||||||
<div v-if="isDragging" class="drop-overlay">
|
<div v-if="isDragging" class="drop-overlay">
|
||||||
<div class="drop-overlay-content">
|
<div class="drop-overlay-content">
|
||||||
<v-icon size="48" color="primary">mdi-cloud-upload</v-icon>
|
<v-icon size="48" color="deep-purple">mdi-cloud-upload</v-icon>
|
||||||
<span class="drop-text">{{ tm('input.dropToUpload') }}</span>
|
<span class="drop-text">{{ tm('input.dropToUpload') }}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -41,7 +41,7 @@
|
|||||||
<!-- Settings Menu -->
|
<!-- Settings Menu -->
|
||||||
<StyledMenu offset="8" location="top start" :close-on-content-click="false">
|
<StyledMenu offset="8" location="top start" :close-on-content-click="false">
|
||||||
<template v-slot:activator="{ props: activatorProps }">
|
<template v-slot:activator="{ props: activatorProps }">
|
||||||
<v-btn v-bind="activatorProps" icon="mdi-plus" variant="text" color="primary" />
|
<v-btn v-bind="activatorProps" icon="mdi-plus" variant="text" color="deep-purple" />
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<!-- Upload Files -->
|
<!-- Upload Files -->
|
||||||
@@ -87,7 +87,7 @@
|
|||||||
{{ tm('voice.liveMode') }}
|
{{ tm('voice.liveMode') }}
|
||||||
</v-tooltip>
|
</v-tooltip>
|
||||||
</v-btn> -->
|
</v-btn> -->
|
||||||
<v-btn @click="handleRecordClick" icon variant="text" :color="isRecording ? 'error' : 'primary'"
|
<v-btn @click="handleRecordClick" icon variant="text" :color="isRecording ? 'error' : 'deep-purple'"
|
||||||
class="record-btn">
|
class="record-btn">
|
||||||
<v-icon :icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
<v-icon :icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||||
plain></v-icon>
|
plain></v-icon>
|
||||||
@@ -95,13 +95,13 @@
|
|||||||
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
|
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
|
||||||
</v-tooltip>
|
</v-tooltip>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn icon v-if="isRunning && !canSend" @click="$emit('stop')" variant="tonal" color="primary" class="send-btn">
|
<v-btn icon v-if="isRunning" @click="$emit('stop')" variant="tonal" color="deep-purple" class="send-btn">
|
||||||
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
|
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
|
||||||
<v-tooltip activator="parent" location="top">
|
<v-tooltip activator="parent" location="top">
|
||||||
{{ tm('input.stopGenerating') }}
|
{{ tm('input.stopGenerating') }}
|
||||||
</v-tooltip>
|
</v-tooltip>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn v-else @click="$emit('send')" icon="mdi-send" variant="tonal" color="primary"
|
<v-btn v-else @click="$emit('send')" icon="mdi-send" variant="tonal" color="deep-purple"
|
||||||
:disabled="!canSend" class="send-btn" />
|
:disabled="!canSend" class="send-btn" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -117,7 +117,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div v-if="stagedAudioUrl" class="audio-preview">
|
<div v-if="stagedAudioUrl" class="audio-preview">
|
||||||
<v-chip color="primary" variant="tonal" class="audio-chip">
|
<v-chip color="deep-purple-lighten-4" class="audio-chip">
|
||||||
<v-icon start icon="mdi-microphone" size="small"></v-icon>
|
<v-icon start icon="mdi-microphone" size="small"></v-icon>
|
||||||
{{ tm('voice.recording') }}
|
{{ tm('voice.recording') }}
|
||||||
</v-chip>
|
</v-chip>
|
||||||
@@ -126,7 +126,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div v-for="(file, index) in stagedFiles" :key="'file-' + index" class="file-preview">
|
<div v-for="(file, index) in stagedFiles" :key="'file-' + index" class="file-preview">
|
||||||
<v-chip color="primary" variant="tonal" class="file-chip">
|
<v-chip color="blue-grey-lighten-4" class="file-chip">
|
||||||
<v-icon start icon="mdi-file-document-outline" size="small"></v-icon>
|
<v-icon start icon="mdi-file-document-outline" size="small"></v-icon>
|
||||||
<span class="file-name-preview">{{ file.original_name }}</span>
|
<span class="file-name-preview">{{ file.original_name }}</span>
|
||||||
</v-chip>
|
</v-chip>
|
||||||
@@ -173,7 +173,6 @@ interface Props {
|
|||||||
currentSession?: Session | null;
|
currentSession?: Session | null;
|
||||||
configId?: string | null;
|
configId?: string | null;
|
||||||
replyTo?: ReplyInfo | null;
|
replyTo?: ReplyInfo | null;
|
||||||
sendShortcut?: 'enter' | 'shift_enter';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const props = withDefaults(defineProps<Props>(), {
|
const props = withDefaults(defineProps<Props>(), {
|
||||||
@@ -181,8 +180,7 @@ const props = withDefaults(defineProps<Props>(), {
|
|||||||
currentSession: null,
|
currentSession: null,
|
||||||
configId: null,
|
configId: null,
|
||||||
stagedFiles: () => [],
|
stagedFiles: () => [],
|
||||||
replyTo: null,
|
replyTo: null
|
||||||
sendShortcut: 'shift_enter'
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
@@ -255,29 +253,9 @@ watch(localPrompt, () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
function handleKeyDown(e: KeyboardEvent) {
|
function handleKeyDown(e: KeyboardEvent) {
|
||||||
const isEnter = e.key === 'Enter';
|
// Enter 插入换行(桌面和手机端均如此,发送通过右下角发送按鈕)
|
||||||
if (!isEnter) {
|
// Shift+Enter 发送(Ctrl+Enter / Cmd+Enter 也保留)
|
||||||
// Ctrl+B 录音
|
if (e.keyCode === 13 && (e.shiftKey || e.ctrlKey || e.metaKey)) {
|
||||||
if (e.ctrlKey && e.keyCode === 66) {
|
|
||||||
e.preventDefault();
|
|
||||||
if (ctrlKeyDown.value) return;
|
|
||||||
|
|
||||||
ctrlKeyDown.value = true;
|
|
||||||
ctrlKeyTimer.value = window.setTimeout(() => {
|
|
||||||
if (ctrlKeyDown.value && !props.isRecording) {
|
|
||||||
emit('startRecording');
|
|
||||||
}
|
|
||||||
}, ctrlKeyLongPressThreshold);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const isSendHotkey =
|
|
||||||
e.ctrlKey ||
|
|
||||||
e.metaKey ||
|
|
||||||
(props.sendShortcut === 'enter' ? !e.shiftKey : e.shiftKey);
|
|
||||||
|
|
||||||
if (isSendHotkey) {
|
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
if (localPrompt.value.trim() === '/astr_live_dev') {
|
if (localPrompt.value.trim() === '/astr_live_dev') {
|
||||||
emit('openLiveMode');
|
emit('openLiveMode');
|
||||||
@@ -289,6 +267,19 @@ function handleKeyDown(e: KeyboardEvent) {
|
|||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ctrl+B 录音
|
||||||
|
if (e.ctrlKey && e.keyCode === 66) {
|
||||||
|
e.preventDefault();
|
||||||
|
if (ctrlKeyDown.value) return;
|
||||||
|
|
||||||
|
ctrlKeyDown.value = true;
|
||||||
|
ctrlKeyTimer.value = window.setTimeout(() => {
|
||||||
|
if (ctrlKeyDown.value && !props.isRecording) {
|
||||||
|
emit('startRecording');
|
||||||
|
}
|
||||||
|
}, ctrlKeyLongPressThreshold);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleKeyUp(e: KeyboardEvent) {
|
function handleKeyUp(e: KeyboardEvent) {
|
||||||
@@ -373,11 +364,6 @@ function getCurrentSelection() {
|
|||||||
return providerModelMenuRef.value?.getCurrentSelection();
|
return providerModelMenuRef.value?.getCurrentSelection();
|
||||||
}
|
}
|
||||||
|
|
||||||
function focusInput() {
|
|
||||||
if (!inputField.value) return;
|
|
||||||
inputField.value.focus();
|
|
||||||
}
|
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
if (inputField.value) {
|
if (inputField.value) {
|
||||||
inputField.value.addEventListener('paste', handlePaste);
|
inputField.value.addEventListener('paste', handlePaste);
|
||||||
@@ -393,8 +379,7 @@ onBeforeUnmount(() => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
defineExpose({
|
defineExpose({
|
||||||
getCurrentSelection,
|
getCurrentSelection
|
||||||
focusInput
|
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
@@ -414,8 +399,8 @@ defineExpose({
|
|||||||
left: 0;
|
left: 0;
|
||||||
right: 0;
|
right: 0;
|
||||||
bottom: 0;
|
bottom: 0;
|
||||||
background-color: rgba(var(--v-theme-primary), 0.12);
|
background-color: rgba(103, 58, 183, 0.15);
|
||||||
border: 2px dashed rgba(var(--v-theme-primary), 0.45);
|
border: 2px dashed rgba(103, 58, 183, 0.5);
|
||||||
border-radius: 24px;
|
border-radius: 24px;
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
@@ -434,7 +419,7 @@ defineExpose({
|
|||||||
.drop-text {
|
.drop-text {
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
color: rgb(var(--v-theme-primary));
|
color: #673ab7;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Fade transition for drop overlay */
|
/* Fade transition for drop overlay */
|
||||||
@@ -454,7 +439,7 @@ defineExpose({
|
|||||||
justify-content: space-between;
|
justify-content: space-between;
|
||||||
padding: 8px 16px;
|
padding: 8px 16px;
|
||||||
margin: 8px 8px 0 8px;
|
margin: 8px 8px 0 8px;
|
||||||
background-color: rgba(var(--v-theme-primary), 0.06);
|
background-color: rgba(103, 58, 183, 0.06);
|
||||||
border-radius: 12px;
|
border-radius: 12px;
|
||||||
gap: 8px;
|
gap: 8px;
|
||||||
max-height: 500px;
|
max-height: 500px;
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
'mobile-sidebar-open': isMobile && mobileMenuOpen,
|
'mobile-sidebar-open': isMobile && mobileMenuOpen,
|
||||||
'mobile-sidebar': isMobile
|
'mobile-sidebar': isMobile
|
||||||
}"
|
}"
|
||||||
:style="{ backgroundColor: sidebarCollapsed && !isMobile ? 'rgb(var(--v-theme-surface))' : 'rgb(var(--v-theme-mcpCardBg))' }">
|
:style="{ 'background-color': isDark ? sidebarCollapsed ? '#1e1e1e' : '#2d2d2d' : sidebarCollapsed ? '#ffffff' : '#f1f4f9' }">
|
||||||
|
|
||||||
<div class="sidebar-collapse-btn-container" v-if="!isMobile">
|
<div class="sidebar-collapse-btn-container" v-if="!isMobile">
|
||||||
<v-btn icon class="sidebar-collapse-btn" @click="toggleSidebar" variant="text" color="deep-purple">
|
<v-btn icon class="sidebar-collapse-btn" @click="toggleSidebar" variant="text" color="deep-purple">
|
||||||
@@ -21,31 +21,12 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div style="padding: 8px; opacity: 0.6;">
|
<div style="padding: 8px; opacity: 0.6;">
|
||||||
<div class="new-chat-row" v-if="!sidebarCollapsed || isMobile">
|
<v-btn block variant="text" class="new-chat-btn" @click="$emit('newChat')" :disabled="!currSessionId && !selectedProjectId"
|
||||||
<v-btn block variant="text" class="new-chat-btn" @click="$emit('newChat')" :disabled="!currSessionId && !selectedProjectId"
|
v-if="!sidebarCollapsed || isMobile" prepend-icon="mdi-square-edit-outline">{{ tm('actions.newChat') }}</v-btn>
|
||||||
prepend-icon="mdi-square-edit-outline">{{ tm('actions.newChat') }}</v-btn>
|
<v-btn icon="mdi-square-edit-outline" rounded="xl" @click="$emit('newChat')" :disabled="!currSessionId && !selectedProjectId"
|
||||||
<v-btn v-if="sessions.length > 0" icon size="small" variant="text" @click="toggleBatchMode"
|
|
||||||
:color="batchMode ? 'primary' : undefined">
|
|
||||||
<v-icon>mdi-checkbox-multiple-marked-outline</v-icon>
|
|
||||||
</v-btn>
|
|
||||||
</div>
|
|
||||||
<v-btn icon="mdi-square-edit-outline" rounded="xl" @click="$emit('newChat')" :disabled="!currSessionId && !selectedProjectId"
|
|
||||||
v-if="sidebarCollapsed && !isMobile" elevation="0"></v-btn>
|
v-if="sidebarCollapsed && !isMobile" elevation="0"></v-btn>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Batch action bar -->
|
|
||||||
<div v-if="batchMode && (!sidebarCollapsed || isMobile)" class="batch-action-bar">
|
|
||||||
<v-btn size="x-small" variant="text" @click="toggleSelectAll">
|
|
||||||
{{ isAllSelected ? tm('batch.deselectAll') : tm('batch.selectAll') }}
|
|
||||||
</v-btn>
|
|
||||||
<span class="batch-selected-count">{{ tm('batch.selected', { count: batchSelected.length }) }}</span>
|
|
||||||
<v-spacer />
|
|
||||||
<v-btn size="x-small" variant="text" color="error" :disabled="batchSelected.length === 0"
|
|
||||||
@click="handleBatchDelete">
|
|
||||||
{{ tm('batch.delete') }}
|
|
||||||
</v-btn>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 项目列表组件 -->
|
<!-- 项目列表组件 -->
|
||||||
<ProjectList
|
<ProjectList
|
||||||
v-if="!sidebarCollapsed || isMobile"
|
v-if="!sidebarCollapsed || isMobile"
|
||||||
@@ -60,34 +41,19 @@
|
|||||||
v-if="!sidebarCollapsed || isMobile">
|
v-if="!sidebarCollapsed || isMobile">
|
||||||
<v-card v-if="sessions.length > 0" flat style="background-color: transparent;">
|
<v-card v-if="sessions.length > 0" flat style="background-color: transparent;">
|
||||||
<v-list density="compact" nav class="conversation-list"
|
<v-list density="compact" nav class="conversation-list"
|
||||||
style="background-color: transparent;" :selected="batchMode ? [] : selectedSessions"
|
style="background-color: transparent;" :selected="selectedSessions"
|
||||||
@update:selected="handleListSelect">
|
@update:selected="$emit('selectConversation', $event)">
|
||||||
<v-list-item v-for="item in sessions" :key="item.session_id" :value="item.session_id"
|
<v-list-item v-for="item in sessions" :key="item.session_id" :value="item.session_id"
|
||||||
rounded="lg" class="conversation-item" active-color="secondary"
|
rounded="lg" class="conversation-item" active-color="secondary">
|
||||||
@click="batchMode ? toggleBatchItem(item.session_id) : undefined">
|
|
||||||
|
|
||||||
<template v-slot:prepend>
|
|
||||||
<div class="batch-checkbox-slot" :class="{ 'batch-checkbox-slot--active': batchMode }">
|
|
||||||
<v-checkbox-btn
|
|
||||||
:model-value="batchSelected.includes(item.session_id)"
|
|
||||||
@update:model-value="toggleBatchItem(item.session_id)"
|
|
||||||
@click.stop
|
|
||||||
density="compact"
|
|
||||||
hide-details
|
|
||||||
class="batch-checkbox"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</template>
|
|
||||||
|
|
||||||
<v-list-item-title v-if="!sidebarCollapsed || isMobile" class="conversation-title"
|
<v-list-item-title v-if="!sidebarCollapsed || isMobile" class="conversation-title"
|
||||||
:style="{ color: 'rgb(var(--v-theme-primaryText))' }">
|
:style="{ color: isDark ? '#ffffff' : '#000000' }">
|
||||||
{{ item.display_name || tm('conversation.newConversation') }}
|
{{ item.display_name || tm('conversation.newConversation') }}
|
||||||
</v-list-item-title>
|
</v-list-item-title>
|
||||||
<!-- <v-list-item-subtitle v-if="!sidebarCollapsed || isMobile" class="timestamp">
|
<!-- <v-list-item-subtitle v-if="!sidebarCollapsed || isMobile" class="timestamp">
|
||||||
{{ new Date(item.updated_at).toLocaleString() }}
|
{{ new Date(item.updated_at).toLocaleString() }}
|
||||||
</v-list-item-subtitle> -->
|
</v-list-item-subtitle> -->
|
||||||
|
|
||||||
<template v-if="!batchMode && (!sidebarCollapsed || isMobile)" v-slot:append>
|
<template v-if="!sidebarCollapsed || isMobile" v-slot:append>
|
||||||
<div class="conversation-actions">
|
<div class="conversation-actions">
|
||||||
<v-btn icon="mdi-pencil" size="x-small" variant="text"
|
<v-btn icon="mdi-pencil" size="x-small" variant="text"
|
||||||
class="edit-title-btn"
|
class="edit-title-btn"
|
||||||
@@ -132,52 +98,16 @@
|
|||||||
</v-btn>
|
</v-btn>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<!-- 语言切换(分组) -->
|
<!-- 语言切换 -->
|
||||||
<v-menu
|
<v-list-item class="styled-menu-item">
|
||||||
:open-on-hover="!isMobile"
|
<template v-slot:prepend>
|
||||||
:open-on-click="isMobile"
|
<v-icon>mdi-translate</v-icon>
|
||||||
:open-delay="!isMobile ? 60 : 0"
|
|
||||||
:close-delay="!isMobile ? 120 : 0"
|
|
||||||
:location="isMobile ? 'bottom' : 'end center'"
|
|
||||||
offset="8"
|
|
||||||
close-on-content-click
|
|
||||||
>
|
|
||||||
<template v-slot:activator="{ props: languageMenuProps }">
|
|
||||||
<v-list-item
|
|
||||||
v-bind="languageMenuProps"
|
|
||||||
class="styled-menu-item chat-settings-group-trigger"
|
|
||||||
rounded="md"
|
|
||||||
>
|
|
||||||
<template v-slot:prepend>
|
|
||||||
<v-icon>mdi-translate</v-icon>
|
|
||||||
</template>
|
|
||||||
<v-list-item-title>{{ t('core.common.language') }}</v-list-item-title>
|
|
||||||
<template v-slot:append>
|
|
||||||
<span class="chat-settings-group-current">{{ currentLanguage?.flag }}</span>
|
|
||||||
<v-icon size="18" class="chat-settings-group-arrow">mdi-chevron-right</v-icon>
|
|
||||||
</template>
|
|
||||||
</v-list-item>
|
|
||||||
</template>
|
</template>
|
||||||
|
<v-list-item-title>{{ t('core.common.language') }}</v-list-item-title>
|
||||||
<v-card class="styled-menu-card" style="min-width: 180px;" elevation="8" rounded="lg">
|
<template v-slot:append>
|
||||||
<v-list density="compact" class="styled-menu-list pa-1">
|
<LanguageSwitcher variant="chatbox" />
|
||||||
<v-list-item
|
</template>
|
||||||
v-for="lang in languages"
|
</v-list-item>
|
||||||
:key="lang.code"
|
|
||||||
:value="lang.code"
|
|
||||||
@click="changeLanguage(lang.code)"
|
|
||||||
:class="{ 'styled-menu-item-active': currentLocale === lang.code }"
|
|
||||||
class="styled-menu-item"
|
|
||||||
rounded="md"
|
|
||||||
>
|
|
||||||
<template v-slot:prepend>
|
|
||||||
<span class="language-flag">{{ lang.flag }}</span>
|
|
||||||
</template>
|
|
||||||
<v-list-item-title>{{ lang.name }}</v-list-item-title>
|
|
||||||
</v-list-item>
|
|
||||||
</v-list>
|
|
||||||
</v-card>
|
|
||||||
</v-menu>
|
|
||||||
|
|
||||||
<!-- 主题切换 -->
|
<!-- 主题切换 -->
|
||||||
<v-list-item class="styled-menu-item" @click="$emit('toggleTheme')">
|
<v-list-item class="styled-menu-item" @click="$emit('toggleTheme')">
|
||||||
@@ -187,93 +117,26 @@
|
|||||||
<v-list-item-title>{{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }}</v-list-item-title>
|
<v-list-item-title>{{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }}</v-list-item-title>
|
||||||
</v-list-item>
|
</v-list-item>
|
||||||
|
|
||||||
<!-- 通信传输模式(分组) -->
|
<!-- 通信传输模式 -->
|
||||||
<v-menu
|
<v-list-item class="styled-menu-item">
|
||||||
:open-on-hover="!isMobile"
|
<template v-slot:prepend>
|
||||||
:open-on-click="isMobile"
|
<v-icon>mdi-lan-connect</v-icon>
|
||||||
:open-delay="!isMobile ? 60 : 0"
|
|
||||||
:close-delay="!isMobile ? 120 : 0"
|
|
||||||
:location="isMobile ? 'bottom' : 'end center'"
|
|
||||||
offset="8"
|
|
||||||
close-on-content-click
|
|
||||||
>
|
|
||||||
<template v-slot:activator="{ props: transportMenuProps }">
|
|
||||||
<v-list-item
|
|
||||||
v-bind="transportMenuProps"
|
|
||||||
class="styled-menu-item chat-settings-group-trigger"
|
|
||||||
rounded="md"
|
|
||||||
>
|
|
||||||
<template v-slot:prepend>
|
|
||||||
<v-icon>mdi-lan-connect</v-icon>
|
|
||||||
</template>
|
|
||||||
<v-list-item-title>{{ tm('transport.title') }}</v-list-item-title>
|
|
||||||
<template v-slot:append>
|
|
||||||
<span class="chat-settings-group-current chat-settings-transport-current">{{ currentTransportLabel }}</span>
|
|
||||||
<v-icon size="18" class="chat-settings-group-arrow">mdi-chevron-right</v-icon>
|
|
||||||
</template>
|
|
||||||
</v-list-item>
|
|
||||||
</template>
|
</template>
|
||||||
|
<v-list-item-title>{{ tm('transport.title') }}</v-list-item-title>
|
||||||
<v-card class="styled-menu-card" style="min-width: 220px;" elevation="8" rounded="lg">
|
<template v-slot:append>
|
||||||
<v-list density="compact" class="styled-menu-list pa-1">
|
<v-select
|
||||||
<v-list-item
|
:model-value="transportMode"
|
||||||
v-for="opt in transportOptions"
|
:items="transportOptions"
|
||||||
:key="opt.value"
|
item-title="label"
|
||||||
:value="opt.value"
|
item-value="value"
|
||||||
@click="handleTransportModeChange(opt.value)"
|
density="compact"
|
||||||
:class="{ 'styled-menu-item-active': transportMode === opt.value }"
|
variant="underlined"
|
||||||
class="styled-menu-item"
|
hide-details
|
||||||
rounded="md"
|
class="transport-mode-select"
|
||||||
>
|
@update:model-value="handleTransportModeChange"
|
||||||
<v-list-item-title>{{ opt.label }}</v-list-item-title>
|
/>
|
||||||
</v-list-item>
|
|
||||||
</v-list>
|
|
||||||
</v-card>
|
|
||||||
</v-menu>
|
|
||||||
|
|
||||||
<!-- 发送快捷键(分组) -->
|
|
||||||
<v-menu
|
|
||||||
:open-on-hover="!isMobile"
|
|
||||||
:open-on-click="isMobile"
|
|
||||||
:open-delay="!isMobile ? 60 : 0"
|
|
||||||
:close-delay="!isMobile ? 120 : 0"
|
|
||||||
:location="isMobile ? 'bottom' : 'end center'"
|
|
||||||
offset="8"
|
|
||||||
close-on-content-click
|
|
||||||
>
|
|
||||||
<template v-slot:activator="{ props: sendShortcutMenuProps }">
|
|
||||||
<v-list-item
|
|
||||||
v-bind="sendShortcutMenuProps"
|
|
||||||
class="styled-menu-item chat-settings-group-trigger"
|
|
||||||
rounded="md"
|
|
||||||
>
|
|
||||||
<template v-slot:prepend>
|
|
||||||
<v-icon>mdi-keyboard-outline</v-icon>
|
|
||||||
</template>
|
|
||||||
<v-list-item-title>{{ tm('shortcuts.sendKey.title') }}</v-list-item-title>
|
|
||||||
<template v-slot:append>
|
|
||||||
<span class="chat-settings-group-current chat-settings-transport-current">{{ currentSendShortcutLabel }}</span>
|
|
||||||
<v-icon size="18" class="chat-settings-group-arrow">mdi-chevron-right</v-icon>
|
|
||||||
</template>
|
|
||||||
</v-list-item>
|
|
||||||
</template>
|
</template>
|
||||||
|
</v-list-item>
|
||||||
<v-card class="styled-menu-card" style="min-width: 220px;" elevation="8" rounded="lg">
|
|
||||||
<v-list density="compact" class="styled-menu-list pa-1">
|
|
||||||
<v-list-item
|
|
||||||
v-for="opt in sendShortcutOptions"
|
|
||||||
:key="opt.value"
|
|
||||||
:value="opt.value"
|
|
||||||
@click="handleSendShortcutChange(opt.value)"
|
|
||||||
:class="{ 'styled-menu-item-active': props.sendShortcut === opt.value }"
|
|
||||||
class="styled-menu-item"
|
|
||||||
rounded="md"
|
|
||||||
>
|
|
||||||
<v-list-item-title>{{ opt.label }}</v-list-item-title>
|
|
||||||
</v-list-item>
|
|
||||||
</v-list>
|
|
||||||
</v-card>
|
|
||||||
</v-menu>
|
|
||||||
|
|
||||||
<!-- 全屏/退出全屏 -->
|
<!-- 全屏/退出全屏 -->
|
||||||
<v-list-item class="styled-menu-item" @click="$emit('toggleFullscreen')">
|
<v-list-item class="styled-menu-item" @click="$emit('toggleFullscreen')">
|
||||||
@@ -299,16 +162,15 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed } from 'vue';
|
import { ref } from 'vue';
|
||||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||||
import type { Session } from '@/composables/useSessions';
|
import type { Session } from '@/composables/useSessions';
|
||||||
import { askForConfirmation, useConfirmDialog } from '@/utils/confirmDialog';
|
import { askForConfirmation, useConfirmDialog } from '@/utils/confirmDialog';
|
||||||
|
import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
|
||||||
import StyledMenu from '@/components/shared/StyledMenu.vue';
|
import StyledMenu from '@/components/shared/StyledMenu.vue';
|
||||||
import ProviderConfigDialog from '@/components/chat/ProviderConfigDialog.vue';
|
import ProviderConfigDialog from '@/components/chat/ProviderConfigDialog.vue';
|
||||||
import ProjectList from '@/components/chat/ProjectList.vue';
|
import ProjectList from '@/components/chat/ProjectList.vue';
|
||||||
import type { Project } from '@/components/chat/ProjectList.vue';
|
import type { Project } from '@/components/chat/ProjectList.vue';
|
||||||
import { useLanguageSwitcher } from '@/i18n/composables';
|
|
||||||
import type { Locale } from '@/i18n/types';
|
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
sessions: Session[];
|
sessions: Session[];
|
||||||
@@ -321,7 +183,6 @@ interface Props {
|
|||||||
isMobile: boolean;
|
isMobile: boolean;
|
||||||
mobileMenuOpen: boolean;
|
mobileMenuOpen: boolean;
|
||||||
projects?: Project[];
|
projects?: Project[];
|
||||||
sendShortcut: 'enter' | 'shift_enter';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const props = withDefaults(defineProps<Props>(), {
|
const props = withDefaults(defineProps<Props>(), {
|
||||||
@@ -333,7 +194,6 @@ const emit = defineEmits<{
|
|||||||
selectConversation: [sessionIds: string[]];
|
selectConversation: [sessionIds: string[]];
|
||||||
editTitle: [sessionId: string, title: string];
|
editTitle: [sessionId: string, title: string];
|
||||||
deleteConversation: [sessionId: string];
|
deleteConversation: [sessionId: string];
|
||||||
batchDeleteConversations: [sessionIds: string[]];
|
|
||||||
closeMobileSidebar: [];
|
closeMobileSidebar: [];
|
||||||
toggleTheme: [];
|
toggleTheme: [];
|
||||||
toggleFullscreen: [];
|
toggleFullscreen: [];
|
||||||
@@ -342,7 +202,6 @@ const emit = defineEmits<{
|
|||||||
editProject: [project: Project];
|
editProject: [project: Project];
|
||||||
deleteProject: [projectId: string];
|
deleteProject: [projectId: string];
|
||||||
updateTransportMode: [mode: 'sse' | 'websocket'];
|
updateTransportMode: [mode: 'sse' | 'websocket'];
|
||||||
updateSendShortcut: [mode: 'enter' | 'shift_enter'];
|
|
||||||
}>();
|
}>();
|
||||||
|
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
@@ -352,84 +211,10 @@ const confirmDialog = useConfirmDialog();
|
|||||||
|
|
||||||
const sidebarCollapsed = ref(true);
|
const sidebarCollapsed = ref(true);
|
||||||
const showProviderConfigDialog = ref(false);
|
const showProviderConfigDialog = ref(false);
|
||||||
|
|
||||||
// Batch mode state
|
|
||||||
const batchMode = ref(false);
|
|
||||||
const batchSelected = ref<string[]>([]);
|
|
||||||
|
|
||||||
const isAllSelected = computed(() =>
|
|
||||||
props.sessions.length > 0 && batchSelected.value.length === props.sessions.length
|
|
||||||
);
|
|
||||||
|
|
||||||
function toggleBatchMode() {
|
|
||||||
batchMode.value = !batchMode.value;
|
|
||||||
batchSelected.value = [];
|
|
||||||
}
|
|
||||||
|
|
||||||
function toggleBatchItem(sessionId: string) {
|
|
||||||
const idx = batchSelected.value.indexOf(sessionId);
|
|
||||||
if (idx >= 0) {
|
|
||||||
batchSelected.value.splice(idx, 1);
|
|
||||||
} else {
|
|
||||||
batchSelected.value.push(sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function toggleSelectAll() {
|
|
||||||
if (isAllSelected.value) {
|
|
||||||
batchSelected.value = [];
|
|
||||||
} else {
|
|
||||||
batchSelected.value = props.sessions.map(s => s.session_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleBatchDelete() {
|
|
||||||
const count = batchSelected.value.length;
|
|
||||||
if (count === 0) return;
|
|
||||||
const message = tm('batch.confirmDelete', { count });
|
|
||||||
if (await askForConfirmation(message, confirmDialog)) {
|
|
||||||
emit('batchDeleteConversations', [...batchSelected.value]);
|
|
||||||
batchSelected.value = [];
|
|
||||||
batchMode.value = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleListSelect(sessionIds: string[]) {
|
|
||||||
if (!batchMode.value) {
|
|
||||||
emit('selectConversation', sessionIds);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const transportOptions = [
|
const transportOptions = [
|
||||||
{ label: tm('transport.sse'), value: 'sse' as const },
|
{ label: tm('transport.sse'), value: 'sse' as const },
|
||||||
{ label: tm('transport.websocket'), value: 'websocket' as const }
|
{ label: tm('transport.websocket'), value: 'websocket' as const }
|
||||||
];
|
];
|
||||||
const sendShortcutOptions = [
|
|
||||||
{ label: tm('shortcuts.sendKey.enterToSend'), value: 'enter' as const },
|
|
||||||
{ label: tm('shortcuts.sendKey.shiftEnterToSend'), value: 'shift_enter' as const }
|
|
||||||
];
|
|
||||||
|
|
||||||
// Language switcher
|
|
||||||
const { languageOptions, currentLanguage, switchLanguage, locale } = useLanguageSwitcher();
|
|
||||||
const languages = computed(() =>
|
|
||||||
languageOptions.value.map(lang => ({
|
|
||||||
code: lang.value,
|
|
||||||
name: lang.label,
|
|
||||||
flag: lang.flag
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
const currentLocale = computed(() => locale.value);
|
|
||||||
const changeLanguage = async (langCode: string) => {
|
|
||||||
await switchLanguage(langCode as Locale);
|
|
||||||
};
|
|
||||||
|
|
||||||
const currentTransportLabel = computed(() => {
|
|
||||||
const found = transportOptions.find(opt => opt.value === props.transportMode);
|
|
||||||
return found?.label ?? '';
|
|
||||||
});
|
|
||||||
const currentSendShortcutLabel = computed(() => {
|
|
||||||
const found = sendShortcutOptions.find(opt => opt.value === props.sendShortcut);
|
|
||||||
return found?.label ?? '';
|
|
||||||
});
|
|
||||||
|
|
||||||
// 从 localStorage 读取侧边栏折叠状态
|
// 从 localStorage 读取侧边栏折叠状态
|
||||||
const savedCollapsedState = localStorage.getItem('sidebarCollapsed');
|
const savedCollapsedState = localStorage.getItem('sidebarCollapsed');
|
||||||
@@ -457,12 +242,6 @@ function handleTransportModeChange(mode: string | null) {
|
|||||||
emit('updateTransportMode', mode);
|
emit('updateTransportMode', mode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleSendShortcutChange(mode: string | null) {
|
|
||||||
if (mode === 'enter' || mode === 'shift_enter') {
|
|
||||||
emit('updateSendShortcut', mode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
@@ -531,7 +310,7 @@ function handleSendShortcutChange(mode: string | null) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.conversation-item:hover {
|
.conversation-item:hover {
|
||||||
background-color: rgba(var(--v-theme-primary), 0.05);
|
background-color: rgba(103, 58, 183, 0.05);
|
||||||
}
|
}
|
||||||
|
|
||||||
.conversation-item:hover .conversation-actions {
|
.conversation-item:hover .conversation-actions {
|
||||||
@@ -623,74 +402,7 @@ function handleSendShortcutChange(mode: string | null) {
|
|||||||
justify-content: center;
|
justify-content: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat-settings-group-trigger :deep(.v-list-item__append) {
|
.transport-mode-select {
|
||||||
display: flex;
|
min-width: 120px;
|
||||||
align-items: center;
|
|
||||||
gap: 6px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat-settings-group-current {
|
|
||||||
font-size: 14px;
|
|
||||||
line-height: 1;
|
|
||||||
opacity: 0.8;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat-settings-transport-current {
|
|
||||||
font-size: 12px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat-settings-group-arrow {
|
|
||||||
opacity: 0.7;
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-flag {
|
|
||||||
font-size: 16px;
|
|
||||||
margin-right: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.new-chat-row {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.new-chat-row .new-chat-btn {
|
|
||||||
flex: 1;
|
|
||||||
min-width: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.batch-action-bar {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
padding: 4px 12px;
|
|
||||||
gap: 4px;
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.batch-selected-count {
|
|
||||||
font-size: 12px;
|
|
||||||
opacity: 0.7;
|
|
||||||
white-space: nowrap;
|
|
||||||
}
|
|
||||||
|
|
||||||
.batch-checkbox {
|
|
||||||
flex: none;
|
|
||||||
transition: opacity 0.2s ease, transform 0.2s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
.batch-checkbox-slot {
|
|
||||||
width: 0;
|
|
||||||
opacity: 0;
|
|
||||||
overflow: hidden;
|
|
||||||
pointer-events: none;
|
|
||||||
transform: translateX(-8px);
|
|
||||||
transition: width 0.2s ease, opacity 0.2s ease, transform 0.2s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
.batch-checkbox-slot--active {
|
|
||||||
width: 28px;
|
|
||||||
opacity: 1;
|
|
||||||
pointer-events: auto;
|
|
||||||
transform: translateX(0);
|
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -180,7 +180,7 @@
|
|||||||
|
|
||||||
<script>
|
<script>
|
||||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||||
import { enableKatex, enableMermaid, MarkdownCodeBlockNode, setCustomComponents } from 'markstream-vue'
|
import { enableKatex, enableMermaid, setCustomComponents } from 'markstream-vue'
|
||||||
import 'markstream-vue/index.css'
|
import 'markstream-vue/index.css'
|
||||||
import 'katex/dist/katex.min.css'
|
import 'katex/dist/katex.min.css'
|
||||||
import 'highlight.js/styles/github.css';
|
import 'highlight.js/styles/github.css';
|
||||||
@@ -194,11 +194,8 @@ import ActionRef from './message_list_comps/ActionRef.vue';
|
|||||||
enableKatex();
|
enableKatex();
|
||||||
enableMermaid();
|
enableMermaid();
|
||||||
|
|
||||||
// 注册 message-list 专用组件:引用节点 + Shiki 代码块渲染
|
// 注册自定义 ref 组件
|
||||||
setCustomComponents('message-list', {
|
setCustomComponents('message-list', { ref: RefNode });
|
||||||
ref: RefNode,
|
|
||||||
code_block: MarkdownCodeBlockNode
|
|
||||||
});
|
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: 'MessageList',
|
name: 'MessageList',
|
||||||
|
|||||||
@@ -22,7 +22,7 @@
|
|||||||
v-model:prompt="prompt"
|
v-model:prompt="prompt"
|
||||||
:stagedImagesUrl="stagedImagesUrl"
|
:stagedImagesUrl="stagedImagesUrl"
|
||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:disabled="false"
|
:disabled="isStreaming"
|
||||||
:is-running="isStreaming || isConvRunning"
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
|
|||||||
@@ -63,9 +63,8 @@
|
|||||||
<!-- Text (Markdown) -->
|
<!-- Text (Markdown) -->
|
||||||
<MarkdownRender
|
<MarkdownRender
|
||||||
v-else-if="renderPart.part.type === 'plain' && renderPart.part.text && renderPart.part.text.trim()"
|
v-else-if="renderPart.part.type === 'plain' && renderPart.part.text && renderPart.part.text.trim()"
|
||||||
:key="`${renderPart.key}-${isDark ? 'dark' : 'light'}`"
|
|
||||||
custom-id="message-list" :custom-html-tags="['ref']" :content="renderPart.part.text" :typewriter="false"
|
custom-id="message-list" :custom-html-tags="['ref']" :content="renderPart.part.text" :typewriter="false"
|
||||||
class="markdown-content" :is-dark="isDark" />
|
class="markdown-content" :is-dark="isDark" :monacoOptions="{ theme: isDark ? 'vs-dark' : 'vs-light' }" />
|
||||||
|
|
||||||
<!-- Image -->
|
<!-- Image -->
|
||||||
<div v-else-if="renderPart.part.type === 'image' && renderPart.part.embedded_url" class="embedded-images">
|
<div v-else-if="renderPart.part.type === 'image' && renderPart.part.embedded_url" class="embedded-images">
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div v-if="isExpanded" class="reasoning-content animate-fade-in">
|
<div v-if="isExpanded" class="reasoning-content animate-fade-in">
|
||||||
<MarkdownRender :key="`reasoning-${isDark ? 'dark' : 'light'}`" :content="reasoning" class="reasoning-text markdown-content"
|
<MarkdownRender :content="reasoning" class="reasoning-text markdown-content"
|
||||||
:typewriter="false" :is-dark="isDark" :style="isDark ? { opacity: '0.85' } : {}" />
|
:typewriter="false" :is-dark="isDark" :style="isDark ? { opacity: '0.85' } : {}" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
<template>
|
<template>
|
||||||
<v-chip v-if="domain" class="ref-chip" size="x-small" variant="flat"
|
<v-chip v-if="domain" class="ref-chip" size="x-small" variant="flat"
|
||||||
:style="chipStyle" :href="url"
|
:style="{ backgroundColor: isDark ? '#303030' : '#f4f4f4', color: isDark ? '#999' : '#666' }" :href="url"
|
||||||
target="_blank" clickable>
|
target="_blank" clickable>
|
||||||
<v-icon start size="x-small" color>mdi-link-variant</v-icon>
|
<v-icon start size="x-small" color>mdi-link-variant</v-icon>
|
||||||
<span>{{ domain }}</span>
|
<span>{{ domain }}</span>
|
||||||
|
|
||||||
</v-chip>
|
</v-chip>
|
||||||
<span v-else class="ref-fallback" :style="fallbackStyle">{{ 'site' }}</span>
|
<span v-else class="ref-fallback" :style="{ color: isDark ? '#999' : '#666' }">{{ 'site' }}</span>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
@@ -46,15 +46,6 @@ const domain = computed(() => {
|
|||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
const chipStyle = computed(() => ({
|
|
||||||
backgroundColor: isDark ? 'rgba(var(--v-theme-on-surface), 0.08)' : 'rgba(var(--v-theme-on-surface), 0.04)',
|
|
||||||
color: isDark ? 'rgba(var(--v-theme-on-surface), 0.62)' : 'rgba(var(--v-theme-on-surface), 0.72)'
|
|
||||||
}))
|
|
||||||
|
|
||||||
const fallbackStyle = computed(() => ({
|
|
||||||
color: isDark ? 'rgba(var(--v-theme-on-surface), 0.62)' : 'rgba(var(--v-theme-on-surface), 0.72)'
|
|
||||||
}))
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
|
|||||||
@@ -300,10 +300,6 @@ export default {
|
|||||||
this.loadingGettingServers = true;
|
this.loadingGettingServers = true;
|
||||||
axios.get('/api/tools/mcp/servers')
|
axios.get('/api/tools/mcp/servers')
|
||||||
.then(response => {
|
.then(response => {
|
||||||
if (response.data.status === 'error') {
|
|
||||||
this.showError(response.data.message || this.tm('messages.getServersError', { error: 'Unknown error' }));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
this.mcpServers = response.data.data || [];
|
this.mcpServers = response.data.data || [];
|
||||||
this.mcpServers.forEach(server => {
|
this.mcpServers.forEach(server => {
|
||||||
if (!this.mcpServerUpdateLoaders[server.name]) {
|
if (!this.mcpServerUpdateLoaders[server.name]) {
|
||||||
@@ -376,10 +372,6 @@ export default {
|
|||||||
axios.post(endpoint, serverData)
|
axios.post(endpoint, serverData)
|
||||||
.then(response => {
|
.then(response => {
|
||||||
this.loading = false;
|
this.loading = false;
|
||||||
if (response.data.status === 'error') {
|
|
||||||
this.showError(response.data.message || this.tm('messages.saveError', { error: 'Unknown error' }));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
this.showMcpServerDialog = false;
|
this.showMcpServerDialog = false;
|
||||||
this.addServerDialogMessage = '';
|
this.addServerDialogMessage = '';
|
||||||
this.getServers();
|
this.getServers();
|
||||||
|
|||||||
@@ -1,97 +0,0 @@
|
|||||||
<script setup>
|
|
||||||
const props = defineProps({
|
|
||||||
modelValue: {
|
|
||||||
type: String,
|
|
||||||
required: true,
|
|
||||||
},
|
|
||||||
items: {
|
|
||||||
type: Array,
|
|
||||||
required: true,
|
|
||||||
},
|
|
||||||
label: {
|
|
||||||
type: String,
|
|
||||||
required: true,
|
|
||||||
},
|
|
||||||
order: {
|
|
||||||
type: String,
|
|
||||||
default: "desc",
|
|
||||||
},
|
|
||||||
ascendingLabel: {
|
|
||||||
type: String,
|
|
||||||
default: "Ascending",
|
|
||||||
},
|
|
||||||
descendingLabel: {
|
|
||||||
type: String,
|
|
||||||
default: "Descending",
|
|
||||||
},
|
|
||||||
showOrder: {
|
|
||||||
type: Boolean,
|
|
||||||
default: false,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const emit = defineEmits(["update:modelValue", "update:order"]);
|
|
||||||
|
|
||||||
const updateSortBy = (value) => {
|
|
||||||
emit("update:modelValue", value);
|
|
||||||
};
|
|
||||||
|
|
||||||
const toggleOrder = () => {
|
|
||||||
emit("update:order", props.order === "desc" ? "asc" : "desc");
|
|
||||||
};
|
|
||||||
</script>
|
|
||||||
|
|
||||||
<template>
|
|
||||||
<div class="plugin-sort-control">
|
|
||||||
<v-select
|
|
||||||
:model-value="modelValue"
|
|
||||||
:items="items"
|
|
||||||
density="compact"
|
|
||||||
variant="outlined"
|
|
||||||
hide-details
|
|
||||||
:label="label"
|
|
||||||
class="plugin-sort-control__select"
|
|
||||||
@update:model-value="updateSortBy"
|
|
||||||
>
|
|
||||||
<template #prepend-inner>
|
|
||||||
<v-icon size="small">mdi-sort</v-icon>
|
|
||||||
</template>
|
|
||||||
</v-select>
|
|
||||||
|
|
||||||
<v-btn
|
|
||||||
v-if="showOrder"
|
|
||||||
icon
|
|
||||||
variant="text"
|
|
||||||
density="compact"
|
|
||||||
@click="toggleOrder"
|
|
||||||
>
|
|
||||||
<v-icon>{{
|
|
||||||
order === "desc" ? "mdi-arrow-down-thin" : "mdi-arrow-up-thin"
|
|
||||||
}}</v-icon>
|
|
||||||
<v-tooltip activator="parent" location="top">
|
|
||||||
{{ order === "desc" ? descendingLabel : ascendingLabel }}
|
|
||||||
</v-tooltip>
|
|
||||||
</v-btn>
|
|
||||||
</div>
|
|
||||||
</template>
|
|
||||||
|
|
||||||
<style scoped>
|
|
||||||
.plugin-sort-control {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 8px;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
}
|
|
||||||
|
|
||||||
.plugin-sort-control__select {
|
|
||||||
min-width: 180px;
|
|
||||||
max-width: 220px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.plugin-sort-control__select :deep(.v-field__input),
|
|
||||||
.plugin-sort-control__select :deep(.v-field-label),
|
|
||||||
.plugin-sort-control__select :deep(.v-select__selection-text),
|
|
||||||
.plugin-sort-control__select :deep(.v-field__prepend-inner) {
|
|
||||||
font-size: 0.875rem;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed } from 'vue';
|
import { computed } from 'vue';
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
import { normalizeTextInput } from '@/utils/inputValue';
|
|
||||||
|
|
||||||
const { tm } = useModuleI18n('features/command');
|
const { tm } = useModuleI18n('features/command');
|
||||||
|
|
||||||
@@ -53,7 +52,6 @@ const statusItems = [
|
|||||||
{ title: tm('filters.disabled'), value: 'disabled' },
|
{ title: tm('filters.disabled'), value: 'disabled' },
|
||||||
{ title: tm('filters.conflict'), value: 'conflict' }
|
{ title: tm('filters.conflict'), value: 'conflict' }
|
||||||
];
|
];
|
||||||
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
@@ -110,11 +108,10 @@ const statusItems = [
|
|||||||
<div style="min-width: 200px; max-width: 350px; flex: 1; border: 1px solid #B9B9B9; border-radius: 16px;">
|
<div style="min-width: 200px; max-width: 350px; flex: 1; border: 1px solid #B9B9B9; border-radius: 16px;">
|
||||||
<v-text-field
|
<v-text-field
|
||||||
:model-value="searchQuery"
|
:model-value="searchQuery"
|
||||||
@update:model-value="emit('update:searchQuery', normalizeTextInput($event))"
|
@update:model-value="emit('update:searchQuery', $event)"
|
||||||
density="compact"
|
density="compact"
|
||||||
:label="tm('search.placeholder')"
|
:label="tm('search.placeholder')"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
clearable
|
|
||||||
variant="solo-filled"
|
variant="solo-filled"
|
||||||
flat
|
flat
|
||||||
hide-details
|
hide-details
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
*/
|
*/
|
||||||
import { ref, computed, type Ref } from 'vue';
|
import { ref, computed, type Ref } from 'vue';
|
||||||
import type { CommandItem, FilterState } from '../types';
|
import type { CommandItem, FilterState } from '../types';
|
||||||
import { normalizeTextInput } from '@/utils/inputValue';
|
|
||||||
|
|
||||||
export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
||||||
// 过滤状态
|
// 过滤状态
|
||||||
@@ -96,7 +95,7 @@ export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
|||||||
* 过滤后的指令列表(支持层级结构)
|
* 过滤后的指令列表(支持层级结构)
|
||||||
*/
|
*/
|
||||||
const filteredCommands = computed(() => {
|
const filteredCommands = computed(() => {
|
||||||
const query = normalizeTextInput(searchQuery.value).toLowerCase();
|
const query = searchQuery.value.toLowerCase();
|
||||||
const conflictCmds: CommandItem[] = [];
|
const conflictCmds: CommandItem[] = [];
|
||||||
const normalCmds: CommandItem[] = [];
|
const normalCmds: CommandItem[] = [];
|
||||||
|
|
||||||
@@ -185,3 +184,4 @@ export function useCommandFilters(commands: Ref<CommandItem[]>) {
|
|||||||
isGroupExpanded
|
isGroupExpanded
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
import { computed, onActivated, onMounted, ref, watch} from 'vue';
|
import { computed, onActivated, onMounted, ref, watch} from 'vue';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
import { normalizeTextInput } from '@/utils/inputValue';
|
|
||||||
|
|
||||||
// Composables
|
// Composables
|
||||||
import { useComponentData } from './composables/useComponentData';
|
import { useComponentData } from './composables/useComponentData';
|
||||||
@@ -84,7 +83,7 @@ const {
|
|||||||
} = useCommandActions(toast, () => fetchCommands(tm('messages.loadFailed')));
|
} = useCommandActions(toast, () => fetchCommands(tm('messages.loadFailed')));
|
||||||
|
|
||||||
const filteredTools = computed(() => {
|
const filteredTools = computed(() => {
|
||||||
const query = normalizeTextInput(toolSearch.value).trim().toLowerCase();
|
const query = toolSearch.value.trim().toLowerCase();
|
||||||
if (!query) return tools.value;
|
if (!query) return tools.value;
|
||||||
return tools.value.filter(tool =>
|
return tools.value.filter(tool =>
|
||||||
tool.name?.toLowerCase().includes(query) ||
|
tool.name?.toLowerCase().includes(query) ||
|
||||||
@@ -254,8 +253,7 @@ watch(viewMode, async (mode) => {
|
|||||||
<div class="d-flex flex-wrap align-center ga-3 mb-4">
|
<div class="d-flex flex-wrap align-center ga-3 mb-4">
|
||||||
<div style="min-width: 240px; max-width: 380px; flex: 1;">
|
<div style="min-width: 240px; max-width: 380px; flex: 1;">
|
||||||
<v-text-field
|
<v-text-field
|
||||||
:model-value="toolSearch"
|
v-model="toolSearch"
|
||||||
@update:model-value="toolSearch = normalizeTextInput($event)"
|
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
:label="tmTool('functionTools.search')"
|
:label="tmTool('functionTools.search')"
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
v-model="modelSearchProxy"
|
v-model="modelSearchProxy"
|
||||||
density="compact"
|
density="compact"
|
||||||
prepend-inner-icon="mdi-magnify"
|
prepend-inner-icon="mdi-magnify"
|
||||||
clearable
|
|
||||||
hide-details
|
hide-details
|
||||||
variant="solo-filled"
|
variant="solo-filled"
|
||||||
flat
|
flat
|
||||||
@@ -162,7 +161,6 @@
|
|||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
import { computed } from 'vue'
|
import { computed } from 'vue'
|
||||||
import { normalizeTextInput } from '@/utils/inputValue'
|
|
||||||
|
|
||||||
const props = defineProps({
|
const props = defineProps({
|
||||||
entries: {
|
entries: {
|
||||||
@@ -224,7 +222,7 @@ const emit = defineEmits([
|
|||||||
|
|
||||||
const modelSearchProxy = computed({
|
const modelSearchProxy = computed({
|
||||||
get: () => props.modelSearch,
|
get: () => props.modelSearch,
|
||||||
set: (val) => emit('update:modelSearch', normalizeTextInput(val))
|
set: (val) => emit('update:modelSearch', val)
|
||||||
})
|
})
|
||||||
|
|
||||||
const isProviderTesting = (providerId) => props.testingProviders.includes(providerId)
|
const isProviderTesting = (providerId) => props.testingProviders.includes(providerId)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
>
|
>
|
||||||
<v-icon
|
<v-icon
|
||||||
size="18"
|
size="18"
|
||||||
:color="props.variant === 'default' ? 'rgb(var(--v-theme-primary))' : undefined"
|
:color="props.variant === 'default' ? (useCustomizerStore().uiTheme === 'PurpleTheme' ? '#5e35b1' : '#d7c5fa') : undefined"
|
||||||
>
|
>
|
||||||
mdi-translate
|
mdi-translate
|
||||||
</v-icon>
|
</v-icon>
|
||||||
@@ -42,6 +42,7 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed } from 'vue'
|
import { computed } from 'vue'
|
||||||
import { useI18n, useLanguageSwitcher } from '@/i18n/composables'
|
import { useI18n, useLanguageSwitcher } from '@/i18n/composables'
|
||||||
|
import { useCustomizerStore } from '@/stores/customizer'
|
||||||
import type { Locale } from '@/i18n/types'
|
import type { Locale } from '@/i18n/types'
|
||||||
import StyledMenu from '@/components/shared/StyledMenu.vue'
|
import StyledMenu from '@/components/shared/StyledMenu.vue'
|
||||||
|
|
||||||
@@ -89,7 +90,7 @@ const changeLanguage = async (langCode: string) => {
|
|||||||
|
|
||||||
.language-switcher--default:hover {
|
.language-switcher--default:hover {
|
||||||
transform: scale(1.05);
|
transform: scale(1.05);
|
||||||
background: rgba(var(--v-theme-primary), 0.08) !important;
|
background: rgba(94, 53, 177, 0.08) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Header变体样式 - 完全继承Vuetify和action-btn的默认样式 */
|
/* Header变体样式 - 完全继承Vuetify和action-btn的默认样式 */
|
||||||
@@ -102,4 +103,8 @@ const changeLanguage = async (langCode: string) => {
|
|||||||
/* 继承action-btn样式,与工具栏主题按钮保持一致 */
|
/* 继承action-btn样式,与工具栏主题按钮保持一致 */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* 深色模式下的悬停效果(仅对default变体) */
|
||||||
|
:deep(.v-theme--PurpleThemeDark) .language-switcher--default:hover {
|
||||||
|
background: rgba(114, 46, 209, 0.12) !important;
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
@@ -6,11 +6,11 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="logo-text">
|
<div class="logo-text">
|
||||||
<h2
|
<h2
|
||||||
:style="{ color: 'rgb(var(--v-theme-primary))' }"
|
:style="{color: useCustomizerStore().uiTheme === 'PurpleTheme' ? '#5e35b1' : '#d7c5fa'}"
|
||||||
v-html="formatTitle(title || t('core.header.logoTitle'))"
|
v-html="formatTitle(title || t('core.header.logoTitle'))"
|
||||||
></h2>
|
></h2>
|
||||||
<!-- 父子组件传递css变量可能会出错,暂时使用十六进制颜色值 -->
|
<!-- 父子组件传递css变量可能会出错,暂时使用十六进制颜色值 -->
|
||||||
<h4 :style="{ color: 'rgba(var(--v-theme-on-surface), 0.72)' }"
|
<h4 :style="{color: useCustomizerStore().uiTheme === 'PurpleTheme' ? '#000000aa' : '#ffffffcc'}"
|
||||||
class="hint-text">{{ subtitle || t('core.header.accountDialog.title') }}</h4>
|
class="hint-text">{{ subtitle || t('core.header.accountDialog.title') }}</h4>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -18,6 +18,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
|
import { useCustomizerStore } from "@/stores/customizer";
|
||||||
import { useI18n } from '@/i18n/composables';
|
import { useI18n } from '@/i18n/composables';
|
||||||
|
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
|
|||||||
@@ -48,24 +48,6 @@ const loading = ref(false);
|
|||||||
const isEmpty = ref(false);
|
const isEmpty = ref(false);
|
||||||
const copyFeedbackTimer = ref(null);
|
const copyFeedbackTimer = ref(null);
|
||||||
const lastRequestId = ref(0);
|
const lastRequestId = ref(0);
|
||||||
const scrollContainer = ref(null);
|
|
||||||
|
|
||||||
function slugifyHeading(text, slugCounts) {
|
|
||||||
const base = (text || "")
|
|
||||||
.trim()
|
|
||||||
.toLowerCase()
|
|
||||||
.normalize("NFKD")
|
|
||||||
.replace(/[\u0300-\u036f]/g, "")
|
|
||||||
.replace(/[^\p{Letter}\p{Number}\s-]/gu, "")
|
|
||||||
.replace(/\s+/g, "-")
|
|
||||||
.replace(/-+/g, "-");
|
|
||||||
|
|
||||||
if (!base) return "";
|
|
||||||
|
|
||||||
const count = slugCounts.get(base) || 0;
|
|
||||||
slugCounts.set(base, count + 1);
|
|
||||||
return count === 0 ? base : `${base}-${count}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
if (copyFeedbackTimer.value) clearTimeout(copyFeedbackTimer.value);
|
if (copyFeedbackTimer.value) clearTimeout(copyFeedbackTimer.value);
|
||||||
@@ -171,18 +153,6 @@ const renderedHtml = computed(() => {
|
|||||||
// 3. 后处理方案:完全隔离,安全性最高
|
// 3. 后处理方案:完全隔离,安全性最高
|
||||||
const tempDiv = document.createElement("div");
|
const tempDiv = document.createElement("div");
|
||||||
tempDiv.innerHTML = cleanHtml;
|
tempDiv.innerHTML = cleanHtml;
|
||||||
|
|
||||||
const slugCounts = new Map();
|
|
||||||
tempDiv.querySelectorAll("h1, h2, h3, h4, h5, h6").forEach((heading) => {
|
|
||||||
if (heading.id) {
|
|
||||||
slugCounts.set(heading.id, (slugCounts.get(heading.id) || 0) + 1);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const slug = slugifyHeading(heading.textContent, slugCounts);
|
|
||||||
if (slug) heading.id = slug;
|
|
||||||
});
|
|
||||||
|
|
||||||
tempDiv.querySelectorAll("a").forEach((link) => {
|
tempDiv.querySelectorAll("a").forEach((link) => {
|
||||||
const href = link.getAttribute("href");
|
const href = link.getAttribute("href");
|
||||||
// 强制所有外部链接使用安全的 _blank 策略
|
// 强制所有外部链接使用安全的 _blank 策略
|
||||||
@@ -281,35 +251,18 @@ watch(
|
|||||||
|
|
||||||
function handleContainerClick(event) {
|
function handleContainerClick(event) {
|
||||||
const btn = event.target.closest(".copy-code-btn");
|
const btn = event.target.closest(".copy-code-btn");
|
||||||
if (btn) {
|
if (!btn) return;
|
||||||
const code = btn.closest(".code-block-wrapper")?.querySelector("code");
|
const code = btn.closest(".code-block-wrapper")?.querySelector("code");
|
||||||
if (code) {
|
if (code) {
|
||||||
if (navigator.clipboard?.writeText) {
|
if (navigator.clipboard?.writeText) {
|
||||||
navigator.clipboard
|
navigator.clipboard
|
||||||
.writeText(code.textContent)
|
.writeText(code.textContent)
|
||||||
.then(() => showCopyFeedback(btn, true))
|
.then(() => showCopyFeedback(btn, true))
|
||||||
.catch(() => tryFallbackCopy(code.textContent, btn));
|
.catch(() => tryFallbackCopy(code.textContent, btn));
|
||||||
} else {
|
} else {
|
||||||
tryFallbackCopy(code.textContent, btn);
|
tryFallbackCopy(code.textContent, btn);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const anchor = event.target.closest('a[href^="#"]');
|
|
||||||
if (!anchor) return;
|
|
||||||
|
|
||||||
const rawHref = anchor.getAttribute("href");
|
|
||||||
const targetId = rawHref ? decodeURIComponent(rawHref.slice(1)) : "";
|
|
||||||
if (!targetId) return;
|
|
||||||
|
|
||||||
const target = scrollContainer.value?.querySelector(
|
|
||||||
`#${CSS.escape(targetId)}`,
|
|
||||||
);
|
|
||||||
if (!target) return;
|
|
||||||
|
|
||||||
event.preventDefault();
|
|
||||||
target.scrollIntoView({ behavior: "smooth", block: "start" });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function tryFallbackCopy(text, btn) {
|
function tryFallbackCopy(text, btn) {
|
||||||
@@ -373,7 +326,7 @@ const showActionArea = computed(() => {
|
|||||||
<v-icon>mdi-close</v-icon>
|
<v-icon>mdi-close</v-icon>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
</v-card-title>
|
</v-card-title>
|
||||||
<v-card-text ref="scrollContainer" style="overflow-y: auto">
|
<v-card-text style="overflow-y: auto">
|
||||||
<div v-if="showActionArea" class="d-flex justify-space-between mb-4">
|
<div v-if="showActionArea" class="d-flex justify-space-between mb-4">
|
||||||
<v-btn
|
<v-btn
|
||||||
v-if="modeConfig.showGithubButton && repoUrl"
|
v-if="modeConfig.showGithubButton && repoUrl"
|
||||||
@@ -483,7 +436,6 @@ const showActionArea = computed(() => {
|
|||||||
margin-bottom: 16px;
|
margin-bottom: 16px;
|
||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
line-height: 1.25;
|
line-height: 1.25;
|
||||||
scroll-margin-top: 12px;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
:deep(.markdown-body h1) {
|
:deep(.markdown-body h1) {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user