diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 962612d5a..77eeb3be6 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -1,46 +1,44 @@
-name: '🐛 报告 Bug'
+name: '🐛 Report Bug / 报告 Bug'
title: '[Bug]'
-description: 提交报告帮助我们改进。
+description: Submit bug report to help us improve. / 提交报告帮助我们改进。
labels: [ 'bug' ]
body:
- type: markdown
attributes:
value: |
- 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
+ Thank you for taking the time to report this issue! Please describe your problem accurately. If possible, please provide a reproducible snippet (this will help resolve the issue more quickly). Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
- type: textarea
attributes:
- label: 发生了什么
- description: 描述你遇到的异常
+ label: What happened / 发生了什么
+ description: Description
placeholder: >
- 一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
+ Please provide a clear and specific description of what this exception is. Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
validations:
required: true
- type: textarea
attributes:
- label: 如何复现?
+ label: Reproduce / 如何复现?
description: >
- 复现该问题的步骤
+ The steps to reproduce the issue. / 复现该问题的步骤
placeholder: >
- 如: 1. 打开 '...'
+ Example: 1. Open '...'
validations:
required: true
- type: textarea
attributes:
- label: AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
- description: >
- 请提供您的 AstrBot 版本和部署方式。
+ label: AstrBot version, deployment method (e.g., Windows Docker Desktop deployment), provider used, and messaging platform used. / AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
placeholder: >
- 如: 3.1.8 Docker, 3.1.7 Windows启动器
+ Example: 4.5.7 Docker, 3.1.7 Windows Launcher
validations:
required: true
- type: dropdown
attributes:
- label: 操作系统
+ label: OS
description: |
- 你在哪个操作系统上遇到了这个问题?
+ On which operating system did you encounter this problem? / 你在哪个操作系统上遇到了这个问题?
multiple: false
options:
- 'Windows'
@@ -53,30 +51,30 @@ body:
- type: textarea
attributes:
- label: 报错日志
+ label: Logs / 报错日志
description: >
- 如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
+ Please provide complete Debug-level logs, such as error logs and screenshots. Don't worry if they're long! Please note that issues with insufficient details or no logs will be closed immediately. Thank you for your understanding. / 如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
placeholder: >
- 请提供完整的报错日志或截图。
+ Please provide a complete error log or screenshot. / 请提供完整的报错日志或截图。
validations:
required: true
- type: checkboxes
attributes:
- label: 你愿意提交 PR 吗?
+ label: Are you willing to submit a PR? / 你愿意提交 PR 吗?
description: >
- 这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
+ This is not required, but we would be happy to provide guidance during the contribution process, especially if you already have a good understanding of how to implement the fix. / 这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
options:
- - label: 是的,我愿意提交 PR!
+ - label: Yes!
- type: checkboxes
attributes:
label: Code of Conduct
options:
- label: >
- 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
+ I have read and agree to abide by the project's [Code of Conduct](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true
- type: markdown
attributes:
- value: "感谢您填写我们的表单!"
+ value: "Thank you for filling out our form! / 感谢您填写我们的表单!"
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index e58a301ea..70bb8f30c 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,44 +1,25 @@
-
-
-
-fixes #XYZ
-
----
-
-### Motivation / 动机
-
-
-
+
+
### Modifications / 改动点
-### Verification Steps / 验证步骤
-
-
-
+- [x] This is NOT a breaking change. / 这不是一个破坏性变更。
+
### Screenshots or Test Results / 运行截图或测试结果
-
-
-### Compatibility & Breaking Changes / 兼容性与破坏性变更
-
-
-
-
-- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
-- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
+
---
### Checklist / 检查清单
-
+
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
diff --git a/.github/workflows/auto_release.yml b/.github/workflows/auto_release.yml
index 04d20c2da..f13f3ae51 100644
--- a/.github/workflows/auto_release.yml
+++ b/.github/workflows/auto_release.yml
@@ -13,7 +13,7 @@ jobs:
contents: write
steps:
- name: Checkout repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Dashboard Build
run: |
@@ -70,7 +70,7 @@ jobs:
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml
index d5260e879..a183f1bb2 100644
--- a/.github/workflows/code-format.yml
+++ b/.github/workflows/code-format.yml
@@ -12,7 +12,7 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
index 85cde14f5..5aeef1eff 100644
--- a/.github/workflows/codeql.yml
+++ b/.github/workflows/codeql.yml
@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
index 1fb2e2024..6ae8c7b9b 100644
--- a/.github/workflows/coverage_test.yml
+++ b/.github/workflows/coverage_test.yml
@@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
with:
fetch-depth: 0
diff --git a/.github/workflows/dashboard_ci.yml b/.github/workflows/dashboard_ci.yml
index f02085f84..f403da773 100644
--- a/.github/workflows/dashboard_ci.yml
+++ b/.github/workflows/dashboard_ci.yml
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
- uses: actions/checkout@v5
+ uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v6
@@ -36,7 +36,7 @@ jobs:
zip -r dist.zip dist
- name: Archive production artifacts
- uses: actions/upload-artifact@v5
+ uses: actions/upload-artifact@v6
with:
name: dist-without-markdown
path: |
diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
index ecc098c3d..0d1550e1b 100644
--- a/.github/workflows/docker-image.yml
+++ b/.github/workflows/docker-image.yml
@@ -3,18 +3,125 @@ name: Docker Image CI/CD
on:
push:
tags:
- - 'v*'
+ - "v*"
+ schedule:
+ # Run at 00:00 UTC every day
+ - cron: "0 0 * * *"
workflow_dispatch:
jobs:
- publish-docker:
+ build-nightly-image:
+ if: github.event_name == 'schedule'
runs-on: ubuntu-latest
+ env:
+ DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
+ GHCR_OWNER: soulter
+ HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
steps:
- - name: Pull The Codes
- uses: actions/checkout@v5
+ - name: Checkout
+ uses: actions/checkout@v6
with:
- fetch-depth: 0 # Must be 0 so we can fetch tags
+ fetch-depth: 1
+ fetch-tag: true
+
+ - name: Check for new commits today
+ if: github.event_name == 'schedule'
+ id: check-commits
+ run: |
+ # Get commits from the last 24 hours
+ commits=$(git log --since="24 hours ago" --oneline)
+ if [ -z "$commits" ]; then
+ echo "No commits in the last 24 hours, skipping build"
+ echo "has_commits=false" >> $GITHUB_OUTPUT
+ else
+ echo "Found commits in the last 24 hours:"
+ echo "$commits"
+ echo "has_commits=true" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Exit if no commits
+ if: github.event_name == 'schedule' && steps.check-commits.outputs.has_commits == 'false'
+ run: exit 0
+
+ - name: Build Dashboard
+ run: |
+ cd dashboard
+ npm install
+ npm run build
+ mkdir -p dist/assets
+ echo $(git rev-parse HEAD) > dist/assets/version
+ cd ..
+ mkdir -p data
+ cp -r dashboard/dist data/
+
+ - name: Determine test image tags
+ id: test-meta
+ run: |
+ short_sha=$(echo "${GITHUB_SHA}" | cut -c1-12)
+ build_date=$(date +%Y%m%d)
+ echo "short_sha=$short_sha" >> $GITHUB_OUTPUT
+ echo "build_date=$build_date" >> $GITHUB_OUTPUT
+
+ - name: Set QEMU
+ uses: docker/setup-qemu-action@v3
+
+ - name: Set Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Log in to DockerHub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKER_HUB_USERNAME }}
+ password: ${{ secrets.DOCKER_HUB_PASSWORD }}
+
+ - name: Login to GitHub Container Registry
+ if: env.HAS_GHCR_TOKEN == 'true'
+ uses: docker/login-action@v3
+ with:
+ registry: ghcr.io
+ username: ${{ env.GHCR_OWNER }}
+ password: ${{ secrets.GHCR_GITHUB_TOKEN }}
+
+ - name: Build nightly image tags list
+ id: test-tags
+ run: |
+ TAGS="${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-latest
+ ${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
+ if [ "${{ env.HAS_GHCR_TOKEN }}" = "true" ]; then
+ TAGS="$TAGS
+ ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-latest
+ ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
+ fi
+ echo "tags<> $GITHUB_OUTPUT
+ echo "$TAGS" >> $GITHUB_OUTPUT
+ echo "EOF" >> $GITHUB_OUTPUT
+
+ - name: Build and Push Nightly Image
+ uses: docker/build-push-action@v6
+ with:
+ context: .
+ platforms: linux/amd64,linux/arm64
+ push: true
+ tags: ${{ steps.test-tags.outputs.tags }}
+
+ - name: Post build notifications
+ run: echo "Test Docker image has been built and pushed successfully"
+
+ build-release-image:
+ if: github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v'))
+ runs-on: ubuntu-latest
+ env:
+ DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
+ GHCR_OWNER: soulter
+ HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 1
+ fetch-tag: true
- name: Get latest tag (only on manual trigger)
id: get-latest-tag
@@ -27,21 +134,22 @@ jobs:
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- - name: Check if version is pre-release
- id: check-prerelease
+ - name: Compute release metadata
+ id: release-meta
run: |
- if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
+ if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
else
- version="${{ github.ref_name }}"
+ version="${GITHUB_REF#refs/tags/}"
fi
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
echo "is_prerelease=true" >> $GITHUB_OUTPUT
- echo "Version $version is a pre-release, will not push latest tag"
+ echo "Version $version marked as pre-release"
else
echo "is_prerelease=false" >> $GITHUB_OUTPUT
- echo "Version $version is a stable release, will push latest tag"
+ echo "Version $version marked as stable"
fi
+ echo "version=$version" >> $GITHUB_OUTPUT
- name: Build Dashboard
run: |
@@ -67,23 +175,24 @@ jobs:
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
+ if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v3
with:
registry: ghcr.io
- username: Soulter
+ username: ${{ env.GHCR_OWNER }}
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- - name: Build and Push Docker to DockerHub and Github GHCR
+ - name: Build and Push Release Image
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: |
- ${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
- ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- ${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
- ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
+ ${{ steps.release-meta.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', env.DOCKER_HUB_USERNAME) || '' }}
+ ${{ steps.release-meta.outputs.is_prerelease == 'false' && env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:latest', env.GHCR_OWNER) || '' }}
+ ${{ format('{0}/astrbot:{1}', env.DOCKER_HUB_USERNAME, steps.release-meta.outputs.version) }}
+ ${{ env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:{1}', env.GHCR_OWNER, steps.release-meta.outputs.version) || '' }}
- name: Post build notifications
- run: echo "Docker image has been built and pushed successfully"
+ run: echo "Release Docker image has been built and pushed successfully"
diff --git a/.github/workflows/smoke_test.yml b/.github/workflows/smoke_test.yml
new file mode 100644
index 000000000..15571867f
--- /dev/null
+++ b/.github/workflows/smoke_test.yml
@@ -0,0 +1,58 @@
+name: Smoke Test
+
+on:
+ push:
+ branches:
+ - master
+ paths-ignore:
+ - 'README*.md'
+ - 'changelogs/**'
+ - 'dashboard/**'
+ pull_request:
+ workflow_dispatch:
+
+jobs:
+ smoke-test:
+ name: Run smoke tests
+ runs-on: ubuntu-latest
+ timeout-minutes: 10
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.12'
+
+ - name: Install UV package manager
+ run: |
+ pip install uv
+
+ - name: Install dependencies
+ run: |
+ uv sync
+ timeout-minutes: 15
+
+ - name: Run smoke tests
+ run: |
+ uv run main.py &
+ APP_PID=$!
+
+ echo "Waiting for application to start..."
+ for i in {1..60}; do
+ if curl -f http://localhost:6185 > /dev/null 2>&1; then
+ echo "Application started successfully!"
+ kill $APP_PID
+ exit 0
+ fi
+ sleep 1
+ done
+
+ echo "Application failed to start within 30 seconds"
+ kill $APP_PID 2>/dev/null || true
+ exit 1
+ timeout-minutes: 2
diff --git a/.gitignore b/.gitignore
index 0934fa257..4ad57cf91 100644
--- a/.gitignore
+++ b/.gitignore
@@ -35,6 +35,7 @@ dashboard/dist/
dashboard/src-tauri/target
package-lock.json
package.json
+yarn.lock
# Operating System
**/.DS_Store
@@ -48,4 +49,6 @@ astrbot.lock
chroma
venv/*
pytest.ini
-build/
\ No newline at end of file
+build/
+AGENTS.md
+IFLOW.md
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 000000000..47404d563
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,90 @@
+# CONTRIBUTING
+
+## 贡献指南
+
+首先,感谢您花时间做出贡献!❤️
+
+所有类型的贡献都受到鼓励和重视。有关不同的帮助方式和处理方式的详细信息,请参阅[目录](#目录)。在做出贡献之前,请确保阅读相关部分。这将使我们维护人员的工作变得更加容易,并为所有参与者带来顺畅的体验。社区期待您的贡献。🎉
+
+### 目录
+
+- [报告问题](#报告问题)
+- [提交代码更改](#提交代码更改)
+
+### 报告问题
+
+如果您在使用 AstrBot 时遇到任何问题,请按照以下步骤报告:
+
+1. **检查现有问题**:在提交新问题之前,请先检查 [Issues](https://github.com/AstrBotDevs/AstrBot/issues) 中是否已经存在类似的问题。
+2. **创建新问题**:如果没有类似的问题,请创建一个新问题。请确保提供以下信息:
+ - 问题的简要描述
+ - 重现问题的步骤
+ - 预期结果和实际结果
+ - 相关日志或错误消息
+
+### 提交代码更改
+
+#### 分支命名
+
+我们使用 `fix/` 前缀来修复错误,使用 `feat/` 前缀来添加新功能。对于 `fix/` 分支,请使用简短的描述,或者直接使用 Issue 编号。例如:`fix/1234` 或者 `fix/1234-login-typo`。对于 `feat/` 分支,请使用简短的描述,例如:`feat/add-user-profile`。
+
+#### PR 描述
+
+- 请使用英文描述您的 PR。
+- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
+
+#### 代码规范
+
+##### Core
+
+我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
+
+```bash
+ruff format .
+ruff check .
+```
+
+如果您使用 VSCode,可以安装 `Ruff` 插件。
+
+
+## Contributing Guide
+
+First off, thanks for taking the time to contribute! ❤️
+
+All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉
+
+### Table of Contents
+
+- [Reporting Issues](#reporting-issues)
+- [Pull Requests](#pull-requests)
+
+### Reporting Issues
+
+If you encounter any issues while using AstrBot, please follow these steps to report them:
+1. **Check Existing Issues**: Before submitting a new issue, please check if a similar issue already exists in the [Issues](https://github.com/AstrBotDevs/AstrBot/issues) section of the repository.
+2. **Create a New Issue**: If no similar issue exists, please create a new issue. Make sure to provide the following information:
+ - A brief description of the issue
+ - Steps to reproduce the issue
+ - Expected and actual results
+ - Relevant logs or error messages
+
+### Pull Requests
+
+#### Branch Naming
+
+We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features. For `fix/` branches, please use a short description or directly use the Issue number, e.g., `fix/1234` or `fix/1234-login-typo`. For `feat/` branches, please use a short description, e.g., `feat/add-user-profile`.
+
+#### PR Description
+- Please use English to describe your PR.
+- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
+
+#### Code Style
+
+##### Core
+
+We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
+
+```bash
+ruff format .
+ruff check .
+```
diff --git a/README.md b/README.md
index 5083f0264..46254b2b4 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,13 @@

-
-
-
+
+
English |
+
日本語 |
+
繁體中文 |
+
Français |
+
Русский
-AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
+AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
+
+
## 主要功能
-1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
-2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
-3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
-4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
-5. **WebUI**。可视化配置和管理机器人,功能齐全。
+1. 💯 免费 & 开源。
+1. ✨ AI 大模型对话,多模态,Agent,MCP,知识库,人格设定。
+2. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。
+2. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。
+3. 📦 插件扩展,已有近 800 个插件可一键安装。
+5. 💻 WebUI 支持。
+6. 🌐 国际化(i18n)支持。
-## 部署方式
+## 快速开始
#### Docker 部署(推荐 🥳)
@@ -50,6 +56,12 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
+#### uv 部署
+
+```bash
+uvx astrbot
+```
+
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
@@ -101,24 +113,6 @@ uv run main.py
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
-## 🌍 社区
-
-### QQ 群组
-
-- 1 群:322154837
-- 3 群:630166526
-- 5 群:822130018
-- 6 群:753075035
-- 开发者群:975206796
-
-### Telegram 群组
-
-
-
-### Discord 群组
-
-
-
## 支持的消息平台
**官方维护**
@@ -205,6 +199,25 @@ pip install pre-commit
pre-commit install
```
+## 🌍 社区
+
+### QQ 群组
+
+- 1 群:322154837
+- 3 群:630166526
+- 5 群:822130018
+- 6 群:753075035
+- 7 群:743746109
+- 开发者群:975206796
+
+### Telegram 群组
+
+
+
+### Discord 群组
+
+
+
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
@@ -230,4 +243,10 @@ pre-commit install
+
+
_私は、高性能ですから!_
+
+
+
-
-
+
-_✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
+
+
-
-[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
-
-
-
-[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
-
-[](https://codecov.io/gh/AstrBotDevs/AstrBot)
-
-
Documentation |
-
Issue Tracking
+
-AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
+
-## ✨ Key Features
+
-1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
-2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
-3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
-4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
-5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
-6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
+
-> [!TIP]
-> Dashboard Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
-> Username: `astrbot`, Password: `astrbot` (LLM not configured for chat page)
+
中文 |
+
日本語 |
+
繁體中文 |
+
Français |
+
Русский
-## ✨ Deployment
+
Documentation |
+
Blog |
+
Roadmap |
+
Issue Tracker
+
-#### Docker Deployment
+AstrBot is an open-source all-in-one Agent chatbot platform that integrates with mainstream instant messaging apps. It provides reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether you're building a personal AI companion, intelligent customer service, automation assistant, or enterprise knowledge base, AstrBot enables you to quickly build production-ready AI applications within your IM platform workflows.
-See docs: [Deploy with Docker](https://astrbot.app/deploy/astrbot/docker.html#docker-deployment)
+
-#### Windows Installer
+## Key Features
-Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app/deploy/astrbot/windows.html)
+1. 💯 Free & Open Source.
+2. ✨ AI LLM Conversations, Multimodal, Agent, MCP, Knowledge Base, Persona Settings.
+3. 🤖 Supports integration with Dify, Alibaba Cloud Bailian, Coze and other agent platforms.
+4. 🌐 Multi-Platform: QQ, WeChat Work, Feishu, DingTalk, WeChat Official Accounts, Telegram, Slack, and [more](#supported-messaging-platforms).
+5. 📦 Plugin Extensions with nearly 800 plugins available for one-click installation.
+6. 💻 WebUI Support.
+7. 🌐 Internationalization (i18n) Support.
-#### Replit Deployment
+## Quick Start
+
+#### Docker Deployment (Recommended 🥳)
+
+We recommend deploying AstrBot using Docker or Docker Compose.
+
+Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
+
+#### uv Deployment
+
+```bash
+uvx astrbot
+```
+
+#### BT-Panel Deployment
+
+AstrBot has partnered with BT-Panel and is now available in their marketplace.
+
+Please refer to the official documentation: [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html).
+
+#### 1Panel Deployment
+
+AstrBot has been officially listed on the 1Panel marketplace.
+
+Please refer to the official documentation: [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html).
+
+#### Deploy on RainYun
+
+AstrBot has been officially listed on RainYun's cloud application platform with one-click deployment.
+
+[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
+
+#### Deploy on Replit
+
+Community-contributed deployment method.
[](https://repl.it/github/AstrBotDevs/AstrBot)
+#### Windows One-Click Installer
+
+Please refer to the official documentation: [Deploy AstrBot with Windows One-Click Installer](https://astrbot.app/deploy/astrbot/windows.html).
+
#### CasaOS Deployment
-Community-contributed method.
-See docs: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html)
+Community-contributed deployment method.
+
+Please refer to the official documentation: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html).
#### Manual Deployment
-See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
+First, install uv:
-## ⚡ Platform Support
+```bash
+pip install uv
+```
-| Platform | Status | Details | Message Types |
-| -------------------------------------------------------------- | ------ | ------------------- | ------------------- |
-| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
-| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
-| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
-| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
-| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
-| Feishu | ✔ | Group chats | Text, Images |
-| WeChat Open Platform | 🚧 | Planned | - |
-| Discord | 🚧 | Planned | - |
-| WhatsApp | 🚧 | Planned | - |
-| Xiaomi Speakers | 🚧 | Planned | - |
+Install AstrBot via Git Clone:
-## Provider Support Status
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
+uv run main.py
+```
-| Name | Support | Type | Notes |
-|---------------------------|---------|------------------------|-----------------------------------------------------------------------|
-| OpenAI API | ✔ | Text Generation | Supports all OpenAI API-compatible services including DeepSeek, Google Gemini, GLM, Moonshot, Alibaba Cloud Bailian, Silicon Flow, xAI, etc. |
-| Claude API | ✔ | Text Generation | |
-| Google Gemini API | ✔ | Text Generation | |
-| Dify | ✔ | LLMOps | |
-| DashScope (Alibaba Cloud) | ✔ | LLMOps | |
-| Ollama | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
-| LM Studio | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
-| LLMTuner | ✔ | Model Loader | Local loading of fine-tuned models (e.g. LoRA) |
-| OneAPI | ✔ | LLM Distribution | |
-| Whisper | ✔ | Speech-to-Text | Supports API and local deployment |
-| SenseVoice | ✔ | Speech-to-Text | Local deployment |
-| OpenAI TTS API | ✔ | Text-to-Speech | |
-| Fishaudio | ✔ | Text-to-Speech | Project involving GPT-Sovits author |
+Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
-# 🦌 Roadmap
+## Supported Messaging Platforms
-> [!TIP]
-> Suggestions welcome via Issues <3
+**Officially Maintained**
-- [ ] Ensure feature parity across all platform adapters
-- [ ] Optimize plugin APIs
-- [ ] Add default TTS services (e.g., GPT-Sovits)
-- [ ] Enhance chat features with persistent memory
-- [ ] i18n Planning
+- QQ (Official Platform & OneBot)
+- Telegram
+- WeChat Work Application & WeChat Work Intelligent Bot
+- WeChat Customer Service & WeChat Official Accounts
+- Feishu (Lark)
+- DingTalk
+- Slack
+- Discord
+- Satori
+- Misskey
+- WhatsApp (Coming Soon)
+- LINE (Coming Soon)
-## ❤️ Contributions
+**Community Maintained**
-All Issues/PRs welcome! Simply submit your changes to this project :)
+- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
+- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
+- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
-For major features, please discuss via Issues first.
+## Supported Model Services
-## 🌟 Support
+**LLM Services**
-- Star this project!
-- Support via [Afdian](https://afdian.com/a/soulter)
-- WeChat support: [QR Code](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)
+- OpenAI and Compatible Services
+- Anthropic
+- Google Gemini
+- Moonshot AI
+- Zhipu AI
+- DeepSeek
+- Ollama (Self-hosted)
+- LM Studio (Self-hosted)
+- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
+- [302.AI](https://share.302.ai/rr1M3l)
+- [TokenPony](https://www.tokenpony.cn/3YPyf)
+- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
+- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
+- ModelScope
+- OneAPI
-## ✨ Demos
+**LLMOps Platforms**
-> [!NOTE]
-> Code executor file I/O currently tested with Napcat(QQ)/Lagrange(QQ)
+- Dify
+- Alibaba Cloud Bailian Applications
+- Coze
-
+**Speech-to-Text Services**
-
+- OpenAI Whisper
+- SenseVoice
-_✨ Docker-based Sandboxed Code Executor (Beta) ✨_
+**Text-to-Speech Services**
-
+- OpenAI TTS
+- Gemini TTS
+- GPT-Sovits-Inference
+- GPT-Sovits
+- FishAudio
+- Edge TTS
+- Alibaba Cloud Bailian TTS
+- Azure TTS
+- Minimax TTS
+- Volcano Engine TTS
-_✨ Multimodal Input, Web Search, Text-to-Image ✨_
+## ❤️ Contributing
-
+Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
-_✨ Natural Language TODO Lists ✨_
+### How to Contribute
-
-
+You can contribute by reviewing issues or helping with pull request reviews. Any issues or PRs are welcome to encourage community participation. Of course, these are just suggestions—you can contribute in any way you like. For adding new features, please discuss through an Issue first.
-_✨ Plugin System Showcase ✨_
+### Development Environment
-
+AstrBot uses `ruff` for code formatting and linting.
-_✨ Web Dashboard ✨_
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot
+pip install pre-commit
+pre-commit install
+```
-
+## 🌍 Community
-_✨ Built-in Web Chat Interface ✨_
+### QQ Groups
-
+- Group 1: 322154837
+- Group 3: 630166526
+- Group 5: 822130018
+- Group 6: 753075035
+- Developer Group: 975206796
+
+### Telegram Group
+
+
+
+### Discord Server
+
+
+
+## ❤️ Special Thanks
+
+Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
+
+
+
+
+
+Additionally, the birth of this project would not have been possible without the help of the following open-source projects:
+
+- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - The amazing cat framework
## ⭐ Star History
-> [!TIP]
-> If this project helps you, please give it a star <3
+> [!TIP]
+> If this project has helped you in your life or work, or if you're interested in its future development, please give the project a Star. It's the driving force behind maintaining this open-source project <3
-
-[](https://star-history.com/#AstrBotDevs/AstrBot&Date)
+
+[](https://star-history.com/#astrbotdevs/astrbot&Date)
-## Disclaimer
-
-1. Licensed under `AGPL-v3`.
-2. WeChat integration uses [Gewechat](https://github.com/Devo919/Gewechat). Use at your own risk with non-critical accounts.
-3. Users must comply with local laws and regulations.
-
-
-
+
_私は、高性能ですから!_
-
diff --git a/README_fr.md b/README_fr.md
new file mode 100644
index 000000000..8f658c9a0
--- /dev/null
+++ b/README_fr.md
@@ -0,0 +1,248 @@
+
+
+
+
+
+
+AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègre aux principales applications de messagerie instantanée. Elle fournit une infrastructure d'IA conversationnelle fiable et évolutive pour les particuliers, les développeurs et les équipes. Que vous construisiez un compagnon IA personnel, un service client intelligent, un assistant d'automatisation ou une base de connaissances d'entreprise, AstrBot vous permet de créer rapidement des applications d'IA prêtes pour la production dans les flux de travail de votre plateforme de messagerie.
+
+
+
+## Fonctionnalités principales
+
+1. 💯 Gratuit & Open Source.
+2. ✨ Conversations avec LLM IA, Multimodal, Agent, MCP, Base de connaissances, Paramètres de personnalité.
+3. 🤖 Prise en charge de l'intégration avec Dify, Alibaba Cloud Bailian, Coze et autres plateformes d'agents.
+4. 🌐 Multi-plateforme : QQ, WeChat Work, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack, et [plus encore](#plateformes-de-messagerie-prises-en-charge).
+5. 📦 Extensions de plugins avec près de 800 plugins disponibles pour une installation en un clic.
+6. 💻 Support WebUI.
+7. 🌐 Support de l'internationalisation (i18n).
+
+## Démarrage rapide
+
+#### Déploiement Docker (Recommandé 🥳)
+
+Nous recommandons de déployer AstrBot en utilisant Docker ou Docker Compose.
+
+Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
+
+#### Déploiement uv
+
+```bash
+uvx astrbot
+```
+
+#### Déploiement BT-Panel
+
+AstrBot s'est associé à BT-Panel et est maintenant disponible sur leur marketplace.
+
+Veuillez consulter la documentation officielle : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
+
+#### Déploiement 1Panel
+
+AstrBot a été officiellement listé sur le marketplace 1Panel.
+
+Veuillez consulter la documentation officielle : [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
+
+#### Déployer sur RainYun
+
+AstrBot a été officiellement listé sur la plateforme d'applications cloud de RainYun avec un déploiement en un clic.
+
+[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
+
+#### Déployer sur Replit
+
+Méthode de déploiement contribuée par la communauté.
+
+[](https://repl.it/github/AstrBotDevs/AstrBot)
+
+#### Installateur Windows en un clic
+
+Veuillez consulter la documentation officielle : [Déployer AstrBot avec l'installateur Windows en un clic](https://astrbot.app/deploy/astrbot/windows.html).
+
+#### Déploiement CasaOS
+
+Méthode de déploiement contribuée par la communauté.
+
+Veuillez consulter la documentation officielle : [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
+
+#### Déploiement manuel
+
+Tout d'abord, installez uv :
+
+```bash
+pip install uv
+```
+
+Installez AstrBot via Git Clone :
+
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
+uv run main.py
+```
+
+Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources](https://astrbot.app/deploy/astrbot/cli.html).
+
+## Plateformes de messagerie prises en charge
+
+**Maintenues officiellement**
+
+- QQ (Plateforme officielle & OneBot)
+- Telegram
+- Application WeChat Work & Bot intelligent WeChat Work
+- Service client WeChat & Comptes officiels WeChat
+- Feishu (Lark)
+- DingTalk
+- Slack
+- Discord
+- Satori
+- Misskey
+- WhatsApp (Bientôt disponible)
+- LINE (Bientôt disponible)
+
+**Maintenues par la communauté**
+
+- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
+- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
+- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
+
+## Services de modèles pris en charge
+
+**Services LLM**
+
+- OpenAI et services compatibles
+- Anthropic
+- Google Gemini
+- Moonshot AI
+- Zhipu AI
+- DeepSeek
+- Ollama (Auto-hébergé)
+- LM Studio (Auto-hébergé)
+- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
+- [302.AI](https://share.302.ai/rr1M3l)
+- [TokenPony](https://www.tokenpony.cn/3YPyf)
+- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
+- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
+- ModelScope
+- OneAPI
+
+**Plateformes LLMOps**
+
+- Dify
+- Applications Alibaba Cloud Bailian
+- Coze
+
+**Services de reconnaissance vocale**
+
+- OpenAI Whisper
+- SenseVoice
+
+**Services de synthèse vocale**
+
+- OpenAI TTS
+- Gemini TTS
+- GPT-Sovits-Inference
+- GPT-Sovits
+- FishAudio
+- Edge TTS
+- Alibaba Cloud Bailian TTS
+- Azure TTS
+- Minimax TTS
+- Volcano Engine TTS
+
+## ❤️ Contribuer
+
+Les Issues et Pull Requests sont toujours les bienvenues ! N'hésitez pas à soumettre vos modifications à ce projet :)
+
+### Comment contribuer
+
+Vous pouvez contribuer en examinant les issues ou en aidant à la revue des pull requests. Toutes les issues ou PRs sont les bienvenues pour encourager la participation de la communauté. Bien sûr, ce ne sont que des suggestions - vous pouvez contribuer de la manière que vous souhaitez. Pour l'ajout de nouvelles fonctionnalités, veuillez d'abord en discuter via une Issue.
+
+### Environnement de développement
+
+AstrBot utilise `ruff` pour le formatage et le linting du code.
+
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot
+pip install pre-commit
+pre-commit install
+```
+
+## 🌍 Communauté
+
+### Groupes QQ
+
+- Groupe 1 : 322154837
+- Groupe 3 : 630166526
+- Groupe 5 : 822130018
+- Groupe 6 : 753075035
+- Groupe développeurs : 975206796
+
+### Groupe Telegram
+
+
+
+### Serveur Discord
+
+
+
+## ❤️ Remerciements spéciaux
+
+Un grand merci à tous les contributeurs et développeurs de plugins pour leurs contributions à AstrBot ❤️
+
+
+
+
+
+De plus, la naissance de ce projet n'aurait pas été possible sans l'aide des projets open source suivants :
+
+- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - L'incroyable framework chat
+
+## ⭐ Historique des étoiles
+
+> [!TIP]
+> Si ce projet vous a aidé dans votre vie ou votre travail, ou si vous êtes intéressé par son développement futur, veuillez donner une étoile au projet. C'est la force motrice derrière la maintenance de ce projet open source <3
+
+
+
+[](https://star-history.com/#astrbotdevs/astrbot&Date)
+
+
+
+
+
+_私は、高性能ですから!_
+
diff --git a/README_ja.md b/README_ja.md
index 735d270bd..d94bf83b7 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -1,167 +1,247 @@
-
-
-
+
-_✨ 簡単に使えるマルチプラットフォーム LLM チャットボットおよび開発フレームワーク ✨_
+
+
-
-[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
-
-
-
-[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
-
-[](https://codecov.io/gh/AstrBotDevs/AstrBot)
-
-
ドキュメントを見る |
-
問題を報告する
+
-AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
+
-## ✨ 主な機能
+
-1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
-2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
-3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
-4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
-5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
-6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
+
-> [!TIP]
-> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
->
-> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
+
中文 |
+
English |
+
繁體中文 |
+
Français |
+
Русский
-## ✨ 使用方法
+
ドキュメント |
+
Blog |
+
ロードマップ |
+
Issue
+
-#### Docker デプロイ
+AstrBot は、主要なインスタントメッセージングアプリと統合できるオープンソースのオールインワン Agent チャットボットプラットフォームです。個人、開発者、チームに信頼性が高くスケーラブルな会話型 AI インフラストラクチャを提供します。パーソナル AI コンパニオン、インテリジェントカスタマーサービス、オートメーションアシスタント、エンタープライズナレッジベースなど、AstrBot を使用すると、IM プラットフォームのワークフロー内で本番環境対応の AI アプリケーションを迅速に構築できます。
-公式ドキュメント [Docker を使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) を参照してください。
+
-#### Windows ワンクリックインストーラーのデプロイ
+## 主な機能
-コンピュータに Python(>3.10)がインストールされている必要があります。公式ドキュメント [Windows ワンクリックインストーラーを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/windows.html) を参照してください。
+1. 💯 無料 & オープンソース。
+2. ✨ AI 大規模言語モデル対話、マルチモーダル、Agent、MCP、ナレッジベース、ペルソナ設定。
+3. 🤖 Dify、Alibaba Cloud 百炼、Coze などの Agent プラットフォームとの統合をサポート。
+4. 🌐 マルチプラットフォーム:QQ、WeChat Work、Feishu、DingTalk、WeChat 公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)。
+5. 📦 約800個のプラグインをワンクリックでインストール可能なプラグイン拡張機能。
+6. 💻 WebUI サポート。
+7. 🌐 国際化(i18n)サポート。
-#### Replit デプロイ
+## クイックスタート
+
+#### Docker デプロイ(推奨 🥳)
+
+Docker / Docker Compose を使用した AstrBot のデプロイを推奨します。
+
+公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
+
+#### uv デプロイ
+
+```bash
+uvx astrbot
+```
+
+#### 宝塔パネルデプロイ
+
+AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
+
+公式ドキュメント [宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html) をご参照ください。
+
+#### 1Panel デプロイ
+
+AstrBot は 1Panel 公式により 1Panel パネルに公開されています。
+
+公式ドキュメント [1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html) をご参照ください。
+
+#### 雨云でのデプロイ
+
+AstrBot は雨云公式によりクラウドアプリケーションプラットフォームに公開され、ワンクリックでデプロイ可能です。
+
+[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
+
+#### Replit でのデプロイ
+
+コミュニティ貢献によるデプロイ方法。
[](https://repl.it/github/AstrBotDevs/AstrBot)
+#### Windows ワンクリックインストーラーデプロイ
+
+公式ドキュメント [Windows ワンクリックインストーラーを使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/windows.html) をご参照ください。
+
#### CasaOS デプロイ
-コミュニティが提供するデプロイ方法です。
+コミュニティ貢献によるデプロイ方法。
-公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/casaos.html) を参照してください。
+公式ドキュメント [CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html) をご参照ください。
#### 手動デプロイ
-公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/cli.html) を参照してください。
+まず uv をインストールします:
-## ⚡ メッセージプラットフォームのサポート状況
+```bash
+pip install uv
+```
-| プラットフォーム | サポート状況 | 詳細 | メッセージタイプ |
-| -------- | ------- | ------- | ------ |
-| QQ(公式ロボットインターフェース) | ✔ | プライベートチャット、グループチャット、QQ チャンネルプライベートチャット、グループチャット | テキスト、画像 |
-| QQ(OneBot) | ✔ | プライベートチャット、グループチャット | テキスト、画像、音声 |
-| WeChat(個人アカウント) | ✔ | WeChat 個人アカウントのプライベートチャット、グループチャット | テキスト、画像、音声 |
-| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | プライベートチャット、グループチャット | テキスト、画像 |
-| [WeChat(企業 WeChat)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | プライベートチャット | テキスト、画像、音声 |
-| Feishu | ✔ | グループチャット | テキスト、画像 |
-| WeChat 対話オープンプラットフォーム | 🚧 | 計画中 | - |
-| Discord | 🚧 | 計画中 | - |
-| WhatsApp | 🚧 | 計画中 | - |
-| Xiaoai 音響 | 🚧 | 計画中 | - |
+Git Clone で AstrBot をインストール:
-# 🦌 今後のロードマップ
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
+uv run main.py
+```
-> [!TIP]
-> Issue でさらに多くの提案を歓迎します <3
+または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
-- [ ] 現在のすべてのプラットフォームアダプターの機能の一貫性を確保し、改善する
-- [ ] プラグインインターフェースの最適化
-- [ ] GPT-Sovits などの TTS サービスをデフォルトでサポート
-- [ ] "チャット強化" 部分を完成させ、永続的な記憶をサポート
-- [ ] i18n の計画
+## サポートされているメッセージプラットフォーム
-## ❤️ 貢献
+**公式メンテナンス**
-Issue や Pull Request を歓迎します!このプロジェクトに変更を加えるだけです :)
+- QQ (公式プラットフォーム & OneBot)
+- Telegram
+- WeChat Work アプリケーション & WeChat Work インテリジェントボット
+- WeChat カスタマーサービス & WeChat 公式アカウント
+- Feishu (Lark)
+- DingTalk
+- Slack
+- Discord
+- Satori
+- Misskey
+- WhatsApp (近日対応予定)
+- LINE (近日対応予定)
-新機能の追加については、まず Issue で議論してください。
+**コミュニティメンテナンス**
-## 🌟 サポート
+- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
+- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
+- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
-- このプロジェクトに Star を付けてください!
-- [愛発電](https://afdian.com/a/soulter)で私をサポートしてください!
-- [WeChat](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)で私をサポートしてください~
+## サポートされているモデルサービス
-## ✨ デモ
+**大規模言語モデルサービス**
-> [!NOTE]
-> コードエグゼキューターのファイル入力/出力は現在 Napcat(QQ)、Lagrange(QQ) でのみテストされています
+- OpenAI および互換サービス
+- Anthropic
+- Google Gemini
+- Moonshot AI
+- 智谱 AI
+- DeepSeek
+- Ollama (セルフホスト)
+- LM Studio (セルフホスト)
+- [優云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
+- [302.AI](https://share.302.ai/rr1M3l)
+- [小馬算力](https://www.tokenpony.cn/3YPyf)
+- [硅基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
+- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
+- ModelScope
+- OneAPI
-
+**LLMOps プラットフォーム**
-
+- Dify
+- Alibaba Cloud 百炼アプリケーション
+- Coze
-_✨ Docker ベースのサンドボックス化されたコードエグゼキューター(ベータテスト中)✨_
+**音声認識サービス**
-
+- OpenAI Whisper
+- SenseVoice
-_✨ 多モーダル、ウェブ検索、長文の画像変換(設定可能)✨_
+**音声合成サービス**
-
+- OpenAI TTS
+- Gemini TTS
+- GPT-Sovits-Inference
+- GPT-Sovits
+- FishAudio
+- Edge TTS
+- Alibaba Cloud 百炼 TTS
+- Azure TTS
+- Minimax TTS
+- Volcano Engine TTS
-_✨ 自然言語タスク ✨_
+## ❤️ コントリビューション
-
-
+Issue や Pull Request は大歓迎です!このプロジェクトに変更を送信してください :)
-_✨ プラグインシステム - 一部のプラグインの展示 ✨_
+### コントリビュート方法
-
+Issue を確認したり、PR(プルリクエスト)のレビューを手伝うことで貢献できます。どんな Issue や PR への参加も歓迎され、コミュニティ貢献を促進します。もちろん、これらは提案に過ぎず、どんな方法でも貢献できます。新機能の追加については、まず Issue で議論してください。
-_✨ 管理パネル ✨_
+### 開発環境
-
+AstrBot はコードのフォーマットとチェックに `ruff` を使用しています。
-_✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot
+pip install pre-commit
+pre-commit install
+```
-
+## 🌍 コミュニティ
+
+### QQ グループ
+
+- 1群: 322154837
+- 3群: 630166526
+- 5群: 822130018
+- 6群: 753075035
+- 開発者群: 975206796
+
+### Telegram グループ
+
+
+
+### Discord サーバー
+
+
+
+## ❤️ Special Thanks
+
+AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
+
+
+
+
+
+また、このプロジェクトの誕生は以下のオープンソースプロジェクトの助けなしには実現できませんでした:
+
+- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 素晴らしい猫猫フレームワーク
## ⭐ Star History
> [!TIP]
-> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
+> このプロジェクトがあなたの生活や仕事に役立ったり、このプロジェクトの今後の発展に関心がある場合は、プロジェクトに Star をください。これがこのオープンソースプロジェクトを維持する原動力です <3
-[](https://star-history.com/#soulter/astrbot&Date)
+[](https://star-history.com/#astrbotdevs/astrbot&Date)
-## スポンサー
-
-[
](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
-
-## 免責事項
-
-1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
-2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
-
-
+
_私は、高性能ですから!_
diff --git a/README_ru.md b/README_ru.md
new file mode 100644
index 000000000..ea8e9b6bf
--- /dev/null
+++ b/README_ru.md
@@ -0,0 +1,248 @@
+
+
+
+
+
+
+AstrBot — это универсальная платформа Agent-чатботов с открытым исходным кодом, которая интегрируется с основными приложениями для обмена мгновенными сообщениями. Она предоставляет надёжную и масштабируемую инфраструктуру разговорного ИИ для частных лиц, разработчиков и команд. Будь то персональный ИИ-компаньон, интеллектуальная служба поддержки, автоматизированный помощник или корпоративная база знаний — AstrBot позволяет быстро создавать готовые к использованию ИИ-приложения в рабочих процессах вашей платформы обмена сообщениями.
+
+
+
+## Основные возможности
+
+1. 💯 Бесплатно и с открытым исходным кодом.
+2. ✨ ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности.
+3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов.
+4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями).
+5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик.
+6. 💻 Поддержка WebUI.
+7. 🌐 Поддержка интернационализации (i18n).
+
+## Быстрый старт
+
+#### Развёртывание Docker (Рекомендуется 🥳)
+
+Мы рекомендуем развёртывать AstrBot с помощью Docker или Docker Compose.
+
+См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
+
+#### Развёртывание uv
+
+```bash
+uvx astrbot
+```
+
+#### Развёртывание BT-Panel
+
+AstrBot в партнёрстве с BT-Panel теперь доступен на их маркетплейсе.
+
+См. официальную документацию: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
+
+#### Развёртывание 1Panel
+
+AstrBot официально размещён на маркетплейсе 1Panel.
+
+См. официальную документацию: [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
+
+#### Развёртывание на RainYun
+
+AstrBot официально размещён на облачной платформе приложений RainYun с развёртыванием в один клик.
+
+[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
+
+#### Развёртывание на Replit
+
+Метод развёртывания от сообщества.
+
+[](https://repl.it/github/AstrBotDevs/AstrBot)
+
+#### Установщик Windows в один клик
+
+См. официальную документацию: [Развёртывание AstrBot с установщиком Windows в один клик](https://astrbot.app/deploy/astrbot/windows.html).
+
+#### Развёртывание CasaOS
+
+Метод развёртывания от сообщества.
+
+См. официальную документацию: [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
+
+#### Ручное развёртывание
+
+Сначала установите uv:
+
+```bash
+pip install uv
+```
+
+Установите AstrBot через Git Clone:
+
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
+uv run main.py
+```
+
+Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html).
+
+## Поддерживаемые платформы обмена сообщениями
+
+**Официально поддерживаемые**
+
+- QQ (Официальная платформа и OneBot)
+- Telegram
+- Приложение WeChat Work и интеллектуальный бот WeChat Work
+- Служба поддержки WeChat и официальные аккаунты WeChat
+- Feishu (Lark)
+- DingTalk
+- Slack
+- Discord
+- Satori
+- Misskey
+- WhatsApp (Скоро)
+- LINE (Скоро)
+
+**Поддерживаемые сообществом**
+
+- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
+- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
+- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
+
+## Поддерживаемые сервисы моделей
+
+**Сервисы LLM**
+
+- OpenAI и совместимые сервисы
+- Anthropic
+- Google Gemini
+- Moonshot AI
+- Zhipu AI
+- DeepSeek
+- Ollama (Самостоятельное размещение)
+- LM Studio (Самостоятельное размещение)
+- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
+- [302.AI](https://share.302.ai/rr1M3l)
+- [TokenPony](https://www.tokenpony.cn/3YPyf)
+- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
+- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
+- ModelScope
+- OneAPI
+
+**Платформы LLMOps**
+
+- Dify
+- Приложения Alibaba Cloud Bailian
+- Coze
+
+**Сервисы распознавания речи**
+
+- OpenAI Whisper
+- SenseVoice
+
+**Сервисы синтеза речи**
+
+- OpenAI TTS
+- Gemini TTS
+- GPT-Sovits-Inference
+- GPT-Sovits
+- FishAudio
+- Edge TTS
+- Alibaba Cloud Bailian TTS
+- Azure TTS
+- Minimax TTS
+- Volcano Engine TTS
+
+## ❤️ Вклад в проект
+
+Issues и Pull Request всегда приветствуются! Не стесняйтесь отправлять свои изменения в этот проект :)
+
+### Как внести вклад
+
+Вы можете внести вклад, просматривая issues или помогая с ревью pull request. Любые issues или PR приветствуются для поощрения участия сообщества. Конечно, это лишь предложения — вы можете вносить вклад любым удобным для вас способом. Для добавления новых функций сначала обсудите это через Issue.
+
+### Среда разработки
+
+AstrBot использует `ruff` для форматирования и линтинга кода.
+
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot
+pip install pre-commit
+pre-commit install
+```
+
+## 🌍 Сообщество
+
+### Группы QQ
+
+- Группа 1: 322154837
+- Группа 3: 630166526
+- Группа 5: 822130018
+- Группа 6: 753075035
+- Группа разработчиков: 975206796
+
+### Группа Telegram
+
+
+
+### Сервер Discord
+
+
+
+## ❤️ Особая благодарность
+
+Особая благодарность всем контрибьюторам и разработчикам плагинов за их вклад в AstrBot ❤️
+
+
+
+
+
+Кроме того, рождение этого проекта было бы невозможно без помощи следующих проектов с открытым исходным кодом:
+
+- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - Замечательный кошачий фреймворк
+
+## ⭐ История звёзд
+
+> [!TIP]
+> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3
+
+
+
+[](https://star-history.com/#astrbotdevs/astrbot&Date)
+
+
+
+
+
+_私は、高性能ですから!_
+
diff --git a/README_zh-TW.md b/README_zh-TW.md
new file mode 100644
index 000000000..5f77ab7ce
--- /dev/null
+++ b/README_zh-TW.md
@@ -0,0 +1,248 @@
+
+
+
+
+
+
+AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主流即時通訊軟體,為個人、開發者和團隊打造可靠、可擴展的對話式智慧基礎設施。無論是個人 AI 夥伴、智慧客服、自動化助手,還是企業知識庫,AstrBot 都能在您的即時通訊軟體平台的工作流程中快速構建生產可用的 AI 應用程式。
+
+
+
+## 主要功能
+
+1. 💯 免費 & 開源。
+2. ✨ AI 大型模型對話,多模態,Agent,MCP,知識庫,人格設定。
+3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。
+4. 🌐 多平台:QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
+5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。
+6. 💻 WebUI 支援。
+7. 🌐 國際化(i18n)支援。
+
+## 快速開始
+
+#### Docker 部署(推薦 🥳)
+
+推薦使用 Docker / Docker Compose 方式部署 AstrBot。
+
+請參閱官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
+
+#### uv 部署
+
+```bash
+uvx astrbot
+```
+
+#### 寶塔面板部署
+
+AstrBot 與寶塔面板合作,已上架至寶塔面板。
+
+請參閱官方文件 [寶塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html)。
+
+#### 1Panel 部署
+
+AstrBot 已由 1Panel 官方上架至 1Panel 面板。
+
+請參閱官方文件 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html)。
+
+#### 在雨雲上部署
+
+AstrBot 已由雨雲官方上架至雲端應用程式平台,可一鍵部署。
+
+[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
+
+#### 在 Replit 上部署
+
+社群貢獻的部署方式。
+
+[](https://repl.it/github/AstrBotDevs/AstrBot)
+
+#### Windows 一鍵安裝器部署
+
+請參閱官方文件 [使用 Windows 一鍵安裝器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html)。
+
+#### CasaOS 部署
+
+社群貢獻的部署方式。
+
+請參閱官方文件 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html)。
+
+#### 手動部署
+
+首先安裝 uv:
+
+```bash
+pip install uv
+```
+
+透過 Git Clone 安裝 AstrBot:
+
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
+uv run main.py
+```
+
+或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。
+
+## 支援的訊息平台
+
+**官方維護**
+
+- QQ(官方平台 & OneBot)
+- Telegram
+- 企微應用 & 企微智慧機器人
+- 微信客服 & 微信公眾號
+- 飛書
+- 釘釘
+- Slack
+- Discord
+- Satori
+- Misskey
+- Whatsapp(即將支援)
+- LINE(即將支援)
+
+**社群維護**
+
+- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
+- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
+- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
+
+## 支援的模型服務
+
+**大型模型服務**
+
+- OpenAI 及相容服務
+- Anthropic
+- Google Gemini
+- Moonshot AI
+- 智譜 AI
+- DeepSeek
+- Ollama(本機部署)
+- LM Studio(本機部署)
+- [優雲智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
+- [302.AI](https://share.302.ai/rr1M3l)
+- [小馬算力](https://www.tokenpony.cn/3YPyf)
+- [矽基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
+- [PPIO 派歐雲](https://ppio.com/user/register?invited_by=AIOONE)
+- ModelScope
+- OneAPI
+
+**LLMOps 平台**
+
+- Dify
+- 阿里雲百煉應用
+- Coze
+
+**語音轉文字服務**
+
+- OpenAI Whisper
+- SenseVoice
+
+**文字轉語音服務**
+
+- OpenAI TTS
+- Gemini TTS
+- GPT-Sovits-Inference
+- GPT-Sovits
+- FishAudio
+- Edge TTS
+- 阿里雲百煉 TTS
+- Azure TTS
+- Minimax TTS
+- 火山引擎 TTS
+
+## ❤️ 貢獻
+
+歡迎任何 Issues/Pull Requests!只需要將您的變更提交到此專案 :)
+
+### 如何貢獻
+
+您可以透過檢視問題或協助審核 PR(拉取請求)來貢獻。任何問題或 PR 都歡迎參與,以促進社群貢獻。當然,這些只是建議,您可以以任何方式進行貢獻。對於新功能的新增,請先透過 Issue 討論。
+
+### 開發環境
+
+AstrBot 使用 `ruff` 進行程式碼格式化和檢查。
+
+```bash
+git clone https://github.com/AstrBotDevs/AstrBot
+pip install pre-commit
+pre-commit install
+```
+
+## 🌍 社群
+
+### QQ 群組
+
+- 1 群:322154837
+- 3 群:630166526
+- 5 群:822130018
+- 6 群:753075035
+- 開發者群:975206796
+
+### Telegram 群組
+
+
+
+### Discord 群組
+
+
+
+## ❤️ Special Thanks
+
+特別感謝所有 Contributors 和外掛開發者對 AstrBot 的貢獻 ❤️
+
+
+
+
+
+此外,本專案的誕生離不開以下開源專案的幫助:
+
+- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 偉大的貓貓框架
+
+## ⭐ Star History
+
+> [!TIP]
+> 如果本專案對您的生活 / 工作產生了幫助,或者您關注本專案的未來發展,請給專案 Star,這是我們維護這個開源專案的動力 <3
+
+
+
+[](https://star-history.com/#astrbotdevs/astrbot&Date)
+
+
+
+
+
+_私は、高性能ですから!_
+
diff --git a/astrbot/api/all.py b/astrbot/api/all.py
index 2463dbc2b..df3e1170f 100644
--- a/astrbot/api/all.py
+++ b/astrbot/api/all.py
@@ -36,7 +36,8 @@ from astrbot.core.star.config import *
# provider
-from astrbot.core.provider import Provider, Personality, ProviderMetaData
+from astrbot.core.provider import Provider, ProviderMetaData
+from astrbot.core.db.po import Personality
# platform
from astrbot.core.platform import (
diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py
index 2008c7bcf..f62b340f8 100644
--- a/astrbot/api/provider/__init__.py
+++ b/astrbot/api/provider/__init__.py
@@ -1,4 +1,5 @@
-from astrbot.core.provider import Personality, Provider, STTProvider
+from astrbot.core.db.po import Personality
+from astrbot.core.provider import Provider, STTProvider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMetaData,
diff --git a/astrbot/cli/__init__.py b/astrbot/cli/__init__.py
index 8d1eee0b1..7332367ed 100644
--- a/astrbot/cli/__init__.py
+++ b/astrbot/cli/__init__.py
@@ -1 +1 @@
-__version__ = "3.5.23"
+__version__ = "4.9.2"
diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py
index 05980b212..c5ff123b2 100644
--- a/astrbot/core/agent/mcp_client.py
+++ b/astrbot/core/agent/mcp_client.py
@@ -4,6 +4,14 @@ from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
@@ -12,21 +20,24 @@ from .run_context import TContext
from .tool import FunctionTool
try:
+ import anyio
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
- logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
+ logger.warning(
+ "Warning: Missing 'mcp' dependency, MCP services will be unavailable."
+ )
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
- "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
+ "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
)
def _prepare_config(config: dict) -> dict:
- """准备配置,处理嵌套格式"""
+ """Prepare configuration, handle nested format"""
if config.get("mcpServers"):
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
@@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict:
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
- """快速测试 MCP 服务器可达性"""
+ """Quick test MCP server connectivity"""
import aiohttp
cfg = _prepare_config(config.copy())
@@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
elif "type" in cfg:
transport_type = cfg["type"]
else:
- raise Exception("MCP 连接配置缺少 transport 或 type 字段")
+ raise Exception("MCP connection config missing transport or type field")
async with aiohttp.ClientSession() as session:
if transport_type == "streamable_http":
@@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
- return False, f"连接超时: {timeout}秒"
+ return False, f"Connection timeout: {timeout} seconds"
except Exception as e:
return False, f"{e!s}"
@@ -101,6 +112,7 @@ class MCPClient:
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.exit_stack = AsyncExitStack()
+ self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
self.name: str | None = None
self.active: bool = True
@@ -108,22 +120,32 @@ class MCPClient:
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
- async def connect_to_server(self, mcp_server_config: dict, name: str):
- """连接到 MCP 服务器
+ # Store connection config for reconnection
+ self._mcp_server_config: dict | None = None
+ self._server_name: str | None = None
+ self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
+ self._reconnecting: bool = False # For logging and debugging
- 如果 `url` 参数存在:
- 1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
- 1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
- 2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
+ async def connect_to_server(self, mcp_server_config: dict, name: str):
+ """Connect to MCP server
+
+ If `url` parameter exists:
+ 1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
+ 2. When transport is specified as `sse`, use SSE connection.
+ 3. If not specified, default to SSE connection to MCP service.
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
+ # Store config for reconnection
+ self._mcp_server_config = mcp_server_config
+ self._server_name = name
+
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
- # 处理 MCP 服务的错误日志
+ # Handle MCP service error logs
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
@@ -137,7 +159,7 @@ class MCPClient:
elif "type" in cfg:
transport_type = cfg["type"]
else:
- raise Exception("MCP 连接配置缺少 transport 或 type 字段")
+ raise Exception("MCP connection config missing transport or type field")
if transport_type != "streamable_http":
# SSE transport method
@@ -193,7 +215,7 @@ class MCPClient:
)
def callback(msg: str):
- # 处理 MCP 服务的错误日志
+ # Handle MCP service error logs
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
@@ -222,10 +244,120 @@ class MCPClient:
self.tools = response.tools
return response
+ async def _reconnect(self) -> None:
+ """Reconnect to the MCP server using the stored configuration.
+
+ Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
+
+ Raises:
+ Exception: raised when reconnection fails
+ """
+ async with self._reconnect_lock:
+ # Check if already reconnecting (useful for logging)
+ if self._reconnecting:
+ logger.debug(
+ f"MCP Client {self._server_name} is already reconnecting, skipping"
+ )
+ return
+
+ if not self._mcp_server_config or not self._server_name:
+ raise Exception("Cannot reconnect: missing connection configuration")
+
+ self._reconnecting = True
+ try:
+ logger.info(
+ f"Attempting to reconnect to MCP server {self._server_name}..."
+ )
+
+ # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
+ if self.exit_stack:
+ self._old_exit_stacks.append(self.exit_stack)
+
+ # Mark old session as invalid
+ self.session = None
+
+ # Create new exit stack for new connection
+ self.exit_stack = AsyncExitStack()
+
+ # Reconnect using stored config
+ await self.connect_to_server(self._mcp_server_config, self._server_name)
+ await self.list_tools_and_save()
+
+ logger.info(
+ f"Successfully reconnected to MCP server {self._server_name}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to reconnect to MCP server {self._server_name}: {e}"
+ )
+ raise
+ finally:
+ self._reconnecting = False
+
+ async def call_tool_with_reconnect(
+ self,
+ tool_name: str,
+ arguments: dict,
+ read_timeout_seconds: timedelta,
+ ) -> mcp.types.CallToolResult:
+ """Call MCP tool with automatic reconnection on failure, max 2 retries.
+
+ Args:
+ tool_name: tool name
+ arguments: tool arguments
+ read_timeout_seconds: read timeout
+
+ Returns:
+ MCP tool call result
+
+ Raises:
+ ValueError: MCP session is not available
+ anyio.ClosedResourceError: raised after reconnection failure
+ """
+
+ @retry(
+ retry=retry_if_exception_type(anyio.ClosedResourceError),
+ stop=stop_after_attempt(2),
+ wait=wait_exponential(multiplier=1, min=1, max=3),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ reraise=True,
+ )
+ async def _call_with_retry():
+ if not self.session:
+ raise ValueError("MCP session is not available for MCP function tools.")
+
+ try:
+ return await self.session.call_tool(
+ name=tool_name,
+ arguments=arguments,
+ read_timeout_seconds=read_timeout_seconds,
+ )
+ except anyio.ClosedResourceError:
+ logger.warning(
+ f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
+ )
+ # Attempt to reconnect
+ await self._reconnect()
+ # Reraise the exception to trigger tenacity retry
+ raise
+
+ return await _call_with_retry()
+
async def cleanup(self):
- """Clean up resources"""
- await self.exit_stack.aclose()
- self.running_event.set() # Set the running event to indicate cleanup is done
+ """Clean up resources including old exit stacks from reconnections"""
+ # Close current exit stack
+ try:
+ await self.exit_stack.aclose()
+ except Exception as e:
+ logger.debug(f"Error closing current exit stack: {e}")
+
+ # Don't close old exit stacks as they may be in different task contexts
+ # They will be garbage collected naturally
+ # Just clear the list to release references
+ self._old_exit_stacks.clear()
+
+ # Set running_event first to unblock any waiting tasks
+ self.running_event.set()
class MCPTool(FunctionTool, Generic[TContext]):
@@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]):
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> mcp.types.CallToolResult:
- session = self.mcp_client.session
- if not session:
- raise ValueError("MCP session is not available for MCP function tools.")
- res = await session.call_tool(
- name=self.mcp_tool.name,
+ return await self.mcp_client.call_tool_with_reconnect(
+ tool_name=self.mcp_tool.name,
arguments=kwargs,
- read_timeout_seconds=timedelta(
- seconds=context.tool_call_timeout,
- ),
+ read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
)
- return res
diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py
index 11128c0f6..d69bc6a81 100644
--- a/astrbot/core/agent/message.py
+++ b/astrbot/core/agent/message.py
@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast
-from pydantic import BaseModel, GetCoreSchemaHandler
+from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
from pydantic_core import core_schema
@@ -76,7 +76,7 @@ class ImageURLPart(ContentPart):
"""The ID of the image, to allow LLMs to distinguish different images."""
type: str = "image_url"
- image_url: str
+ image_url: ImageURL
class AudioURLPart(ContentPart):
@@ -119,6 +119,15 @@ class ToolCall(BaseModel):
"""The ID of the tool call."""
function: FunctionBody
"""The function body of the tool call."""
+ extra_content: dict[str, Any] | None = None
+ """Extra metadata for the tool call."""
+
+ @model_serializer(mode="wrap")
+ def serialize(self, handler):
+ data = handler(self)
+ if self.extra_content is None:
+ data.pop("extra_content", None)
+ return data
class ToolCallPart(BaseModel):
@@ -138,22 +147,39 @@ class Message(BaseModel):
"tool",
]
- content: str | list[ContentPart]
+ content: str | list[ContentPart] | None = None
"""The content of the message."""
+ tool_calls: list[ToolCall] | list[dict] | None = None
+ """The tool calls of the message."""
+
+ tool_call_id: str | None = None
+ """The ID of the tool call."""
+
+ @model_validator(mode="after")
+ def check_content_required(self):
+ # assistant + tool_calls is not None: allow content to be None
+ if self.role == "assistant" and self.tool_calls is not None:
+ return self
+
+ # other all cases: content is required
+ if self.content is None:
+ raise ValueError(
+ "content is required unless role='assistant' and tool_calls is not None"
+ )
+ return self
+
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
role: Literal["assistant"] = "assistant"
- tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message):
"""A message segment representing a tool call."""
role: Literal["tool"] = "tool"
- tool_call_id: str
class UserMessageSegment(Message):
diff --git a/astrbot/core/agent/response.py b/astrbot/core/agent/response.py
index 3f3430c87..9e61fa8c7 100644
--- a/astrbot/core/agent/response.py
+++ b/astrbot/core/agent/response.py
@@ -1,7 +1,8 @@
import typing as T
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from astrbot.core.message.message_event_result import MessageChain
+from astrbot.core.provider.entities import TokenUsage
class AgentResponseData(T.TypedDict):
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
class AgentResponse:
type: str
data: AgentResponseData
+
+
+@dataclass
+class AgentStats:
+ token_usage: TokenUsage = field(default_factory=TokenUsage)
+ start_time: float = 0.0
+ end_time: float = 0.0
+ time_to_first_token: float = 0.0
+
+ @property
+ def duration(self) -> float:
+ return self.end_time - self.start_time
+
+ def to_dict(self) -> dict:
+ return {
+ "token_usage": self.token_usage.__dict__,
+ "start_time": self.start_time,
+ "end_time": self.end_time,
+ "time_to_first_token": self.time_to_first_token,
+ }
diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py
index 395817679..687ad22e5 100644
--- a/astrbot/core/agent/run_context.py
+++ b/astrbot/core/agent/run_context.py
@@ -1,8 +1,11 @@
-from dataclasses import dataclass
from typing import Any, Generic
+from pydantic import Field
+from pydantic.dataclasses import dataclass
from typing_extensions import TypeVar
+from .message import Message
+
TContext = TypeVar("TContext", default=Any)
@@ -11,6 +14,8 @@ class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
+ messages: list[Message] = Field(default_factory=list)
+ """This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds
diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py
index c7cd36d96..21e796433 100644
--- a/astrbot/core/agent/runners/base.py
+++ b/astrbot/core/agent/runners/base.py
@@ -2,13 +2,12 @@ import abc
import typing as T
from enum import Enum, auto
-from astrbot.core.provider import Provider
+from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
-from ..tool_executor import BaseFunctionToolExecutor
class AgentState(Enum):
@@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
- provider: Provider,
run_context: ContextWrapper[TContext],
- tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
@@ -40,6 +37,13 @@ class BaseAgentRunner(T.Generic[TContext]):
"""Process a single step of the agent."""
...
+ @abc.abstractmethod
+ async def step_until_done(
+ self, max_step: int
+ ) -> T.AsyncGenerator[AgentResponse, None]:
+ """Process steps until the agent is done."""
+ ...
+
@abc.abstractmethod
def done(self) -> bool:
"""Check if the agent has completed its task.
@@ -53,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]):
This method should be called after the agent is done.
"""
...
+
+ def _transition_state(self, new_state: AgentState) -> None:
+ """Transition the agent state."""
+ if self._state != new_state:
+ logger.debug(f"Agent state transition: {self._state} -> {new_state}")
+ self._state = new_state
diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py
new file mode 100644
index 000000000..a8300bb71
--- /dev/null
+++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py
@@ -0,0 +1,367 @@
+import base64
+import json
+import sys
+import typing as T
+
+import astrbot.core.message.components as Comp
+from astrbot import logger
+from astrbot.core import sp
+from astrbot.core.message.message_event_result import MessageChain
+from astrbot.core.provider.entities import (
+ LLMResponse,
+ ProviderRequest,
+)
+
+from ...hooks import BaseAgentRunHooks
+from ...response import AgentResponseData
+from ...run_context import ContextWrapper, TContext
+from ..base import AgentResponse, AgentState, BaseAgentRunner
+from .coze_api_client import CozeAPIClient
+
+if sys.version_info >= (3, 12):
+ from typing import override
+else:
+ from typing_extensions import override
+
+
+class CozeAgentRunner(BaseAgentRunner[TContext]):
+ """Coze Agent Runner"""
+
+ @override
+ async def reset(
+ self,
+ request: ProviderRequest,
+ run_context: ContextWrapper[TContext],
+ agent_hooks: BaseAgentRunHooks[TContext],
+ provider_config: dict,
+ **kwargs: T.Any,
+ ) -> None:
+ self.req = request
+ self.streaming = kwargs.get("streaming", False)
+ self.final_llm_resp = None
+ self._state = AgentState.IDLE
+ self.agent_hooks = agent_hooks
+ self.run_context = run_context
+
+ self.api_key = provider_config.get("coze_api_key", "")
+ if not self.api_key:
+ raise Exception("Coze API Key 不能为空。")
+ self.bot_id = provider_config.get("bot_id", "")
+ if not self.bot_id:
+ raise Exception("Coze Bot ID 不能为空。")
+ self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
+
+ if not isinstance(self.api_base, str) or not self.api_base.startswith(
+ ("http://", "https://"),
+ ):
+ raise Exception(
+ "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
+ )
+
+ self.timeout = provider_config.get("timeout", 120)
+ if isinstance(self.timeout, str):
+ self.timeout = int(self.timeout)
+ self.auto_save_history = provider_config.get("auto_save_history", True)
+
+ # 创建 API 客户端
+ self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
+
+ # 会话相关缓存
+ self.file_id_cache: dict[str, dict[str, str]] = {}
+
+ @override
+ async def step(self):
+ """
+ 执行 Coze Agent 的一个步骤
+ """
+ if not self.req:
+ raise ValueError("Request is not set. Please call reset() first.")
+
+ if self._state == AgentState.IDLE:
+ try:
+ await self.agent_hooks.on_agent_begin(self.run_context)
+ except Exception as e:
+ logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
+
+ # 开始处理,转换到运行状态
+ self._transition_state(AgentState.RUNNING)
+
+ try:
+ # 执行 Coze 请求并处理结果
+ async for response in self._execute_coze_request():
+ yield response
+ except Exception as e:
+ logger.error(f"Coze 请求失败:{str(e)}")
+ self._transition_state(AgentState.ERROR)
+ self.final_llm_resp = LLMResponse(
+ role="err", completion_text=f"Coze 请求失败:{str(e)}"
+ )
+ yield AgentResponse(
+ type="err",
+ data=AgentResponseData(
+ chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
+ ),
+ )
+ finally:
+ await self.api_client.close()
+
+ @override
+ async def step_until_done(
+ self, max_step: int = 30
+ ) -> T.AsyncGenerator[AgentResponse, None]:
+ while not self.done():
+ async for resp in self.step():
+ yield resp
+
+ async def _execute_coze_request(self):
+ """执行 Coze 请求的核心逻辑"""
+ prompt = self.req.prompt or ""
+ session_id = self.req.session_id or "unknown"
+ image_urls = self.req.image_urls or []
+ contexts = self.req.contexts or []
+ system_prompt = self.req.system_prompt
+
+ # 用户ID参数
+ user_id = session_id
+
+ # 获取或创建会话ID
+ conversation_id = await sp.get_async(
+ scope="umo",
+ scope_id=user_id,
+ key="coze_conversation_id",
+ default="",
+ )
+
+ # 构建消息
+ additional_messages = []
+
+ if system_prompt:
+ if not self.auto_save_history or not conversation_id:
+ additional_messages.append(
+ {
+ "role": "system",
+ "content": system_prompt,
+ "content_type": "text",
+ },
+ )
+
+ # 处理历史上下文
+ if not self.auto_save_history and contexts:
+ for ctx in contexts:
+ if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
+ # 处理上下文中的图片
+ content = ctx["content"]
+ if isinstance(content, list):
+ # 多模态内容,需要处理图片
+ processed_content = []
+ for item in content:
+ if isinstance(item, dict):
+ if item.get("type") == "text":
+ processed_content.append(item)
+ elif item.get("type") == "image_url":
+ # 处理图片上传
+ try:
+ image_data = item.get("image_url", {})
+ url = image_data.get("url", "")
+ if url:
+ file_id = (
+ await self._download_and_upload_image(
+ url, session_id
+ )
+ )
+ processed_content.append(
+ {
+ "type": "file",
+ "file_id": file_id,
+ "file_url": url,
+ }
+ )
+ except Exception as e:
+ logger.warning(f"处理上下文图片失败: {e}")
+ continue
+
+ if processed_content:
+ additional_messages.append(
+ {
+ "role": ctx["role"],
+ "content": processed_content,
+ "content_type": "object_string",
+ }
+ )
+ else:
+ # 纯文本内容
+ additional_messages.append(
+ {
+ "role": ctx["role"],
+ "content": content,
+ "content_type": "text",
+ }
+ )
+
+ # 构建当前消息
+ if prompt or image_urls:
+ if image_urls:
+ # 多模态
+ object_string_content = []
+ if prompt:
+ object_string_content.append({"type": "text", "text": prompt})
+
+ for url in image_urls:
+ # the url is a base64 string
+ try:
+ image_data = base64.b64decode(url)
+ file_id = await self.api_client.upload_file(image_data)
+ object_string_content.append(
+ {
+ "type": "image",
+ "file_id": file_id,
+ }
+ )
+ except Exception as e:
+ logger.warning(f"处理图片失败 {url}: {e}")
+ continue
+
+ if object_string_content:
+ content = json.dumps(object_string_content, ensure_ascii=False)
+ additional_messages.append(
+ {
+ "role": "user",
+ "content": content,
+ "content_type": "object_string",
+ }
+ )
+ elif prompt:
+ # 纯文本
+ additional_messages.append(
+ {
+ "role": "user",
+ "content": prompt,
+ "content_type": "text",
+ },
+ )
+
+ # 执行 Coze API 请求
+ accumulated_content = ""
+ message_started = False
+
+ async for chunk in self.api_client.chat_messages(
+ bot_id=self.bot_id,
+ user_id=user_id,
+ additional_messages=additional_messages,
+ conversation_id=conversation_id,
+ auto_save_history=self.auto_save_history,
+ stream=True,
+ timeout=self.timeout,
+ ):
+ event_type = chunk.get("event")
+ data = chunk.get("data", {})
+
+ if event_type == "conversation.chat.created":
+ if isinstance(data, dict) and "conversation_id" in data:
+ await sp.put_async(
+ scope="umo",
+ scope_id=user_id,
+ key="coze_conversation_id",
+ value=data["conversation_id"],
+ )
+
+ if event_type == "conversation.message.delta":
+ # 增量消息
+ content = data.get("content", "")
+ if not content and "delta" in data:
+ content = data["delta"].get("content", "")
+ if not content and "text" in data:
+ content = data.get("text", "")
+
+ if content:
+ accumulated_content += content
+ message_started = True
+
+ # 如果是流式响应,发送增量数据
+ if self.streaming:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(
+ chain=MessageChain().message(content)
+ ),
+ )
+
+ elif event_type == "conversation.message.completed":
+ # 消息完成
+ logger.debug("Coze message completed")
+ message_started = True
+
+ elif event_type == "conversation.chat.completed":
+ # 对话完成
+ logger.debug("Coze chat completed")
+ break
+
+ elif event_type == "error":
+ # 错误处理
+ error_msg = data.get("msg", "未知错误")
+ error_code = data.get("code", "UNKNOWN")
+ logger.error(f"Coze 出现错误: {error_code} - {error_msg}")
+ raise Exception(f"Coze 出现错误: {error_code} - {error_msg}")
+
+ if not message_started and not accumulated_content:
+ logger.warning("Coze 未返回任何内容")
+ accumulated_content = ""
+
+ # 创建最终响应
+ chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
+ self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
+ self._transition_state(AgentState.DONE)
+
+ try:
+ await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
+ except Exception as e:
+ logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
+
+ # 返回最终结果
+ yield AgentResponse(
+ type="llm_result",
+ data=AgentResponseData(chain=chain),
+ )
+
+ async def _download_and_upload_image(
+ self,
+ image_url: str,
+ session_id: str | None = None,
+ ) -> str:
+ """下载图片并上传到 Coze,返回 file_id"""
+ import hashlib
+
+ # 计算哈希实现缓存
+ cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest()
+
+ if session_id:
+ if session_id not in self.file_id_cache:
+ self.file_id_cache[session_id] = {}
+
+ if cache_key in self.file_id_cache[session_id]:
+ file_id = self.file_id_cache[session_id][cache_key]
+ logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}")
+ return file_id
+
+ try:
+ image_data = await self.api_client.download_image(image_url)
+ file_id = await self.api_client.upload_file(image_data)
+
+ if session_id:
+ self.file_id_cache[session_id][cache_key] = file_id
+ logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
+
+ return file_id
+
+ except Exception as e:
+ logger.error(f"处理图片失败 {image_url}: {e!s}")
+ raise Exception(f"处理图片失败: {e!s}")
+
+ @override
+ def done(self) -> bool:
+ """检查 Agent 是否已完成工作"""
+ return self._state in (AgentState.DONE, AgentState.ERROR)
+
+ @override
+ def get_final_llm_resp(self) -> LLMResponse | None:
+ return self.final_llm_resp
diff --git a/astrbot/core/provider/sources/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py
similarity index 100%
rename from astrbot/core/provider/sources/coze_api_client.py
rename to astrbot/core/agent/runners/coze/coze_api_client.py
diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py
new file mode 100644
index 000000000..7a095a60b
--- /dev/null
+++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py
@@ -0,0 +1,403 @@
+import asyncio
+import functools
+import queue
+import re
+import sys
+import threading
+import typing as T
+
+from dashscope import Application
+from dashscope.app.application_response import ApplicationResponse
+
+import astrbot.core.message.components as Comp
+from astrbot.core import logger, sp
+from astrbot.core.message.message_event_result import MessageChain
+from astrbot.core.provider.entities import (
+ LLMResponse,
+ ProviderRequest,
+)
+
+from ...hooks import BaseAgentRunHooks
+from ...response import AgentResponseData
+from ...run_context import ContextWrapper, TContext
+from ..base import AgentResponse, AgentState, BaseAgentRunner
+
+if sys.version_info >= (3, 12):
+ from typing import override
+else:
+ from typing_extensions import override
+
+
+class DashscopeAgentRunner(BaseAgentRunner[TContext]):
+ """Dashscope Agent Runner"""
+
+ @override
+ async def reset(
+ self,
+ request: ProviderRequest,
+ run_context: ContextWrapper[TContext],
+ agent_hooks: BaseAgentRunHooks[TContext],
+ provider_config: dict,
+ **kwargs: T.Any,
+ ) -> None:
+ self.req = request
+ self.streaming = kwargs.get("streaming", False)
+ self.final_llm_resp = None
+ self._state = AgentState.IDLE
+ self.agent_hooks = agent_hooks
+ self.run_context = run_context
+
+ self.api_key = provider_config.get("dashscope_api_key", "")
+ if not self.api_key:
+ raise Exception("阿里云百炼 API Key 不能为空。")
+ self.app_id = provider_config.get("dashscope_app_id", "")
+ if not self.app_id:
+ raise Exception("阿里云百炼 APP ID 不能为空。")
+ self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
+ if not self.dashscope_app_type:
+ raise Exception("阿里云百炼 APP 类型不能为空。")
+
+ self.variables: dict = provider_config.get("variables", {}) or {}
+ self.rag_options: dict = provider_config.get("rag_options", {})
+ self.output_reference = self.rag_options.get("output_reference", False)
+ self.rag_options = self.rag_options.copy()
+ self.rag_options.pop("output_reference", None)
+
+ self.timeout = provider_config.get("timeout", 120)
+ if isinstance(self.timeout, str):
+ self.timeout = int(self.timeout)
+
+ def has_rag_options(self):
+ """判断是否有 RAG 选项
+
+ Returns:
+ bool: 是否有 RAG 选项
+
+ """
+ if self.rag_options and (
+ len(self.rag_options.get("pipeline_ids", [])) > 0
+ or len(self.rag_options.get("file_ids", [])) > 0
+ ):
+ return True
+ return False
+
+ @override
+ async def step(self):
+ """
+ 执行 Dashscope Agent 的一个步骤
+ """
+ if not self.req:
+ raise ValueError("Request is not set. Please call reset() first.")
+
+ if self._state == AgentState.IDLE:
+ try:
+ await self.agent_hooks.on_agent_begin(self.run_context)
+ except Exception as e:
+ logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
+
+ # 开始处理,转换到运行状态
+ self._transition_state(AgentState.RUNNING)
+
+ try:
+ # 执行 Dashscope 请求并处理结果
+ async for response in self._execute_dashscope_request():
+ yield response
+ except Exception as e:
+ logger.error(f"阿里云百炼请求失败:{str(e)}")
+ self._transition_state(AgentState.ERROR)
+ self.final_llm_resp = LLMResponse(
+ role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
+ )
+ yield AgentResponse(
+ type="err",
+ data=AgentResponseData(
+ chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
+ ),
+ )
+
+ @override
+ async def step_until_done(
+ self, max_step: int = 30
+ ) -> T.AsyncGenerator[AgentResponse, None]:
+ while not self.done():
+ async for resp in self.step():
+ yield resp
+
+ def _consume_sync_generator(
+ self, response: T.Any, response_queue: queue.Queue
+ ) -> None:
+ """在线程中消费同步generator,将结果放入队列
+
+ Args:
+ response: 同步generator对象
+ response_queue: 用于传递数据的队列
+
+ """
+ try:
+ if self.streaming:
+ for chunk in response:
+ response_queue.put(("data", chunk))
+ else:
+ response_queue.put(("data", response))
+ except Exception as e:
+ response_queue.put(("error", e))
+ finally:
+ response_queue.put(("done", None))
+
+ async def _process_stream_chunk(
+ self, chunk: ApplicationResponse, output_text: str
+ ) -> tuple[str, list | None, AgentResponse | None]:
+ """处理流式响应的单个chunk
+
+ Args:
+ chunk: Dashscope响应chunk
+ output_text: 当前累积的输出文本
+
+ Returns:
+ (更新后的output_text, doc_references, AgentResponse或None)
+
+ """
+ logger.debug(f"dashscope stream chunk: {chunk}")
+
+ if chunk.status_code != 200:
+ logger.error(
+ f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
+ )
+ self._transition_state(AgentState.ERROR)
+ error_msg = (
+ f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
+ )
+ self.final_llm_resp = LLMResponse(
+ role="err",
+ result_chain=MessageChain().message(error_msg),
+ )
+ return (
+ output_text,
+ None,
+ AgentResponse(
+ type="err",
+ data=AgentResponseData(chain=MessageChain().message(error_msg)),
+ ),
+ )
+
+ chunk_text = chunk.output.get("text", "") or ""
+ # RAG 引用脚标格式化
+ chunk_text = re.sub(r"
[\[(\d+)\]]", r"[\1]", chunk_text)
+
+ response = None
+ if chunk_text:
+ output_text += chunk_text
+ response = AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(chain=MessageChain().message(chunk_text)),
+ )
+
+ # 获取文档引用
+ doc_references = chunk.output.get("doc_references", None)
+
+ return output_text, doc_references, response
+
+ def _format_doc_references(self, doc_references: list) -> str:
+ """格式化文档引用为文本
+
+ Args:
+ doc_references: 文档引用列表
+
+ Returns:
+ 格式化后的引用文本
+
+ """
+ ref_parts = []
+ for ref in doc_references:
+ ref_title = (
+ ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
+ )
+ ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
+ ref_str = "".join(ref_parts)
+ return f"\n\n回答来源:\n{ref_str}"
+
+ async def _build_request_payload(
+ self, prompt: str, session_id: str, contexts: list, system_prompt: str
+ ) -> dict:
+ """构建请求payload
+
+ Args:
+ prompt: 用户输入
+ session_id: 会话ID
+ contexts: 上下文列表
+ system_prompt: 系统提示词
+
+ Returns:
+ 请求payload字典
+
+ """
+ conversation_id = await sp.get_async(
+ scope="umo",
+ scope_id=session_id,
+ key="dashscope_conversation_id",
+ default="",
+ )
+ # 获得会话变量
+ payload_vars = self.variables.copy()
+ session_var = await sp.get_async(
+ scope="umo",
+ scope_id=session_id,
+ key="session_variables",
+ default={},
+ )
+ payload_vars.update(session_var)
+
+ if (
+ self.dashscope_app_type in ["agent", "dialog-workflow"]
+ and not self.has_rag_options()
+ ):
+ # 支持多轮对话的
+ p = {
+ "app_id": self.app_id,
+ "api_key": self.api_key,
+ "prompt": prompt,
+ "biz_params": payload_vars or None,
+ "stream": self.streaming,
+ "incremental_output": True,
+ }
+ if conversation_id:
+ p["session_id"] = conversation_id
+ return p
+ else:
+ # 不支持多轮对话的
+ payload = {
+ "app_id": self.app_id,
+ "prompt": prompt,
+ "api_key": self.api_key,
+ "biz_params": payload_vars or None,
+ "stream": self.streaming,
+ "incremental_output": True,
+ }
+ if self.rag_options:
+ payload["rag_options"] = self.rag_options
+ return payload
+
+ async def _handle_streaming_response(
+ self, response: T.Any, session_id: str
+ ) -> T.AsyncGenerator[AgentResponse, None]:
+ """处理流式响应
+
+ Args:
+ response: Dashscope 流式响应 generator
+
+ Yields:
+ AgentResponse 对象
+
+ """
+ response_queue = queue.Queue()
+ consumer_thread = threading.Thread(
+ target=self._consume_sync_generator,
+ args=(response, response_queue),
+ daemon=True,
+ )
+ consumer_thread.start()
+
+ output_text = ""
+ doc_references = None
+
+ while True:
+ try:
+ item_type, item_data = await asyncio.get_event_loop().run_in_executor(
+ None, response_queue.get, True, 1
+ )
+ except queue.Empty:
+ continue
+
+ if item_type == "done":
+ break
+ elif item_type == "error":
+ raise item_data
+ elif item_type == "data":
+ chunk = item_data
+ assert isinstance(chunk, ApplicationResponse)
+
+ (
+ output_text,
+ chunk_doc_refs,
+ response,
+ ) = await self._process_stream_chunk(chunk, output_text)
+
+ if response:
+ if response.type == "err":
+ yield response
+ return
+ yield response
+
+ if chunk_doc_refs:
+ doc_references = chunk_doc_refs
+
+ if chunk.output.session_id:
+ await sp.put_async(
+ scope="umo",
+ scope_id=session_id,
+ key="dashscope_conversation_id",
+ value=chunk.output.session_id,
+ )
+
+ # 添加 RAG 引用
+ if self.output_reference and doc_references:
+ ref_text = self._format_doc_references(doc_references)
+ output_text += ref_text
+
+ if self.streaming:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(chain=MessageChain().message(ref_text)),
+ )
+
+ # 创建最终响应
+ chain = MessageChain(chain=[Comp.Plain(output_text)])
+ self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
+ self._transition_state(AgentState.DONE)
+
+ try:
+ await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
+ except Exception as e:
+ logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
+
+ # 返回最终结果
+ yield AgentResponse(
+ type="llm_result",
+ data=AgentResponseData(chain=chain),
+ )
+
+ async def _execute_dashscope_request(self):
+ """执行 Dashscope 请求的核心逻辑"""
+ prompt = self.req.prompt or ""
+ session_id = self.req.session_id or "unknown"
+ image_urls = self.req.image_urls or []
+ contexts = self.req.contexts or []
+ system_prompt = self.req.system_prompt
+
+ # 检查图片输入
+ if image_urls:
+ logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
+
+ # 构建请求payload
+ payload = await self._build_request_payload(
+ prompt, session_id, contexts, system_prompt
+ )
+
+ if not self.streaming:
+ payload["incremental_output"] = False
+
+ # 发起请求
+ partial = functools.partial(Application.call, **payload)
+ response = await asyncio.get_event_loop().run_in_executor(None, partial)
+
+ async for resp in self._handle_streaming_response(response, session_id):
+ yield resp
+
+ @override
+ def done(self) -> bool:
+ """检查 Agent 是否已完成工作"""
+ return self._state in (AgentState.DONE, AgentState.ERROR)
+
+ @override
+ def get_final_llm_resp(self) -> LLMResponse | None:
+ return self.final_llm_resp
diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py
new file mode 100644
index 000000000..d9a8b7cd6
--- /dev/null
+++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py
@@ -0,0 +1,336 @@
+import base64
+import os
+import sys
+import typing as T
+
+import astrbot.core.message.components as Comp
+from astrbot.core import logger, sp
+from astrbot.core.message.message_event_result import MessageChain
+from astrbot.core.provider.entities import (
+ LLMResponse,
+ ProviderRequest,
+)
+from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.io import download_file
+
+from ...hooks import BaseAgentRunHooks
+from ...response import AgentResponseData
+from ...run_context import ContextWrapper, TContext
+from ..base import AgentResponse, AgentState, BaseAgentRunner
+from .dify_api_client import DifyAPIClient
+
+if sys.version_info >= (3, 12):
+ from typing import override
+else:
+ from typing_extensions import override
+
+
+class DifyAgentRunner(BaseAgentRunner[TContext]):
+ """Dify Agent Runner"""
+
+ @override
+ async def reset(
+ self,
+ request: ProviderRequest,
+ run_context: ContextWrapper[TContext],
+ agent_hooks: BaseAgentRunHooks[TContext],
+ provider_config: dict,
+ **kwargs: T.Any,
+ ) -> None:
+ self.req = request
+ self.streaming = kwargs.get("streaming", False)
+ self.final_llm_resp = None
+ self._state = AgentState.IDLE
+ self.agent_hooks = agent_hooks
+ self.run_context = run_context
+
+ self.api_key = provider_config.get("dify_api_key", "")
+ self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
+ self.api_type = provider_config.get("dify_api_type", "chat")
+ self.workflow_output_key = provider_config.get(
+ "dify_workflow_output_key",
+ "astrbot_wf_output",
+ )
+ self.dify_query_input_key = provider_config.get(
+ "dify_query_input_key",
+ "astrbot_text_query",
+ )
+ self.variables: dict = provider_config.get("variables", {}) or {}
+ self.timeout = provider_config.get("timeout", 60)
+ if isinstance(self.timeout, str):
+ self.timeout = int(self.timeout)
+
+ self.api_client = DifyAPIClient(self.api_key, self.api_base)
+
+ @override
+ async def step(self):
+ """
+ 执行 Dify Agent 的一个步骤
+ """
+ if not self.req:
+ raise ValueError("Request is not set. Please call reset() first.")
+
+ if self._state == AgentState.IDLE:
+ try:
+ await self.agent_hooks.on_agent_begin(self.run_context)
+ except Exception as e:
+ logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
+
+ # 开始处理,转换到运行状态
+ self._transition_state(AgentState.RUNNING)
+
+ try:
+ # 执行 Dify 请求并处理结果
+ async for response in self._execute_dify_request():
+ yield response
+ except Exception as e:
+ logger.error(f"Dify 请求失败:{str(e)}")
+ self._transition_state(AgentState.ERROR)
+ self.final_llm_resp = LLMResponse(
+ role="err", completion_text=f"Dify 请求失败:{str(e)}"
+ )
+ yield AgentResponse(
+ type="err",
+ data=AgentResponseData(
+ chain=MessageChain().message(f"Dify 请求失败:{str(e)}")
+ ),
+ )
+ finally:
+ await self.api_client.close()
+
+ @override
+ async def step_until_done(
+ self, max_step: int = 30
+ ) -> T.AsyncGenerator[AgentResponse, None]:
+ while not self.done():
+ async for resp in self.step():
+ yield resp
+
+ async def _execute_dify_request(self):
+ """执行 Dify 请求的核心逻辑"""
+ prompt = self.req.prompt or ""
+ session_id = self.req.session_id or "unknown"
+ image_urls = self.req.image_urls or []
+ system_prompt = self.req.system_prompt
+
+ conversation_id = await sp.get_async(
+ scope="umo",
+ scope_id=session_id,
+ key="dify_conversation_id",
+ default="",
+ )
+ result = ""
+
+ # 处理图片上传
+ files_payload = []
+ for image_url in image_urls:
+ # image_url is a base64 string
+ try:
+ image_data = base64.b64decode(image_url)
+ file_response = await self.api_client.file_upload(
+ file_data=image_data,
+ user=session_id,
+ mime_type="image/png",
+ file_name="image.png",
+ )
+ logger.debug(f"Dify 上传图片响应:{file_response}")
+ if "id" not in file_response:
+ logger.warning(
+ f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
+ )
+ continue
+ files_payload.append(
+ {
+ "type": "image",
+ "transfer_method": "local_file",
+ "upload_file_id": file_response["id"],
+ }
+ )
+ except Exception as e:
+ logger.warning(f"上传图片失败:{e}")
+ continue
+
+ # 获得会话变量
+ payload_vars = self.variables.copy()
+ # 动态变量
+ session_var = await sp.get_async(
+ scope="umo",
+ scope_id=session_id,
+ key="session_variables",
+ default={},
+ )
+ payload_vars.update(session_var)
+ payload_vars["system_prompt"] = system_prompt
+
+ # 处理不同的 API 类型
+ match self.api_type:
+ case "chat" | "agent" | "chatflow":
+ if not prompt:
+ prompt = "请描述这张图片。"
+
+ async for chunk in self.api_client.chat_messages(
+ inputs={
+ **payload_vars,
+ },
+ query=prompt,
+ user=session_id,
+ conversation_id=conversation_id,
+ files=files_payload,
+ timeout=self.timeout,
+ ):
+ logger.debug(f"dify resp chunk: {chunk}")
+ if chunk["event"] == "message" or chunk["event"] == "agent_message":
+ result += chunk["answer"]
+ if not conversation_id:
+ await sp.put_async(
+ scope="umo",
+ scope_id=session_id,
+ key="dify_conversation_id",
+ value=chunk["conversation_id"],
+ )
+ conversation_id = chunk["conversation_id"]
+
+ # 如果是流式响应,发送增量数据
+ if self.streaming and chunk["answer"]:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(
+ chain=MessageChain().message(chunk["answer"])
+ ),
+ )
+ elif chunk["event"] == "message_end":
+ logger.debug("Dify message end")
+ break
+ elif chunk["event"] == "error":
+ logger.error(f"Dify 出现错误:{chunk}")
+ raise Exception(
+ f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
+ )
+
+ case "workflow":
+ async for chunk in self.api_client.workflow_run(
+ inputs={
+ self.dify_query_input_key: prompt,
+ "astrbot_session_id": session_id,
+ **payload_vars,
+ },
+ user=session_id,
+ files=files_payload,
+ timeout=self.timeout,
+ ):
+ logger.debug(f"dify workflow resp chunk: {chunk}")
+ match chunk["event"]:
+ case "workflow_started":
+ logger.info(
+ f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
+ )
+ case "node_finished":
+ logger.debug(
+ f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
+ )
+ case "text_chunk":
+ if self.streaming and chunk["data"]["text"]:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(
+ chain=MessageChain().message(
+ chunk["data"]["text"]
+ )
+ ),
+ )
+ case "workflow_finished":
+ logger.info(
+ f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
+ )
+ logger.debug(f"Dify 工作流结果:{chunk}")
+ if chunk["data"]["error"]:
+ logger.error(
+ f"Dify 工作流出现错误:{chunk['data']['error']}"
+ )
+ raise Exception(
+ f"Dify 工作流出现错误:{chunk['data']['error']}"
+ )
+ if self.workflow_output_key not in chunk["data"]["outputs"]:
+ raise Exception(
+ f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
+ )
+ result = chunk
+ case _:
+ raise Exception(f"未知的 Dify API 类型:{self.api_type}")
+
+ if not result:
+ logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
+
+ # 解析结果
+ chain = await self.parse_dify_result(result)
+
+ # 创建最终响应
+ self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
+ self._transition_state(AgentState.DONE)
+
+ try:
+ await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
+ except Exception as e:
+ logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
+
+ # 返回最终结果
+ yield AgentResponse(
+ type="llm_result",
+ data=AgentResponseData(chain=chain),
+ )
+
+ async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
+ """解析 Dify 的响应结果"""
+ if isinstance(chunk, str):
+ # Chat
+ return MessageChain(chain=[Comp.Plain(chunk)])
+
+ async def parse_file(item: dict):
+ match item["type"]:
+ case "image":
+ return Comp.Image(file=item["url"], url=item["url"])
+ case "audio":
+ # 仅支持 wav
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ path = os.path.join(temp_dir, f"{item['filename']}.wav")
+ await download_file(item["url"], path)
+ return Comp.Image(file=item["url"], url=item["url"])
+ case "video":
+ return Comp.Video(file=item["url"])
+ case _:
+ return Comp.File(name=item["filename"], file=item["url"])
+
+ output = chunk["data"]["outputs"][self.workflow_output_key]
+ chains = []
+ if isinstance(output, str):
+ # 纯文本输出
+ chains.append(Comp.Plain(output))
+ elif isinstance(output, list):
+ # 主要适配 Dify 的 HTTP 请求结点的多模态输出
+ for item in output:
+ # handle Array[File]
+ if (
+ not isinstance(item, dict)
+ or item.get("dify_model_identity", "") != "__dify__file__"
+ ):
+ chains.append(Comp.Plain(str(output)))
+ break
+ else:
+ chains.append(Comp.Plain(str(output)))
+
+ # scan file
+ files = chunk["data"].get("files", [])
+ for item in files:
+ comp = await parse_file(item)
+ chains.append(comp)
+
+ return MessageChain(chain=chains)
+
+ @override
+ def done(self) -> bool:
+ """检查 Agent 是否已完成工作"""
+ return self._state in (AgentState.DONE, AgentState.ERROR)
+
+ @override
+ def get_final_llm_resp(self) -> LLMResponse | None:
+ return self.final_llm_resp
diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py
similarity index 71%
rename from astrbot/core/utils/dify_api_client.py
rename to astrbot/core/agent/runners/dify/dify_api_client.py
index ea8ff9dff..d9c6556cf 100644
--- a/astrbot/core/utils/dify_api_client.py
+++ b/astrbot/core/agent/runners/dify/dify_api_client.py
@@ -3,7 +3,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
-from aiohttp import ClientResponse, ClientSession
+from aiohttp import ClientResponse, ClientSession, FormData
from astrbot.core import logger
@@ -101,21 +101,59 @@ class DifyAPIClient:
async def file_upload(
self,
- file_path: str,
user: str,
+ file_path: str | None = None,
+ file_data: bytes | None = None,
+ file_name: str | None = None,
+ mime_type: str | None = None,
) -> dict[str, Any]:
+ """Upload a file to Dify. Must provide either file_path or file_data.
+
+ Args:
+ user: The user ID.
+ file_path: The path to the file to upload.
+ file_data: The file data in bytes.
+ file_name: Optional file name when using file_data.
+ Returns:
+ A dictionary containing the uploaded file information.
+ """
url = f"{self.api_base}/files/upload"
- with open(file_path, "rb") as f:
- payload = {
- "user": user,
- "file": f,
- }
- async with self.session.post(
- url,
- data=payload,
- headers=self.headers,
- ) as resp:
- return await resp.json() # {"id": "xxx", ...}
+
+ form = FormData()
+ form.add_field("user", user)
+
+ if file_data is not None:
+ # 使用 bytes 数据
+ form.add_field(
+ "file",
+ file_data,
+ filename=file_name or "uploaded_file",
+ content_type=mime_type or "application/octet-stream",
+ )
+ elif file_path is not None:
+ # 使用文件路径
+ import os
+
+ with open(file_path, "rb") as f:
+ file_content = f.read()
+ form.add_field(
+ "file",
+ file_content,
+ filename=os.path.basename(file_path),
+ content_type=mime_type or "application/octet-stream",
+ )
+ else:
+ raise ValueError("file_path 和 file_data 不能同时为 None")
+
+ async with self.session.post(
+ url,
+ data=form,
+ headers=self.headers, # 不包含 Content-Type,让 aiohttp 自动设置
+ ) as resp:
+ if resp.status != 200 and resp.status != 201:
+ text = await resp.text()
+ raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
+ return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py
index 23071d446..069de144f 100644
--- a/astrbot/core/agent/runners/tool_loop_agent_runner.py
+++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py
@@ -1,4 +1,5 @@
import sys
+import time
import traceback
import typing as T
@@ -12,6 +13,7 @@ from mcp.types import (
)
from astrbot import logger
+from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import (
MessageChain,
)
@@ -23,8 +25,8 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
-from ..message import AssistantMessageSegment, ToolCallMessageSegment
-from ..response import AgentResponseData
+from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
+from ..response import AgentResponseData, AgentStats
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner
@@ -55,11 +57,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.agent_hooks = agent_hooks
self.run_context = run_context
- def _transition_state(self, new_state: AgentState) -> None:
- """转换 Agent 状态"""
- if self._state != new_state:
- logger.debug(f"Agent state transition: {self._state} -> {new_state}")
- self._state = new_state
+ messages = []
+ # append existing messages in the run context
+ for msg in request.contexts:
+ messages.append(Message.model_validate(msg))
+ if request.prompt is not None:
+ m = await request.assemble_context()
+ messages.append(Message.model_validate(m))
+ if request.system_prompt:
+ messages.insert(
+ 0,
+ Message(role="system", content=request.system_prompt),
+ )
+ self.run_context.messages = messages
+
+ self.stats = AgentStats()
+ self.stats.start_time = time.time()
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
@@ -89,22 +102,38 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
- assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
+ # update ttft
+ if self.stats.time_to_first_token == 0:
+ self.stats.time_to_first_token = time.time() - self.stats.start_time
+
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
- else:
+ elif llm_response.completion_text:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text),
),
)
+ elif llm_response.reasoning_content:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(
+ chain=MessageChain(type="reasoning").message(
+ llm_response.reasoning_content,
+ ),
+ ),
+ )
continue
llm_resp_result = llm_response
+
+ if not llm_response.is_chunk and llm_response.usage:
+ # only count the token usage of the final response for computation purpose
+ self.stats.token_usage += llm_response.usage
break # got final response
if not llm_resp_result:
@@ -116,6 +145,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp
+ self.stats.end_time = time.time()
self._transition_state(AgentState.ERROR)
yield AgentResponse(
type="err",
@@ -130,6 +160,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
+ self.stats.end_time = time.time()
+ # record the final assistant message
+ self.run_context.messages.append(
+ Message(
+ role="assistant",
+ content=llm_resp.completion_text or "",
+ ),
+ )
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
@@ -152,19 +190,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
- for tool_call_name in llm_resp.tools_call_name:
- yield AgentResponse(
- type="tool_call",
- data=AgentResponseData(
- chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
- ),
- )
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
+ if result.type is None:
+ # should not happen
+ continue
+ if result.type == "tool_direct_result":
+ ar_type = "tool_call_result"
+ else:
+ ar_type = result.type
yield AgentResponse(
- type="tool_call_result",
+ type=ar_type,
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
@@ -175,8 +213,23 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
),
tool_calls_result=tool_call_result_blocks,
)
+ # record the assistant message with tool calls
+ self.run_context.messages.extend(
+ tool_calls_result.to_openai_messages_model()
+ )
+
self.req.append_tool_calls_result(tool_calls_result)
+ async def step_until_done(
+ self, max_step: int
+ ) -> T.AsyncGenerator[AgentResponse, None]:
+ """Process steps until the agent is done."""
+ step_count = 0
+ while not self.done() and step_count < max_step:
+ step_count += 1
+ async for resp in self.step():
+ yield resp
+
async def _handle_function_tools(
self,
req: ProviderRequest,
@@ -192,6 +245,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
+ yield MessageChain(
+ type="tool_call",
+ chain=[
+ Json(
+ data={
+ "id": func_tool_id,
+ "name": func_tool_name,
+ "args": func_tool_args,
+ "ts": time.time(),
+ }
+ )
+ ],
+ )
try:
if not req.func_tool:
return
@@ -265,7 +331,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=res.content[0].text,
),
)
- yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -287,7 +352,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=resource.text,
),
)
- yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
@@ -311,7 +375,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content="返回的数据类型不受支持",
),
)
- yield MessageChain().message("返回的数据类型不受支持。")
+
+ # yield the last tool call result
+ if tool_call_result_blocks:
+ last_tcr_content = str(tool_call_result_blocks[-1].content)
+ yield MessageChain(
+ type="tool_call_result",
+ chain=[
+ Json(
+ data={
+ "id": func_tool_id,
+ "ts": time.time(),
+ "result": last_tcr_content,
+ }
+ )
+ ],
+ )
elif resp is None:
# Tool 直接请求发送消息给用户
@@ -321,6 +400,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
)
self._transition_state(AgentState.DONE)
+ self.stats.end_time = time.time()
else:
# 不应该出现其他类型
logger.warning(
diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py
index ae240d2e0..7f30f44ef 100644
--- a/astrbot/core/agent/tool.py
+++ b/astrbot/core/agent/tool.py
@@ -1,15 +1,18 @@
-from collections.abc import Awaitable, Callable
+from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic
import jsonschema
import mcp
from deprecated import deprecated
-from pydantic import model_validator
+from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
+from astrbot.core.message.message_event_result import MessageEventResult
+
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
+ToolExecResult = str | mcp.types.CallToolResult
@dataclass
@@ -37,7 +40,10 @@ class ToolSchema:
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
- handler: Callable[..., Awaitable[Any]] | None = None
+ handler: (
+ Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
+ | None
+ ) = None
"""a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None
@@ -55,15 +61,14 @@ class FunctionTool(ToolSchema, Generic[TContext]):
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
- async def call(
- self, context: ContextWrapper[TContext], **kwargs
- ) -> str | mcp.types.CallToolResult:
+ async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
+@dataclass
class ToolSet:
"""A set of function tools that can be used in function calling.
@@ -71,8 +76,7 @@ class ToolSet:
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
"""
- def __init__(self, tools: list[FunctionTool] | None = None):
- self.tools: list[FunctionTool] = tools or []
+ tools: list[FunctionTool] = Field(default_factory=list)
def empty(self) -> bool:
"""Check if the tool set is empty."""
diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py
index 28b242253..9c6451cc7 100644
--- a/astrbot/core/astr_agent_context.py
+++ b/astrbot/core/astr_agent_context.py
@@ -1,14 +1,21 @@
-from dataclasses import dataclass
+from pydantic import Field
+from pydantic.dataclasses import dataclass
+from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.platform.astr_message_event import AstrMessageEvent
-from astrbot.core.provider import Provider
-from astrbot.core.provider.entities import ProviderRequest
+from astrbot.core.star.context import Context
@dataclass
class AstrAgentContext:
- provider: Provider
- first_provider_request: ProviderRequest
- curr_provider_request: ProviderRequest
- streaming: bool
+ __pydantic_config__ = {"arbitrary_types_allowed": True}
+
+ context: Context
+ """The star context instance"""
event: AstrMessageEvent
+ """The message event associated with the agent context."""
+ extra: dict[str, str] = Field(default_factory=dict)
+ """Customized extra data."""
+
+
+AgentContextWrapper = ContextWrapper[AstrAgentContext]
diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py
new file mode 100644
index 000000000..f394fc947
--- /dev/null
+++ b/astrbot/core/astr_agent_hooks.py
@@ -0,0 +1,36 @@
+from typing import Any
+
+from mcp.types import CallToolResult
+
+from astrbot.core.agent.hooks import BaseAgentRunHooks
+from astrbot.core.agent.run_context import ContextWrapper
+from astrbot.core.agent.tool import FunctionTool
+from astrbot.core.astr_agent_context import AstrAgentContext
+from astrbot.core.pipeline.context_utils import call_event_hook
+from astrbot.core.star.star_handler import EventType
+
+
+class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
+ async def on_agent_done(self, run_context, llm_response):
+ # 执行事件钩子
+ await call_event_hook(
+ run_context.context.event,
+ EventType.OnLLMResponseEvent,
+ llm_response,
+ )
+
+ async def on_tool_end(
+ self,
+ run_context: ContextWrapper[AstrAgentContext],
+ tool: FunctionTool[Any],
+ tool_args: dict | None,
+ tool_result: CallToolResult | None,
+ ):
+ run_context.context.event.clear_result()
+
+
+class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
+ pass
+
+
+MAIN_AGENT_HOOKS = MainAgentHooks()
diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py
new file mode 100644
index 000000000..5421a14c0
--- /dev/null
+++ b/astrbot/core/astr_agent_run_util.py
@@ -0,0 +1,115 @@
+import traceback
+from collections.abc import AsyncGenerator
+
+from astrbot.core import logger
+from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
+from astrbot.core.astr_agent_context import AstrAgentContext
+from astrbot.core.message.components import Json
+from astrbot.core.message.message_event_result import (
+ MessageChain,
+ MessageEventResult,
+ ResultContentType,
+)
+from astrbot.core.provider.entities import LLMResponse
+
+AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
+
+
+async def run_agent(
+ agent_runner: AgentRunner,
+ max_step: int = 30,
+ show_tool_use: bool = True,
+ stream_to_general: bool = False,
+ show_reasoning: bool = False,
+) -> AsyncGenerator[MessageChain | None, None]:
+ step_idx = 0
+ astr_event = agent_runner.run_context.context.event
+ while step_idx < max_step:
+ step_idx += 1
+ try:
+ async for resp in agent_runner.step():
+ if astr_event.is_stopped():
+ return
+ if resp.type == "tool_call_result":
+ msg_chain = resp.data["chain"]
+ if msg_chain.type == "tool_direct_result":
+ # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
+ await astr_event.send(msg_chain)
+ continue
+ if astr_event.get_platform_id() == "webchat":
+ await astr_event.send(msg_chain)
+ # 对于其他情况,暂时先不处理
+ continue
+ elif resp.type == "tool_call":
+ if agent_runner.streaming:
+ # 用来标记流式响应需要分节
+ yield MessageChain(chain=[], type="break")
+
+ if astr_event.get_platform_name() == "webchat":
+ await astr_event.send(resp.data["chain"])
+ elif show_tool_use:
+ json_comp = resp.data["chain"].chain[0]
+ if isinstance(json_comp, Json):
+ m = f"🔨 调用工具: {json_comp.data.get('name')}"
+ else:
+ m = "🔨 调用工具..."
+ chain = MessageChain(type="tool_call").message(m)
+ await astr_event.send(chain)
+ continue
+
+ if stream_to_general and resp.type == "streaming_delta":
+ continue
+
+ if stream_to_general or not agent_runner.streaming:
+ content_typ = (
+ ResultContentType.LLM_RESULT
+ if resp.type == "llm_result"
+ else ResultContentType.GENERAL_RESULT
+ )
+ astr_event.set_result(
+ MessageEventResult(
+ chain=resp.data["chain"].chain,
+ result_content_type=content_typ,
+ ),
+ )
+ yield
+ astr_event.clear_result()
+ elif resp.type == "streaming_delta":
+ chain = resp.data["chain"]
+ if chain.type == "reasoning" and not show_reasoning:
+ # display the reasoning content only when configured
+ continue
+ yield resp.data["chain"] # MessageChain
+ if agent_runner.done():
+ # send agent stats to webchat
+ if astr_event.get_platform_name() == "webchat":
+ await astr_event.send(
+ MessageChain(
+ type="agent_stats",
+ chain=[Json(data=agent_runner.stats.to_dict())],
+ )
+ )
+
+ break
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+
+ err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
+
+ error_llm_response = LLMResponse(
+ role="err",
+ completion_text=err_msg,
+ )
+ try:
+ await agent_runner.agent_hooks.on_agent_done(
+ agent_runner.run_context, error_llm_response
+ )
+ except Exception:
+ logger.exception("Error in on_agent_done hook")
+
+ if agent_runner.streaming:
+ yield MessageChain().message(err_msg)
+ else:
+ astr_event.set_result(MessageEventResult().message(err_msg))
+ return
diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py
new file mode 100644
index 000000000..ed08e90a9
--- /dev/null
+++ b/astrbot/core/astr_agent_tool_exec.py
@@ -0,0 +1,250 @@
+import asyncio
+import inspect
+import traceback
+import typing as T
+
+import mcp
+
+from astrbot import logger
+from astrbot.core.agent.handoff import HandoffTool
+from astrbot.core.agent.mcp_client import MCPTool
+from astrbot.core.agent.run_context import ContextWrapper
+from astrbot.core.agent.tool import FunctionTool, ToolSet
+from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
+from astrbot.core.astr_agent_context import AstrAgentContext
+from astrbot.core.message.message_event_result import (
+ CommandResult,
+ MessageChain,
+ MessageEventResult,
+)
+from astrbot.core.provider.register import llm_tools
+
+
+class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
+ @classmethod
+ async def execute(cls, tool, run_context, **tool_args):
+ """执行函数调用。
+
+ Args:
+ event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
+ **kwargs: 函数调用的参数。
+
+ Returns:
+ AsyncGenerator[None | mcp.types.CallToolResult, None]
+
+ """
+ if isinstance(tool, HandoffTool):
+ async for r in cls._execute_handoff(tool, run_context, **tool_args):
+ yield r
+ return
+
+ elif isinstance(tool, MCPTool):
+ async for r in cls._execute_mcp(tool, run_context, **tool_args):
+ yield r
+ return
+
+ else:
+ async for r in cls._execute_local(tool, run_context, **tool_args):
+ yield r
+ return
+
+ @classmethod
+ async def _execute_handoff(
+ cls,
+ tool: HandoffTool,
+ run_context: ContextWrapper[AstrAgentContext],
+ **tool_args,
+ ):
+ input_ = tool_args.get("input")
+
+ # make toolset for the agent
+ tools = tool.agent.tools
+ if tools:
+ toolset = ToolSet()
+ for t in tools:
+ if isinstance(t, str):
+ _t = llm_tools.get_func(t)
+ if _t:
+ toolset.add_tool(_t)
+ elif isinstance(t, FunctionTool):
+ toolset.add_tool(t)
+ else:
+ toolset = None
+
+ ctx = run_context.context.context
+ event = run_context.context.event
+ umo = event.unified_msg_origin
+ prov_id = await ctx.get_current_chat_provider_id(umo)
+ llm_resp = await ctx.tool_loop_agent(
+ event=event,
+ chat_provider_id=prov_id,
+ prompt=input_,
+ system_prompt=tool.agent.instructions,
+ tools=toolset,
+ max_steps=30,
+ run_hooks=tool.agent.run_hooks,
+ )
+ yield mcp.types.CallToolResult(
+ content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
+ )
+
+ @classmethod
+ async def _execute_local(
+ cls,
+ tool: FunctionTool,
+ run_context: ContextWrapper[AstrAgentContext],
+ **tool_args,
+ ):
+ event = run_context.context.event
+ if not event:
+ raise ValueError("Event must be provided for local function tools.")
+
+ is_override_call = False
+ for ty in type(tool).mro():
+ if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
+ is_override_call = True
+ break
+
+ # 检查 tool 下有没有 run 方法
+ if not tool.handler and not hasattr(tool, "run") and not is_override_call:
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
+
+ awaitable = None
+ method_name = ""
+ if tool.handler:
+ awaitable = tool.handler
+ method_name = "decorator_handler"
+ elif is_override_call:
+ awaitable = tool.call
+ method_name = "call"
+ elif hasattr(tool, "run"):
+ awaitable = getattr(tool, "run")
+ method_name = "run"
+ if awaitable is None:
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
+
+ wrapper = call_local_llm_tool(
+ context=run_context,
+ handler=awaitable,
+ method_name=method_name,
+ **tool_args,
+ )
+ while True:
+ try:
+ resp = await asyncio.wait_for(
+ anext(wrapper),
+ timeout=run_context.tool_call_timeout,
+ )
+ if resp is not None:
+ if isinstance(resp, mcp.types.CallToolResult):
+ yield resp
+ else:
+ text_content = mcp.types.TextContent(
+ type="text",
+ text=str(resp),
+ )
+ yield mcp.types.CallToolResult(content=[text_content])
+ else:
+ # NOTE: Tool 在这里直接请求发送消息给用户
+ # TODO: 是否需要判断 event.get_result() 是否为空?
+ # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
+ if res := run_context.context.event.get_result():
+ if res.chain:
+ try:
+ await event.send(
+ MessageChain(
+ chain=res.chain,
+ type="tool_direct_result",
+ )
+ )
+ except Exception as e:
+ logger.error(
+ f"Tool 直接发送消息失败: {e}",
+ exc_info=True,
+ )
+ yield None
+ except asyncio.TimeoutError:
+ raise Exception(
+ f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
+ )
+ except StopAsyncIteration:
+ break
+
+ @classmethod
+ async def _execute_mcp(
+ cls,
+ tool: FunctionTool,
+ run_context: ContextWrapper[AstrAgentContext],
+ **tool_args,
+ ):
+ res = await tool.call(run_context, **tool_args)
+ if not res:
+ return
+ yield res
+
+
+async def call_local_llm_tool(
+ context: ContextWrapper[AstrAgentContext],
+ handler: T.Callable[
+ ...,
+ T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
+ | T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
+ ],
+ method_name: str,
+ *args,
+ **kwargs,
+) -> T.AsyncGenerator[T.Any, None]:
+ """执行本地 LLM 工具的处理函数并处理其返回结果"""
+ ready_to_call = None # 一个协程或者异步生成器
+
+ trace_ = None
+
+ event = context.context.event
+
+ try:
+ if method_name == "run" or method_name == "decorator_handler":
+ ready_to_call = handler(event, *args, **kwargs)
+ elif method_name == "call":
+ ready_to_call = handler(context, *args, **kwargs)
+ else:
+ raise ValueError(f"未知的方法名: {method_name}")
+ except ValueError as e:
+ logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
+ except TypeError:
+ logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
+ except Exception as e:
+ trace_ = traceback.format_exc()
+ logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
+
+ if not ready_to_call:
+ return
+
+ if inspect.isasyncgen(ready_to_call):
+ _has_yielded = False
+ try:
+ async for ret in ready_to_call:
+ # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
+ # 返回值只能是 MessageEventResult 或者 None(无返回值)
+ _has_yielded = True
+ if isinstance(ret, (MessageEventResult, CommandResult)):
+ # 如果返回值是 MessageEventResult, 设置结果并继续
+ event.set_result(ret)
+ yield
+ else:
+ # 如果返回值是 None, 则不设置结果并继续
+ # 继续执行后续阶段
+ yield ret
+ if not _has_yielded:
+ # 如果这个异步生成器没有执行到 yield 分支
+ yield
+ except Exception as e:
+ logger.error(f"Previous Error: {trace_}")
+ raise e
+ elif inspect.iscoroutine(ready_to_call):
+ # 如果只是一个协程, 直接执行
+ ret = await ready_to_call
+ if isinstance(ret, (MessageEventResult, CommandResult)):
+ event.set_result(ret)
+ yield
+ else:
+ yield ret
diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py
index 786d29c81..9477eabaa 100644
--- a/astrbot/core/config/astrbot_config.py
+++ b/astrbot/core/config/astrbot_config.py
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
"""
+ config_path: str
+ default_config: dict
+ schema: dict | None
+
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 1021d81b5..327191db6 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -4,9 +4,18 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
-VERSION = "4.5.6"
+VERSION = "4.9.2"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
+WEBHOOK_SUPPORTED_PLATFORMS = [
+ "qq_official_webhook",
+ "weixin_official_account",
+ "wecom",
+ "wecom_ai_bot",
+ "slack",
+ "lark",
+]
+
# 默认配置
DEFAULT_CONFIG = {
"config_version": 2,
@@ -34,7 +43,15 @@ DEFAULT_CONFIG = {
"interval": "1.5,3.5",
"log_base": 2.6,
"words_count_threshold": 150,
+ "split_mode": "regex", # regex 或 words
"regex": ".*?[。?!~…]+|.+$",
+ "split_words": [
+ "。",
+ "?",
+ "!",
+ "~",
+ "…",
+ ], # 当 split_mode 为 words 时使用
"content_cleanup_rule": "",
},
"no_permission_reply": True,
@@ -68,9 +85,19 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
- "streaming_segmented": False,
+ "agent_runner_type": "local",
+ "dify_agent_runner_provider_id": "",
+ "coze_agent_runner_provider_id": "",
+ "dashscope_agent_runner_provider_id": "",
+ "unsupported_streaming_strategy": "realtime_segmenting",
+ "reachability_check": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
+ "file_extract": {
+ "enable": False,
+ "provider": "moonshotai",
+ "moonshotai_api_key": "",
+ },
},
"provider_stt_settings": {
"enable": False,
@@ -81,11 +108,13 @@ DEFAULT_CONFIG = {
"provider_id": "",
"dual_output": False,
"use_file_service": False,
+ "trigger_probability": 1.0,
},
"provider_ltm_settings": {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
+ "image_caption_provider_id": "",
"active_reply": {
"enable": False,
"method": "possibility_reply",
@@ -137,10 +166,21 @@ DEFAULT_CONFIG = {
"kb_names": [], # 默认知识库名称列表
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
+ "kb_agentic_mode": False,
+ "disable_builtin_commands": False,
}
-# 配置项的中文描述、值类型
+"""
+AstrBot v3 时代的配置元数据,目前仅承担以下功能:
+
+1. 保存配置时,配置项的类型验证
+2. WebUI 展示提供商和平台适配器模版
+
+WebUI 的配置文件在 `CONFIG_METADATA_3` 中。
+
+未来将会逐步淘汰此配置元数据。
+"""
CONFIG_METADATA_2 = {
"platform_group": {
"metadata": {
@@ -164,10 +204,12 @@ CONFIG_METADATA_2 = {
"appid": "",
"secret": "",
"is_sandbox": False,
+ "unified_webhook_mode": True,
+ "webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6196,
},
- "QQ 个人号(OneBot v11)": {
+ "OneBot v11": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
@@ -194,6 +236,8 @@ CONFIG_METADATA_2 = {
"token": "",
"encoding_aes_key": "",
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
+ "unified_webhook_mode": True,
+ "webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6194,
"active_send_mode": False,
@@ -208,6 +252,8 @@ CONFIG_METADATA_2 = {
"encoding_aes_key": "",
"kf_name": "",
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
+ "unified_webhook_mode": True,
+ "webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6195,
},
@@ -220,6 +266,8 @@ CONFIG_METADATA_2 = {
"wecom_ai_bot_name": "",
"token": "",
"encoding_aes_key": "",
+ "unified_webhook_mode": True,
+ "webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6198,
},
@@ -231,6 +279,10 @@ CONFIG_METADATA_2 = {
"app_id": "",
"app_secret": "",
"domain": "https://open.feishu.cn",
+ "lark_connection_mode": "socket", # webhook, socket
+ "webhook_uuid": "",
+ "lark_encrypt_key": "",
+ "lark_verification_token": "",
},
"钉钉(DingTalk)": {
"id": "dingtalk",
@@ -287,6 +339,8 @@ CONFIG_METADATA_2 = {
"app_token": "",
"signing_secret": "",
"slack_connection_mode": "socket", # webhook, socket
+ "unified_webhook_mode": True,
+ "webhook_uuid": "",
"slack_webhook_host": "0.0.0.0",
"slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback",
@@ -322,6 +376,28 @@ CONFIG_METADATA_2 = {
# "type": "string",
# "options": ["fullscreen", "embedded"],
# },
+ "lark_connection_mode": {
+ "description": "订阅方式",
+ "type": "string",
+ "options": ["socket", "webhook"],
+ "labels": ["长连接模式", "推送至服务器模式"],
+ },
+ "lark_encrypt_key": {
+ "description": "Encrypt Key",
+ "type": "string",
+ "hint": "用于解密飞书回调数据的加密密钥",
+ "condition": {
+ "lark_connection_mode": "webhook",
+ },
+ },
+ "lark_verification_token": {
+ "description": "Verification Token",
+ "type": "string",
+ "hint": "用于验证飞书回调请求的令牌",
+ "condition": {
+ "lark_connection_mode": "webhook",
+ },
+ },
"is_sandbox": {
"description": "沙箱模式",
"type": "bool",
@@ -366,16 +442,28 @@ CONFIG_METADATA_2 = {
"description": "Slack Webhook Host",
"type": "string",
"hint": "Only valid when Slack connection mode is `webhook`.",
+ "condition": {
+ "slack_connection_mode": "webhook",
+ "unified_webhook_mode": False,
+ },
},
"slack_webhook_port": {
"description": "Slack Webhook Port",
"type": "int",
"hint": "Only valid when Slack connection mode is `webhook`.",
+ "condition": {
+ "slack_connection_mode": "webhook",
+ "unified_webhook_mode": False,
+ },
},
"slack_webhook_path": {
"description": "Slack Webhook Path",
"type": "string",
"hint": "Only valid when Slack connection mode is `webhook`.",
+ "condition": {
+ "slack_connection_mode": "webhook",
+ "unified_webhook_mode": False,
+ },
},
"active_send_mode": {
"description": "是否换用主动发送接口",
@@ -566,6 +654,33 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
},
+ "port": {
+ "description": "回调服务器端口",
+ "type": "int",
+ "hint": "回调服务器端口。留空则不启用回调服务器。",
+ "condition": {
+ "unified_webhook_mode": False,
+ },
+ },
+ "callback_server_host": {
+ "description": "回调服务器主机",
+ "type": "string",
+ "hint": "回调服务器主机。留空则不启用回调服务器。",
+ "condition": {
+ "unified_webhook_mode": False,
+ },
+ },
+ "unified_webhook_mode": {
+ "description": "统一 Webhook 模式",
+ "type": "bool",
+ "hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。",
+ },
+ "webhook_uuid": {
+ "invisible": True,
+ "description": "Webhook UUID",
+ "type": "string",
+ "hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。",
+ },
},
},
"platform_settings": {
@@ -633,7 +748,7 @@ CONFIG_METADATA_2 = {
},
"words_count_threshold": {
"type": "int",
- "hint": "超过这个字数的消息不会被分段回复。默认为 150",
+ "hint": "分段回复的字数上限。只有字数小于此值的消息才会被分段,超过此值的长消息将直接发送(不分段)。默认为 150",
},
"regex": {
"type": "string",
@@ -831,7 +946,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120,
"model_config": {
- "model": "gemini-1.5-flash",
+ "model": "gemini-3-flash-preview",
"temperature": 0.4,
},
"custom_headers": {},
@@ -848,7 +963,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120,
"model_config": {
- "model": "gemini-2.0-flash-exp",
+ "model": "gemini-3-flash-preview",
"temperature": 0.4,
},
"gm_resp_image_modal": False,
@@ -861,9 +976,7 @@ CONFIG_METADATA_2 = {
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
},
- "gm_thinking_config": {
- "budget": 0,
- },
+ "gm_thinking_config": {"budget": 0, "level": "HIGH"},
"modalities": ["text", "image", "tool_use"],
},
"DeepSeek": {
@@ -880,6 +993,23 @@ CONFIG_METADATA_2 = {
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
+ "Groq": {
+ "id": "groq_default",
+ "provider": "groq",
+ "type": "groq_chat_completion",
+ "provider_type": "chat_completion",
+ "enable": True,
+ "key": [],
+ "api_base": "https://api.groq.com/openai/v1",
+ "timeout": 120,
+ "model_config": {
+ "model": "openai/gpt-oss-20b",
+ "temperature": 0.4,
+ },
+ "custom_headers": {},
+ "custom_extra_body": {},
+ "modalities": ["text", "tool_use"],
+ },
"302.AI": {
"id": "302ai",
"provider": "302ai",
@@ -993,7 +1123,7 @@ CONFIG_METADATA_2 = {
"id": "dify_app_default",
"provider": "dify",
"type": "dify",
- "provider_type": "chat_completion",
+ "provider_type": "agent_runner",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
@@ -1007,20 +1137,20 @@ CONFIG_METADATA_2 = {
"Coze": {
"id": "coze",
"provider": "coze",
- "provider_type": "chat_completion",
+ "provider_type": "agent_runner",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
- "auto_save_history": True,
+ # "auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
"type": "dashscope",
- "provider_type": "chat_completion",
+ "provider_type": "agent_runner",
"enable": True,
"dashscope_app_type": "agent",
"dashscope_api_key": "",
@@ -1069,7 +1199,7 @@ CONFIG_METADATA_2 = {
"api_base": "",
"model": "whisper-1",
},
- "Whisper(本地加载)": {
+ "Whisper(Local)": {
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"provider": "openai",
"type": "openai_whisper_selfhost",
@@ -1078,7 +1208,7 @@ CONFIG_METADATA_2 = {
"id": "whisper_selfhost",
"model": "tiny",
},
- "SenseVoice(本地加载)": {
+ "SenseVoice(Local)": {
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"type": "sensevoice_stt_selfhost",
"provider": "sensevoice",
@@ -1113,7 +1243,7 @@ CONFIG_METADATA_2 = {
"pitch": "+0Hz",
"timeout": 20,
},
- "GSV TTS(本地加载)": {
+ "GSV TTS(Local)": {
"id": "gsv_tts",
"enable": False,
"provider": "gpt_sovits",
@@ -1290,6 +1420,19 @@ CONFIG_METADATA_2 = {
"timeout": 20,
"launch_model_if_not_running": False,
},
+ "阿里云百炼重排序": {
+ "id": "bailian_rerank",
+ "type": "bailian_rerank",
+ "provider": "bailian",
+ "provider_type": "rerank",
+ "enable": True,
+ "rerank_api_key": "",
+ "rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
+ "rerank_model": "qwen3-rerank",
+ "timeout": 30,
+ "return_documents": False,
+ "instruct": "",
+ },
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
@@ -1324,6 +1467,16 @@ CONFIG_METADATA_2 = {
"description": "重排序模型名称",
"type": "string",
},
+ "return_documents": {
+ "description": "是否在排序结果中返回文档原文",
+ "type": "bool",
+ "hint": "默认值false,以减少网络传输开销。",
+ },
+ "instruct": {
+ "description": "自定义排序任务类型说明",
+ "type": "string",
+ "hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
+ },
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"type": "bool",
@@ -1664,13 +1817,24 @@ CONFIG_METADATA_2 = {
},
},
"gm_thinking_config": {
- "description": "Gemini思考设置",
+ "description": "Thinking Config",
"type": "object",
"items": {
"budget": {
- "description": "思考预算",
+ "description": "Thinking Budget",
"type": "int",
- "hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
+ "hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
+ },
+ "level": {
+ "description": "Thinking Level",
+ "type": "string",
+ "hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
+ "options": [
+ "MINIMAL",
+ "LOW",
+ "MEDIUM",
+ "HIGH",
+ ],
},
},
},
@@ -1866,7 +2030,6 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用",
"type": "bool",
- "hint": "是否启用。",
},
"key": {
"description": "API Key",
@@ -1993,17 +2156,41 @@ CONFIG_METADATA_2 = {
"show_tool_use_status": {
"type": "bool",
},
- "streaming_segmented": {
- "type": "bool",
+ "unsupported_streaming_strategy": {
+ "type": "string",
+ },
+ "agent_runner_type": {
+ "type": "string",
+ },
+ "dify_agent_runner_provider_id": {
+ "type": "string",
+ },
+ "coze_agent_runner_provider_id": {
+ "type": "string",
+ },
+ "dashscope_agent_runner_provider_id": {
+ "type": "string",
},
"max_agent_step": {
- "description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
- "description": "工具调用超时时间(秒)",
"type": "int",
},
+ "file_extract": {
+ "type": "object",
+ "items": {
+ "enable": {
+ "type": "bool",
+ },
+ "provider": {
+ "type": "string",
+ },
+ "moonshotai_api_key": {
+ "type": "string",
+ },
+ },
+ },
},
},
"provider_stt_settings": {
@@ -2032,6 +2219,9 @@ CONFIG_METADATA_2 = {
"use_file_service": {
"type": "bool",
},
+ "trigger_probability": {
+ "type": "float",
+ },
},
},
"provider_ltm_settings": {
@@ -2046,6 +2236,9 @@ CONFIG_METADATA_2 = {
"image_caption": {
"type": "bool",
},
+ "image_caption_provider_id": {
+ "type": "string",
+ },
"image_caption_prompt": {
"type": "string",
},
@@ -2129,39 +2322,93 @@ CONFIG_METADATA_2 = {
"kb_names": {"type": "list", "items": {"type": "string"}},
"kb_fusion_top_k": {"type": "int", "default": 20},
"kb_final_top_k": {"type": "int", "default": 5},
+ "kb_agentic_mode": {"type": "bool"},
},
},
}
+"""
+v4.7.0 之后,name, description, hint 等字段已经实现 i18n 国际化。国际化资源文件位于:
+
+- dashboard/src/i18n/locales/en-US/features/config-metadata.json
+- dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
+
+如果在此文件中添加了新的配置字段,请务必同步更新上述两个国际化资源文件。
+"""
CONFIG_METADATA_3 = {
"ai_group": {
"name": "AI 配置",
"metadata": {
- "ai": {
- "description": "模型",
+ "agent_runner": {
+ "description": "Agent 执行方式",
+ "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。",
"type": "object",
"items": {
"provider_settings.enable": {
- "description": "启用大语言模型聊天",
+ "description": "启用",
"type": "bool",
+ "hint": "AI 对话总开关",
},
+ "provider_settings.agent_runner_type": {
+ "description": "执行器",
+ "type": "string",
+ "options": ["local", "dify", "coze", "dashscope"],
+ "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"],
+ "condition": {
+ "provider_settings.enable": True,
+ },
+ },
+ "provider_settings.coze_agent_runner_provider_id": {
+ "description": "Coze Agent 执行器提供商 ID",
+ "type": "string",
+ "_special": "select_agent_runner_provider:coze",
+ "condition": {
+ "provider_settings.agent_runner_type": "coze",
+ "provider_settings.enable": True,
+ },
+ },
+ "provider_settings.dify_agent_runner_provider_id": {
+ "description": "Dify Agent 执行器提供商 ID",
+ "type": "string",
+ "_special": "select_agent_runner_provider:dify",
+ "condition": {
+ "provider_settings.agent_runner_type": "dify",
+ "provider_settings.enable": True,
+ },
+ },
+ "provider_settings.dashscope_agent_runner_provider_id": {
+ "description": "阿里云百炼应用 Agent 执行器提供商 ID",
+ "type": "string",
+ "_special": "select_agent_runner_provider:dashscope",
+ "condition": {
+ "provider_settings.agent_runner_type": "dashscope",
+ "provider_settings.enable": True,
+ },
+ },
+ },
+ },
+ "ai": {
+ "description": "模型",
+ "hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
+ "type": "object",
+ "items": {
"provider_settings.default_provider_id": {
"description": "默认聊天模型",
"type": "string",
"_special": "select_provider",
- "hint": "留空时使用第一个模型。",
+ "hint": "留空时使用第一个模型",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
"_special": "select_provider",
- "hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
+ "hint": "留空代表不使用,可用于非多模态模型",
},
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",
- "hint": "STT 总开关。",
+ "hint": "STT 总开关",
},
"provider_stt_settings.provider_id": {
"description": "默认语音转文本模型",
@@ -2175,22 +2422,32 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": {
"description": "启用文本转语音",
"type": "bool",
- "hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
+ "hint": "TTS 总开关",
},
"provider_tts_settings.provider_id": {
"description": "默认文本转语音模型",
"type": "string",
- "hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
"_special": "select_provider_tts",
"condition": {
"provider_tts_settings.enable": True,
},
},
+ "provider_tts_settings.trigger_probability": {
+ "description": "TTS 触发概率",
+ "type": "float",
+ "slider": {"min": 0, "max": 1, "step": 0.05},
+ "condition": {
+ "provider_tts_settings.enable": True,
+ },
+ },
"provider_settings.image_caption_prompt": {
"description": "图片转述提示词",
"type": "text",
},
},
+ "condition": {
+ "provider_settings.enable": True,
+ },
},
"persona": {
"description": "人格",
@@ -2202,6 +2459,10 @@ CONFIG_METADATA_3 = {
"_special": "select_persona",
},
},
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ "provider_settings.enable": True,
+ },
},
"knowledgebase": {
"description": "知识库",
@@ -2224,6 +2485,15 @@ CONFIG_METADATA_3 = {
"type": "int",
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
},
+ "kb_agentic_mode": {
+ "description": "Agentic 知识库检索",
+ "type": "bool",
+ "hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
+ },
+ },
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ "provider_settings.enable": True,
},
},
"websearch": {
@@ -2261,7 +2531,41 @@ CONFIG_METADATA_3 = {
"type": "bool",
},
},
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ "provider_settings.enable": True,
+ },
},
+ # "file_extract": {
+ # "description": "文档解析能力 [beta]",
+ # "type": "object",
+ # "items": {
+ # "provider_settings.file_extract.enable": {
+ # "description": "启用文档解析能力",
+ # "type": "bool",
+ # },
+ # "provider_settings.file_extract.provider": {
+ # "description": "文档解析提供商",
+ # "type": "string",
+ # "options": ["moonshotai"],
+ # "condition": {
+ # "provider_settings.file_extract.enable": True,
+ # },
+ # },
+ # "provider_settings.file_extract.moonshotai_api_key": {
+ # "description": "Moonshot AI API Key",
+ # "type": "string",
+ # "condition": {
+ # "provider_settings.file_extract.provider": "moonshotai",
+ # "provider_settings.file_extract.enable": True,
+ # },
+ # },
+ # },
+ # "condition": {
+ # "provider_settings.agent_runner_type": "local",
+ # "provider_settings.enable": True,
+ # },
+ # },
"others": {
"description": "其他配置",
"type": "object",
@@ -2269,54 +2573,83 @@ CONFIG_METADATA_3 = {
"provider_settings.display_reasoning_text": {
"description": "显示思考内容",
"type": "bool",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
},
"provider_settings.identifier": {
"description": "用户识别",
"type": "bool",
+ "hint": "启用后,会在提示词前包含用户 ID 信息。",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
- "hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
+ "hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",
+ "hint": "启用后,会在系统提示词中附带当前时间信息。",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
},
"provider_settings.show_tool_use_status": {
"description": "输出函数调用状态",
"type": "bool",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
},
"provider_settings.streaming_response": {
- "description": "流式回复",
+ "description": "流式输出",
"type": "bool",
},
- "provider_settings.streaming_segmented": {
- "description": "不支持流式回复的平台采取分段输出",
- "type": "bool",
+ "provider_settings.unsupported_streaming_strategy": {
+ "description": "不支持流式回复的平台",
+ "type": "string",
+ "options": ["realtime_segmenting", "turn_off"],
+ "hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
+ "labels": ["实时分段回复", "关闭流式回复"],
+ "condition": {
+ "provider_settings.streaming_response": True,
+ },
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
- "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。",
+ "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
- "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数。",
+ "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
- "hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
+ "hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求",
},
"provider_settings.prompt_prefix": {
"description": "用户提示词",
@@ -2327,6 +2660,14 @@ CONFIG_METADATA_3 = {
"description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool",
},
+ "provider_settings.reachability_check": {
+ "description": "提供商可达性检测",
+ "type": "bool",
+ "hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
+ },
+ },
+ "condition": {
+ "provider_settings.enable": True,
},
},
},
@@ -2377,6 +2718,11 @@ CONFIG_METADATA_3 = {
"description": "只 @ 机器人是否触发等待",
"type": "bool",
},
+ "disable_builtin_commands": {
+ "description": "禁用自带指令",
+ "type": "bool",
+ "hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
+ },
},
},
"whitelist": {
@@ -2591,9 +2937,26 @@ CONFIG_METADATA_3 = {
"description": "分段回复字数阈值",
"type": "int",
},
+ "platform_settings.segmented_reply.split_mode": {
+ "description": "分段模式",
+ "type": "string",
+ "options": ["regex", "words"],
+ "labels": ["正则表达式", "分段词列表"],
+ },
"platform_settings.segmented_reply.regex": {
"description": "分段正则表达式",
"type": "string",
+ "condition": {
+ "platform_settings.segmented_reply.split_mode": "regex",
+ },
+ },
+ "platform_settings.segmented_reply.split_words": {
+ "description": "分段词列表",
+ "type": "list",
+ "hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
+ "condition": {
+ "platform_settings.segmented_reply.split_mode": "words",
+ },
},
"platform_settings.segmented_reply.content_cleanup_rule": {
"description": "内容过滤正则表达式",
@@ -2617,7 +2980,16 @@ CONFIG_METADATA_3 = {
"provider_ltm_settings.image_caption": {
"description": "自动理解图片",
"type": "bool",
- "hint": "需要设置默认图片转述模型。",
+ "hint": "需要设置群聊图片转述模型。",
+ },
+ "provider_ltm_settings.image_caption_provider_id": {
+ "description": "群聊图片转述模型",
+ "type": "string",
+ "_special": "select_provider",
+ "hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
+ "condition": {
+ "provider_ltm_settings.image_caption": True,
+ },
},
"provider_ltm_settings.active_reply.enable": {
"description": "主动回复",
@@ -2635,6 +3007,7 @@ CONFIG_METADATA_3 = {
"description": "回复概率",
"type": "float",
"hint": "0.0-1.0 之间的数值",
+ "slider": {"min": 0, "max": 1, "step": 0.05},
"condition": {
"provider_ltm_settings.active_reply.enable": True,
},
diff --git a/astrbot/core/config/i18n_utils.py b/astrbot/core/config/i18n_utils.py
new file mode 100644
index 000000000..aa441c0c1
--- /dev/null
+++ b/astrbot/core/config/i18n_utils.py
@@ -0,0 +1,111 @@
+"""
+配置元数据国际化工具
+
+提供配置元数据的国际化键转换功能
+"""
+
+from typing import Any
+
+
+class ConfigMetadataI18n:
+ """配置元数据国际化转换器"""
+
+ @staticmethod
+ def _get_i18n_key(group: str, section: str, field: str, attr: str) -> str:
+ """
+ 生成国际化键
+
+ Args:
+ group: 配置组,如 'ai_group', 'platform_group'
+ section: 配置节,如 'agent_runner', 'general'
+ field: 字段名,如 'enable', 'default_provider'
+ attr: 属性类型,如 'description', 'hint', 'labels'
+
+ Returns:
+ 国际化键,格式如: 'ai_group.agent_runner.enable.description'
+ """
+ if field:
+ return f"{group}.{section}.{field}.{attr}"
+ else:
+ return f"{group}.{section}.{attr}"
+
+ @staticmethod
+ def convert_to_i18n_keys(metadata: dict[str, Any]) -> dict[str, Any]:
+ """
+ 将配置元数据转换为使用国际化键
+
+ Args:
+ metadata: 原始配置元数据字典
+
+ Returns:
+ 使用国际化键的配置元数据字典
+ """
+ result = {}
+
+ for group_key, group_data in metadata.items():
+ group_result = {
+ "name": f"{group_key}.name",
+ "metadata": {},
+ }
+
+ for section_key, section_data in group_data.get("metadata", {}).items():
+ section_result = {
+ "description": f"{group_key}.{section_key}.description",
+ "type": section_data.get("type"),
+ }
+
+ # 复制其他属性
+ for key in ["items", "condition", "_special", "invisible"]:
+ if key in section_data:
+ section_result[key] = section_data[key]
+
+ # 处理 hint
+ if "hint" in section_data:
+ section_result["hint"] = f"{group_key}.{section_key}.hint"
+
+ # 处理 items 中的字段
+ if "items" in section_data and isinstance(section_data["items"], dict):
+ items_result = {}
+ for field_key, field_data in section_data["items"].items():
+ # 处理嵌套的点号字段名(如 provider_settings.enable)
+ field_name = field_key
+
+ field_result = {}
+
+ # 复制基本属性
+ for attr in [
+ "type",
+ "condition",
+ "_special",
+ "invisible",
+ "options",
+ "slider",
+ ]:
+ if attr in field_data:
+ field_result[attr] = field_data[attr]
+
+ # 转换文本属性为国际化键
+ if "description" in field_data:
+ field_result["description"] = (
+ f"{group_key}.{section_key}.{field_name}.description"
+ )
+
+ if "hint" in field_data:
+ field_result["hint"] = (
+ f"{group_key}.{section_key}.{field_name}.hint"
+ )
+
+ if "labels" in field_data:
+ field_result["labels"] = (
+ f"{group_key}.{section_key}.{field_name}.labels"
+ )
+
+ items_result[field_key] = field_result
+
+ section_result["items"] = items_result
+
+ group_result["metadata"][section_key] = section_result
+
+ result[group_key] = group_result
+
+ return result
diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py
index 2a6ac4273..5a8672837 100644
--- a/astrbot/core/core_lifecycle.py
+++ b/astrbot/core/core_lifecycle.py
@@ -16,12 +16,12 @@ import time
import traceback
from asyncio import Queue
-from astrbot.core import LogBroker, logger, sp
+from astrbot.api import logger, sp
+from astrbot.core import LogBroker
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.db import BaseDatabase
-from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
@@ -33,6 +33,7 @@ from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
+from astrbot.core.utils.migra_helper import migra
from . import astrbot_config, html_renderer
from .event_bus import EventBus
@@ -96,11 +97,16 @@ class AstrBotCoreLifecycle:
sp=sp,
)
- # 4.5 to 4.6 migration for umop_config_router
+ # apply migration
try:
- await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
+ await migra(
+ self.db,
+ self.astrbot_config_mgr,
+ self.umop_config_router,
+ self.astrbot_config_mgr,
+ )
except Exception as e:
- logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
+ logger.error(f"AstrBot migration failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列
@@ -191,7 +197,7 @@ class AstrBotCoreLifecycle:
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
for task in self.star_context._register_tasks:
- extra_tasks.append(asyncio.create_task(task, name=task.__name__))
+ extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_:
diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py
index c62e49289..192c7b263 100644
--- a/astrbot/core/db/__init__.py
+++ b/astrbot/core/db/__init__.py
@@ -5,14 +5,16 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass
from deprecated import deprecated
-from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core.db.po import (
Attachment,
+ CommandConfig,
+ CommandConflict,
ConversationV2,
Persona,
PlatformMessageHistory,
+ PlatformSession,
PlatformStat,
Preference,
Stats,
@@ -31,7 +33,7 @@ class BaseDatabase(abc.ABC):
echo=False,
future=True,
)
- self.AsyncSessionLocal = sessionmaker(
+ self.AsyncSessionLocal = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
@@ -172,7 +174,7 @@ class BaseDatabase(abc.ABC):
content: dict,
sender_id: str | None = None,
sender_name: str | None = None,
- ) -> None:
+ ) -> PlatformMessageHistory:
"""Insert a new platform message history record."""
...
@@ -183,7 +185,7 @@ class BaseDatabase(abc.ABC):
user_id: str,
offset_sec: int = 86400,
) -> None:
- """Delete platform message history records older than the specified offset."""
+ """Delete platform message history records newer than the specified offset."""
...
@abc.abstractmethod
@@ -197,6 +199,14 @@ class BaseDatabase(abc.ABC):
"""Get platform message history for a specific user."""
...
+ @abc.abstractmethod
+ async def get_platform_message_history_by_id(
+ self,
+ message_id: int,
+ ) -> PlatformMessageHistory | None:
+ """Get a platform message history record by its ID."""
+ ...
+
@abc.abstractmethod
async def insert_attachment(
self,
@@ -212,6 +222,27 @@ class BaseDatabase(abc.ABC):
"""Get an attachment by its ID."""
...
+ @abc.abstractmethod
+ async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]:
+ """Get multiple attachments by their IDs."""
+ ...
+
+ @abc.abstractmethod
+ async def delete_attachment(self, attachment_id: str) -> bool:
+ """Delete an attachment by its ID.
+
+ Returns True if the attachment was deleted, False if it was not found.
+ """
+ ...
+
+ @abc.abstractmethod
+ async def delete_attachments(self, attachment_ids: list[str]) -> int:
+ """Delete multiple attachments by their IDs.
+
+ Returns the number of attachments deleted.
+ """
+ ...
+
@abc.abstractmethod
async def insert_persona(
self,
@@ -285,6 +316,76 @@ class BaseDatabase(abc.ABC):
"""Clear all preferences for a specific scope ID."""
...
+ @abc.abstractmethod
+ async def get_command_configs(self) -> list[CommandConfig]:
+ """Get all stored command configurations."""
+ ...
+
+ @abc.abstractmethod
+ async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
+ """Fetch a single command configuration by handler."""
+ ...
+
+ @abc.abstractmethod
+ async def upsert_command_config(
+ self,
+ handler_full_name: str,
+ plugin_name: str,
+ module_path: str,
+ original_command: str,
+ *,
+ resolved_command: str | None = None,
+ enabled: bool | None = None,
+ keep_original_alias: bool | None = None,
+ conflict_key: str | None = None,
+ resolution_strategy: str | None = None,
+ note: str | None = None,
+ extra_data: dict | None = None,
+ auto_managed: bool | None = None,
+ ) -> CommandConfig:
+ """Create or update a command configuration."""
+ ...
+
+ @abc.abstractmethod
+ async def delete_command_config(self, handler_full_name: str) -> None:
+ """Delete a single command configuration."""
+ ...
+
+ @abc.abstractmethod
+ async def delete_command_configs(self, handler_full_names: list[str]) -> None:
+ """Bulk delete command configurations."""
+ ...
+
+ @abc.abstractmethod
+ async def list_command_conflicts(
+ self,
+ status: str | None = None,
+ ) -> list[CommandConflict]:
+ """List recorded command conflict entries."""
+ ...
+
+ @abc.abstractmethod
+ async def upsert_command_conflict(
+ self,
+ conflict_key: str,
+ handler_full_name: str,
+ plugin_name: str,
+ *,
+ status: str | None = None,
+ resolution: str | None = None,
+ resolved_command: str | None = None,
+ note: str | None = None,
+ extra_data: dict | None = None,
+ auto_generated: bool | None = None,
+ ) -> CommandConflict:
+ """Create or update a conflict record."""
+ ...
+
+ @abc.abstractmethod
+ async def delete_command_conflicts(self, ids: list[int]) -> None:
+ """Delete conflict records."""
+ ...
+
# @abc.abstractmethod
# async def insert_llm_message(
# self,
@@ -313,3 +414,51 @@ class BaseDatabase(abc.ABC):
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
...
+
+ # ====
+ # Platform Session Management
+ # ====
+
+ @abc.abstractmethod
+ async def create_platform_session(
+ self,
+ creator: str,
+ platform_id: str = "webchat",
+ session_id: str | None = None,
+ display_name: str | None = None,
+ is_group: int = 0,
+ ) -> PlatformSession:
+ """Create a new Platform session."""
+ ...
+
+ @abc.abstractmethod
+ async def get_platform_session_by_id(
+ self, session_id: str
+ ) -> PlatformSession | None:
+ """Get a Platform session by its ID."""
+ ...
+
+ @abc.abstractmethod
+ async def get_platform_sessions_by_creator(
+ self,
+ creator: str,
+ platform_id: str | None = None,
+ page: int = 1,
+ page_size: int = 20,
+ ) -> list[PlatformSession]:
+ """Get all Platform sessions for a specific creator (username) and optionally platform."""
+ ...
+
+ @abc.abstractmethod
+ async def update_platform_session(
+ self,
+ session_id: str,
+ display_name: str | None = None,
+ ) -> None:
+ """Update a Platform session's updated_at timestamp and optionally display_name."""
+ ...
+
+ @abc.abstractmethod
+ async def delete_platform_session(self, session_id: str) -> None:
+ """Delete a Platform session by its ID."""
+ ...
diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py
index a75c60a1b..66b72d5cb 100644
--- a/astrbot/core/db/migration/migra_3_to_4.py
+++ b/astrbot/core/db/migration/migra_3_to_4.py
@@ -70,6 +70,7 @@ async def migration_conversation_table(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
+ continue
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
@@ -207,6 +208,7 @@ async def migration_webchat_data(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
+ continue
if ":" in conv.user_id:
continue
platform_id = "webchat"
diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py
new file mode 100644
index 000000000..ff0b5ca6f
--- /dev/null
+++ b/astrbot/core/db/migration/migra_webchat_session.py
@@ -0,0 +1,131 @@
+"""Migration script for WebChat sessions.
+
+This migration creates PlatformSession from existing platform_message_history records.
+
+Changes:
+- Creates platform_sessions table
+- Adds platform_id field (default: 'webchat')
+- Adds display_name field
+- Session_id format: {platform_id}_{uuid}
+"""
+
+from sqlalchemy import func, select
+from sqlmodel import col
+
+from astrbot.api import logger, sp
+from astrbot.core.db import BaseDatabase
+from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
+
+
+async def migrate_webchat_session(db_helper: BaseDatabase):
+ """Create PlatformSession records from platform_message_history.
+
+ This migration extracts all unique user_ids from platform_message_history
+ where platform_id='webchat' and creates corresponding PlatformSession records.
+ """
+ # 检查是否已经完成迁移
+ migration_done = await db_helper.get_preference(
+ "global", "global", "migration_done_webchat_session_1"
+ )
+ if migration_done:
+ return
+
+ logger.info("开始执行数据库迁移(WebChat 会话迁移)...")
+
+ try:
+ async with db_helper.get_db() as session:
+ # 从 platform_message_history 创建 PlatformSession
+ query = (
+ select(
+ col(PlatformMessageHistory.user_id),
+ col(PlatformMessageHistory.sender_name),
+ func.min(PlatformMessageHistory.created_at).label("earliest"),
+ func.max(PlatformMessageHistory.updated_at).label("latest"),
+ )
+ .where(col(PlatformMessageHistory.platform_id) == "webchat")
+ .where(col(PlatformMessageHistory.sender_id) != "bot")
+ .group_by(col(PlatformMessageHistory.user_id))
+ )
+
+ result = await session.execute(query)
+ webchat_users = result.all()
+
+ if not webchat_users:
+ logger.info("没有找到需要迁移的 WebChat 数据")
+ await sp.put_async(
+ "global", "global", "migration_done_webchat_session_1", True
+ )
+ return
+
+ logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移")
+
+ # 检查已存在的会话
+ existing_query = select(col(PlatformSession.session_id))
+ existing_result = await session.execute(existing_query)
+ existing_session_ids = {row[0] for row in existing_result.fetchall()}
+
+ # 查询 Conversations 表中的 title,用于设置 display_name
+ # 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id}
+ user_ids_to_query = [
+ f"webchat:FriendMessage:webchat!astrbot!{user_id}"
+ for user_id, _, _, _ in webchat_users
+ ]
+ conv_query = select(
+ col(ConversationV2.user_id), col(ConversationV2.title)
+ ).where(col(ConversationV2.user_id).in_(user_ids_to_query))
+ conv_result = await session.execute(conv_query)
+ # 创建 user_id -> title 的映射字典
+ title_map = {
+ user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title
+ for user_id, title in conv_result.fetchall()
+ }
+
+ # 批量创建 PlatformSession 记录
+ sessions_to_add = []
+ skipped_count = 0
+
+ for user_id, sender_name, created_at, updated_at in webchat_users:
+ # user_id 就是 webchat_conv_id (session_id)
+ session_id = user_id
+
+ # sender_name 通常是 username,但可能为 None
+ creator = sender_name if sender_name else "guest"
+
+ # 检查是否已经存在该会话
+ if session_id in existing_session_ids:
+ logger.debug(f"会话 {session_id} 已存在,跳过")
+ skipped_count += 1
+ continue
+
+ # 从 Conversations 表中获取 display_name
+ display_name = title_map.get(user_id)
+
+ # 创建新的 PlatformSession(保留原有的时间戳)
+ new_session = PlatformSession(
+ session_id=session_id,
+ platform_id="webchat",
+ creator=creator,
+ is_group=0,
+ created_at=created_at,
+ updated_at=updated_at,
+ display_name=display_name,
+ )
+ sessions_to_add.append(new_session)
+
+ # 批量插入
+ if sessions_to_add:
+ session.add_all(sessions_to_add)
+ await session.commit()
+
+ logger.info(
+ f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}",
+ )
+ else:
+ logger.info("没有新会话需要迁移")
+
+ # 标记迁移完成
+ await sp.put_async("global", "global", "migration_done_webchat_session_1", True)
+
+ except Exception as e:
+ logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
+ raise
diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py
index a301028d1..b1a780d48 100644
--- a/astrbot/core/db/migration/sqlite_v3.py
+++ b/astrbot/core/db/migration/sqlite_v3.py
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str
return conn
- def _exec_sql(self, sql: str, params: tuple = None):
+ def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn
try:
c = self.conn.cursor()
@@ -224,9 +224,11 @@ class SQLiteDatabase:
c.close()
- return Stats(platform, [], [])
+ return Stats(platform)
- def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
+ def get_conversation_by_user_id(
+ self, user_id: str, cid: str
+ ) -> Conversation | None:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -258,7 +260,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at),
)
- def get_conversations(self, user_id: str) -> tuple:
+ def get_conversations(self, user_id: str) -> list[Conversation]:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py
index 1e7245976..64bcf4ce3 100644
--- a/astrbot/core/db/po.py
+++ b/astrbot/core/db/po.py
@@ -3,13 +3,7 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TypedDict
-from sqlmodel import (
- JSON,
- Field,
- SQLModel,
- Text,
- UniqueConstraint,
-)
+from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
class PlatformStat(SQLModel, table=True):
@@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here.
"""
- __tablename__ = "platform_stats"
+ __tablename__: str = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
@@ -37,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
- __tablename__ = "conversations"
+ __tablename__: str = "conversations"
- inner_conversation_id: int = Field(
+ inner_conversation_id: int | None = Field(
+ default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
@@ -74,7 +69,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs.
"""
- __tablename__ = "personas"
+ __tablename__: str = "personas"
id: int | None = Field(
primary_key=True,
@@ -104,7 +99,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
- __tablename__ = "preferences"
+ __tablename__: str = "preferences"
id: int | None = Field(
default=None,
@@ -140,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages.
"""
- __tablename__ = "platform_message_history"
+ __tablename__: str = "platform_message_history"
id: int | None = Field(
primary_key=True,
@@ -161,13 +156,55 @@ class PlatformMessageHistory(SQLModel, table=True):
)
+class PlatformSession(SQLModel, table=True):
+ """Platform session table for managing user sessions across different platforms.
+
+ A session represents a chat window for a specific user on a specific platform.
+ Each session can have multiple conversations (对话) associated with it.
+ """
+
+ __tablename__: str = "platform_sessions"
+
+ inner_id: int | None = Field(
+ primary_key=True,
+ sa_column_kwargs={"autoincrement": True},
+ default=None,
+ )
+ session_id: str = Field(
+ max_length=100,
+ nullable=False,
+ unique=True,
+ default_factory=lambda: str(uuid.uuid4()),
+ )
+ platform_id: str = Field(default="webchat", nullable=False)
+ """Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
+ creator: str = Field(nullable=False)
+ """Username of the session creator"""
+ display_name: str | None = Field(default=None, max_length=255)
+ """Display name for the session"""
+ is_group: int = Field(default=0, nullable=False)
+ """0 for private chat, 1 for group chat (not implemented yet)"""
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = Field(
+ default_factory=lambda: datetime.now(timezone.utc),
+ sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
+ )
+
+ __table_args__ = (
+ UniqueConstraint(
+ "session_id",
+ name="uix_platform_session_id",
+ ),
+ )
+
+
class Attachment(SQLModel, table=True):
"""This class represents attachments for messages in AstrBot.
Attachments can be images, files, or other media types.
"""
- __tablename__ = "attachments"
+ __tablename__: str = "attachments"
inner_attachment_id: int | None = Field(
primary_key=True,
@@ -197,6 +234,65 @@ class Attachment(SQLModel, table=True):
)
+class CommandConfig(SQLModel, table=True):
+ """Per-command configuration overrides for dashboard management."""
+
+ __tablename__ = "command_configs" # type: ignore
+
+ handler_full_name: str = Field(
+ primary_key=True,
+ max_length=512,
+ )
+ plugin_name: str = Field(nullable=False, max_length=255)
+ module_path: str = Field(nullable=False, max_length=255)
+ original_command: str = Field(nullable=False, max_length=255)
+ resolved_command: str | None = Field(default=None, max_length=255)
+ enabled: bool = Field(default=True, nullable=False)
+ keep_original_alias: bool = Field(default=False, nullable=False)
+ conflict_key: str | None = Field(default=None, max_length=255)
+ resolution_strategy: str | None = Field(default=None, max_length=64)
+ note: str | None = Field(default=None, sa_type=Text)
+ extra_data: dict | None = Field(default=None, sa_type=JSON)
+ auto_managed: bool = Field(default=False, nullable=False)
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = Field(
+ default_factory=lambda: datetime.now(timezone.utc),
+ sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
+ )
+
+
+class CommandConflict(SQLModel, table=True):
+ """Conflict tracking for duplicated command names."""
+
+ __tablename__ = "command_conflicts" # type: ignore
+
+ id: int | None = Field(
+ default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
+ )
+ conflict_key: str = Field(nullable=False, max_length=255)
+ handler_full_name: str = Field(nullable=False, max_length=512)
+ plugin_name: str = Field(nullable=False, max_length=255)
+ status: str = Field(default="pending", max_length=32)
+ resolution: str | None = Field(default=None, max_length=64)
+ resolved_command: str | None = Field(default=None, max_length=255)
+ note: str | None = Field(default=None, sa_type=Text)
+ extra_data: dict | None = Field(default=None, sa_type=JSON)
+ auto_generated: bool = Field(default=False, nullable=False)
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = Field(
+ default_factory=lambda: datetime.now(timezone.utc),
+ sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
+ )
+
+ __table_args__ = (
+ UniqueConstraint(
+ "conflict_key",
+ "handler_full_name",
+ name="uix_conflict_handler",
+ ),
+ )
+
+
@dataclass
class Conversation:
"""LLM 对话类
@@ -225,17 +321,17 @@ class Personality(TypedDict):
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
- prompt: str = ""
- name: str = ""
- begin_dialogs: list[str] = []
- mood_imitation_dialogs: list[str] = []
+ prompt: str
+ name: str
+ begin_dialogs: list[str]
+ mood_imitation_dialogs: list[str]
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
- tools: list[str] | None = None
+ tools: list[str] | None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
- _begin_dialogs_processed: list[dict] = []
- _mood_imitation_dialogs_processed: str = ""
+ _begin_dialogs_processed: list[dict]
+ _mood_imitation_dialogs_processed: str
# ====
diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py
index 457a4ab3f..fa3ca9a76 100644
--- a/astrbot/core/db/sqlite.py
+++ b/astrbot/core/db/sqlite.py
@@ -1,17 +1,22 @@
import asyncio
import threading
import typing as T
-from datetime import datetime, timedelta
+from collections.abc import Awaitable, Callable
+from datetime import datetime, timedelta, timezone
+from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import (
Attachment,
+ CommandConfig,
+ CommandConflict,
ConversationV2,
Persona,
PlatformMessageHistory,
+ PlatformSession,
PlatformStat,
Preference,
SQLModel,
@@ -24,6 +29,7 @@ from astrbot.core.db.po import (
)
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
+TxResult = T.TypeVar("TxResult")
class SQLiteDatabase(BaseDatabase):
@@ -104,8 +110,8 @@ class SQLiteDatabase(BaseDatabase):
text("""
SELECT * FROM platform_stats
WHERE timestamp >= :start_time
- ORDER BY timestamp DESC
GROUP BY platform_id
+ ORDER BY timestamp DESC
"""),
{"start_time": start_time},
)
@@ -412,7 +418,7 @@ class SQLiteDatabase(BaseDatabase):
user_id,
offset_sec=86400,
):
- """Delete platform message history records older than the specified offset."""
+ """Delete platform message history records newer than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -422,7 +428,7 @@ class SQLiteDatabase(BaseDatabase):
delete(PlatformMessageHistory).where(
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
- col(PlatformMessageHistory.created_at) < cutoff_time,
+ col(PlatformMessageHistory.created_at) >= cutoff_time,
),
)
@@ -448,6 +454,18 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
+ async def get_platform_message_history_by_id(
+ self, message_id: int
+ ) -> PlatformMessageHistory | None:
+ """Get a platform message history record by its ID."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ query = select(PlatformMessageHistory).where(
+ PlatformMessageHistory.id == message_id
+ )
+ result = await session.execute(query)
+ return result.scalar_one_or_none()
+
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
@@ -469,6 +487,48 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalar_one_or_none()
+ async def get_attachments(self, attachment_ids: list[str]) -> list:
+ """Get multiple attachments by their IDs."""
+ if not attachment_ids:
+ return []
+ async with self.get_db() as session:
+ session: AsyncSession
+ query = select(Attachment).where(
+ col(Attachment.attachment_id).in_(attachment_ids)
+ )
+ result = await session.execute(query)
+ return list(result.scalars().all())
+
+ async def delete_attachment(self, attachment_id: str) -> bool:
+ """Delete an attachment by its ID.
+
+ Returns True if the attachment was deleted, False if it was not found.
+ """
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ query = delete(Attachment).where(
+ col(Attachment.attachment_id) == attachment_id
+ )
+ result = T.cast(CursorResult, await session.execute(query))
+ return result.rowcount > 0
+
+ async def delete_attachments(self, attachment_ids: list[str]) -> int:
+ """Delete multiple attachments by their IDs.
+
+ Returns the number of attachments deleted.
+ """
+ if not attachment_ids:
+ return 0
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ query = delete(Attachment).where(
+ col(Attachment.attachment_id).in_(attachment_ids)
+ )
+ result = T.cast(CursorResult, await session.execute(query))
+ return result.rowcount
+
async def insert_persona(
self,
persona_id,
@@ -614,6 +674,242 @@ class SQLiteDatabase(BaseDatabase):
)
await session.commit()
+ # ====
+ # Command Configuration & Conflict Tracking
+ # ====
+
+ async def _run_in_tx(
+ self,
+ fn: Callable[[AsyncSession], Awaitable[TxResult]],
+ ) -> TxResult:
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ return await fn(session)
+
+ @staticmethod
+ def _apply_updates(model, **updates) -> None:
+ for field, value in updates.items():
+ if value is not None:
+ setattr(model, field, value)
+
+ @staticmethod
+ def _new_command_config(
+ handler_full_name: str,
+ plugin_name: str,
+ module_path: str,
+ original_command: str,
+ *,
+ resolved_command: str | None = None,
+ enabled: bool | None = None,
+ keep_original_alias: bool | None = None,
+ conflict_key: str | None = None,
+ resolution_strategy: str | None = None,
+ note: str | None = None,
+ extra_data: dict | None = None,
+ auto_managed: bool | None = None,
+ ) -> CommandConfig:
+ return CommandConfig(
+ handler_full_name=handler_full_name,
+ plugin_name=plugin_name,
+ module_path=module_path,
+ original_command=original_command,
+ resolved_command=resolved_command,
+ enabled=True if enabled is None else enabled,
+ keep_original_alias=False
+ if keep_original_alias is None
+ else keep_original_alias,
+ conflict_key=conflict_key or original_command,
+ resolution_strategy=resolution_strategy,
+ note=note,
+ extra_data=extra_data,
+ auto_managed=bool(auto_managed),
+ )
+
+ @staticmethod
+ def _new_command_conflict(
+ conflict_key: str,
+ handler_full_name: str,
+ plugin_name: str,
+ *,
+ status: str | None = None,
+ resolution: str | None = None,
+ resolved_command: str | None = None,
+ note: str | None = None,
+ extra_data: dict | None = None,
+ auto_generated: bool | None = None,
+ ) -> CommandConflict:
+ return CommandConflict(
+ conflict_key=conflict_key,
+ handler_full_name=handler_full_name,
+ plugin_name=plugin_name,
+ status=status or "pending",
+ resolution=resolution,
+ resolved_command=resolved_command,
+ note=note,
+ extra_data=extra_data,
+ auto_generated=bool(auto_generated),
+ )
+
+ async def get_command_configs(self) -> list[CommandConfig]:
+ async with self.get_db() as session:
+ session: AsyncSession
+ result = await session.execute(select(CommandConfig))
+ return list(result.scalars().all())
+
+ async def get_command_config(
+ self,
+ handler_full_name: str,
+ ) -> CommandConfig | None:
+ async with self.get_db() as session:
+ session: AsyncSession
+ return await session.get(CommandConfig, handler_full_name)
+
+ async def upsert_command_config(
+ self,
+ handler_full_name: str,
+ plugin_name: str,
+ module_path: str,
+ original_command: str,
+ *,
+ resolved_command: str | None = None,
+ enabled: bool | None = None,
+ keep_original_alias: bool | None = None,
+ conflict_key: str | None = None,
+ resolution_strategy: str | None = None,
+ note: str | None = None,
+ extra_data: dict | None = None,
+ auto_managed: bool | None = None,
+ ) -> CommandConfig:
+ async def _op(session: AsyncSession) -> CommandConfig:
+ config = await session.get(CommandConfig, handler_full_name)
+ if not config:
+ config = self._new_command_config(
+ handler_full_name,
+ plugin_name,
+ module_path,
+ original_command,
+ resolved_command=resolved_command,
+ enabled=enabled,
+ keep_original_alias=keep_original_alias,
+ conflict_key=conflict_key,
+ resolution_strategy=resolution_strategy,
+ note=note,
+ extra_data=extra_data,
+ auto_managed=auto_managed,
+ )
+ session.add(config)
+ else:
+ self._apply_updates(
+ config,
+ plugin_name=plugin_name,
+ module_path=module_path,
+ original_command=original_command,
+ resolved_command=resolved_command,
+ enabled=enabled,
+ keep_original_alias=keep_original_alias,
+ conflict_key=conflict_key,
+ resolution_strategy=resolution_strategy,
+ note=note,
+ extra_data=extra_data,
+ auto_managed=auto_managed,
+ )
+ await session.flush()
+ await session.refresh(config)
+ return config
+
+ return await self._run_in_tx(_op)
+
+ async def delete_command_config(self, handler_full_name: str) -> None:
+ await self.delete_command_configs([handler_full_name])
+
+ async def delete_command_configs(self, handler_full_names: list[str]) -> None:
+ if not handler_full_names:
+ return
+
+ async def _op(session: AsyncSession) -> None:
+ await session.execute(
+ delete(CommandConfig).where(
+ col(CommandConfig.handler_full_name).in_(handler_full_names),
+ ),
+ )
+
+ await self._run_in_tx(_op)
+
+ async def list_command_conflicts(
+ self,
+ status: str | None = None,
+ ) -> list[CommandConflict]:
+ async with self.get_db() as session:
+ session: AsyncSession
+ query = select(CommandConflict)
+ if status:
+ query = query.where(CommandConflict.status == status)
+ result = await session.execute(query)
+ return list(result.scalars().all())
+
+ async def upsert_command_conflict(
+ self,
+ conflict_key: str,
+ handler_full_name: str,
+ plugin_name: str,
+ *,
+ status: str | None = None,
+ resolution: str | None = None,
+ resolved_command: str | None = None,
+ note: str | None = None,
+ extra_data: dict | None = None,
+ auto_generated: bool | None = None,
+ ) -> CommandConflict:
+ async def _op(session: AsyncSession) -> CommandConflict:
+ result = await session.execute(
+ select(CommandConflict).where(
+ CommandConflict.conflict_key == conflict_key,
+ CommandConflict.handler_full_name == handler_full_name,
+ ),
+ )
+ record = result.scalar_one_or_none()
+ if not record:
+ record = self._new_command_conflict(
+ conflict_key,
+ handler_full_name,
+ plugin_name,
+ status=status,
+ resolution=resolution,
+ resolved_command=resolved_command,
+ note=note,
+ extra_data=extra_data,
+ auto_generated=auto_generated,
+ )
+ session.add(record)
+ else:
+ self._apply_updates(
+ record,
+ plugin_name=plugin_name,
+ status=status,
+ resolution=resolution,
+ resolved_command=resolved_command,
+ note=note,
+ extra_data=extra_data,
+ auto_generated=auto_generated,
+ )
+ await session.flush()
+ await session.refresh(record)
+ return record
+
+ return await self._run_in_tx(_op)
+
+ async def delete_command_conflicts(self, ids: list[int]) -> None:
+ if not ids:
+ return
+
+ async def _op(session: AsyncSession) -> None:
+ await session.execute(
+ delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
+ )
+
+ await self._run_in_tx(_op)
+
# ====
# Deprecated Methods
# ====
@@ -709,3 +1005,101 @@ class SQLiteDatabase(BaseDatabase):
t.start()
t.join()
return result
+
+ # ====
+ # Platform Session Management
+ # ====
+
+ async def create_platform_session(
+ self,
+ creator: str,
+ platform_id: str = "webchat",
+ session_id: str | None = None,
+ display_name: str | None = None,
+ is_group: int = 0,
+ ) -> PlatformSession:
+ """Create a new Platform session."""
+ kwargs = {}
+ if session_id:
+ kwargs["session_id"] = session_id
+
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ new_session = PlatformSession(
+ creator=creator,
+ platform_id=platform_id,
+ display_name=display_name,
+ is_group=is_group,
+ **kwargs,
+ )
+ session.add(new_session)
+ await session.flush()
+ await session.refresh(new_session)
+ return new_session
+
+ async def get_platform_session_by_id(
+ self, session_id: str
+ ) -> PlatformSession | None:
+ """Get a Platform session by its ID."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ query = select(PlatformSession).where(
+ PlatformSession.session_id == session_id,
+ )
+ result = await session.execute(query)
+ return result.scalar_one_or_none()
+
+ async def get_platform_sessions_by_creator(
+ self,
+ creator: str,
+ platform_id: str | None = None,
+ page: int = 1,
+ page_size: int = 20,
+ ) -> list[PlatformSession]:
+ """Get all Platform sessions for a specific creator (username) and optionally platform."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ offset = (page - 1) * page_size
+ query = select(PlatformSession).where(PlatformSession.creator == creator)
+
+ if platform_id:
+ query = query.where(PlatformSession.platform_id == platform_id)
+
+ query = (
+ query.order_by(desc(PlatformSession.updated_at))
+ .offset(offset)
+ .limit(page_size)
+ )
+ result = await session.execute(query)
+ return list(result.scalars().all())
+
+ async def update_platform_session(
+ self,
+ session_id: str,
+ display_name: str | None = None,
+ ) -> None:
+ """Update a Platform session's updated_at timestamp and optionally display_name."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
+ if display_name is not None:
+ values["display_name"] = display_name
+
+ await session.execute(
+ update(PlatformSession)
+ .where(col(PlatformSession.session_id) == session_id)
+ .values(**values),
+ )
+
+ async def delete_platform_session(self, session_id: str) -> None:
+ """Delete a Platform session by its ID."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ await session.execute(
+ delete(PlatformSession).where(
+ col(PlatformSession.session_id) == session_id,
+ ),
+ )
diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
index 24f1c323c..564454cb1 100644
--- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
+++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
@@ -90,4 +90,6 @@ class EmbeddingStorage:
path (str): 保存索引的路径
"""
+ if self.index is None:
+ return
faiss.write_index(self.index, self.path)
diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py
index 749df753e..0017e65fa 100644
--- a/astrbot/core/event_bus.py
+++ b/astrbot/core/event_bus.py
@@ -27,7 +27,7 @@ class EventBus:
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
- astrbot_config_mgr: AstrBotConfigManager = None,
+ astrbot_config_mgr: AstrBotConfigManager,
):
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
@@ -40,6 +40,11 @@ class EventBus:
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
+ if not scheduler:
+ logger.error(
+ f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
+ )
+ continue
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str):
diff --git a/astrbot/core/exceptions.py b/astrbot/core/exceptions.py
new file mode 100644
index 000000000..e637d4930
--- /dev/null
+++ b/astrbot/core/exceptions.py
@@ -0,0 +1,9 @@
+from __future__ import annotations
+
+
+class AstrBotError(Exception):
+ """Base exception for all AstrBot errors."""
+
+
+class ProviderNotFoundError(AstrBotError):
+ """Raised when a specified provider is not found."""
diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py
index b03b00369..4adfb60b8 100644
--- a/astrbot/core/knowledge_base/kb_helper.py
+++ b/astrbot/core/knowledge_base/kb_helper.py
@@ -1,4 +1,7 @@
+import asyncio
import json
+import re
+import time
import uuid
from pathlib import Path
@@ -8,12 +11,98 @@ from astrbot.core import logger
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.manager import ProviderManager
-from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
+from astrbot.core.provider.provider import (
+ EmbeddingProvider,
+ RerankProvider,
+)
+from astrbot.core.provider.provider import (
+ Provider as LLMProvider,
+)
from .chunking.base import BaseChunker
+from .chunking.recursive import RecursiveCharacterChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .models import KBDocument, KBMedia, KnowledgeBase
+from .parsers.url_parser import extract_text_from_url
from .parsers.util import select_parser
+from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
+
+
+class RateLimiter:
+ """一个简单的速率限制器"""
+
+ def __init__(self, max_rpm: int):
+ self.max_per_minute = max_rpm
+ self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
+ self.last_call_time = 0
+
+ async def __aenter__(self):
+ if self.interval == 0:
+ return
+
+ now = time.monotonic()
+ elapsed = now - self.last_call_time
+
+ if elapsed < self.interval:
+ await asyncio.sleep(self.interval - elapsed)
+
+ self.last_call_time = time.monotonic()
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+async def _repair_and_translate_chunk_with_retry(
+ chunk: str,
+ repair_llm_service: LLMProvider,
+ rate_limiter: RateLimiter,
+ max_retries: int = 2,
+) -> list[str]:
+ """
+ Repairs, translates, and optionally re-chunks a single text chunk using the small LLM, with rate limiting.
+ """
+ # 为了防止 LLM 上下文污染,在 user_prompt 中也加入明确的指令
+ user_prompt = f"""IGNORE ALL PREVIOUS INSTRUCTIONS. Your ONLY task is to process the following text chunk according to the system prompt provided.
+
+Text chunk to process:
+---
+{chunk}
+---
+"""
+ for attempt in range(max_retries + 1):
+ try:
+ async with rate_limiter:
+ response = await repair_llm_service.text_chat(
+ prompt=user_prompt, system_prompt=TEXT_REPAIR_SYSTEM_PROMPT
+ )
+
+ llm_output = response.completion_text
+
+ if "
" in llm_output:
+ return [] # Signal to discard this chunk
+
+ # More robust regex to handle potential LLM formatting errors (spaces, newlines in tags)
+ matches = re.findall(
+ r"<\s*repaired_text\s*>\s*(.*?)\s*<\s*/\s*repaired_text\s*>",
+ llm_output,
+ re.DOTALL,
+ )
+
+ if matches:
+ # Further cleaning to ensure no empty strings are returned
+ return [m.strip() for m in matches if m.strip()]
+ else:
+ # If no valid tags and not explicitly discarded, discard it to be safe.
+ return []
+ except Exception as e:
+ logger.warning(
+ f" - LLM call failed on attempt {attempt + 1}/{max_retries + 1}. Error: {str(e)}"
+ )
+
+ logger.error(
+ f" - Failed to process chunk after {max_retries + 1} attempts. Using original text."
+ )
+ return [chunk]
class KBHelper:
@@ -100,7 +189,7 @@ class KBHelper:
async def upload_document(
self,
file_name: str,
- file_content: bytes,
+ file_content: bytes | None,
file_type: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
@@ -108,6 +197,7 @@ class KBHelper:
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
+ pre_chunked_text: list[str] | None = None,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
@@ -130,46 +220,63 @@ class KBHelper:
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
media_paths: list[Path] = []
+ file_size = 0
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
# async with aiofiles.open(file_path, "wb") as f:
# await f.write(file_content)
try:
- # 阶段1: 解析文档
- if progress_callback:
- await progress_callback("parsing", 0, 100)
-
- parser = await select_parser(f".{file_type}")
- parse_result = await parser.parse(file_content, file_name)
- text_content = parse_result.text
- media_items = parse_result.media
-
- if progress_callback:
- await progress_callback("parsing", 100, 100)
-
- # 保存媒体文件
+ chunks_text = []
saved_media = []
- for media_item in media_items:
- media = await self._save_media(
- doc_id=doc_id,
- media_type=media_item.media_type,
- file_name=media_item.file_name,
- content=media_item.content,
- mime_type=media_item.mime_type,
+
+ if pre_chunked_text is not None:
+ # 如果提供了预分块文本,直接使用
+ chunks_text = pre_chunked_text
+ file_size = sum(len(chunk) for chunk in chunks_text)
+ logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。")
+ else:
+ # 否则,执行标准的文件解析和分块流程
+ if file_content is None:
+ raise ValueError(
+ "当未提供 pre_chunked_text 时,file_content 不能为空。"
+ )
+
+ file_size = len(file_content)
+
+ # 阶段1: 解析文档
+ if progress_callback:
+ await progress_callback("parsing", 0, 100)
+
+ parser = await select_parser(f".{file_type}")
+ parse_result = await parser.parse(file_content, file_name)
+ text_content = parse_result.text
+ media_items = parse_result.media
+
+ if progress_callback:
+ await progress_callback("parsing", 100, 100)
+
+ # 保存媒体文件
+ for media_item in media_items:
+ media = await self._save_media(
+ doc_id=doc_id,
+ media_type=media_item.media_type,
+ file_name=media_item.file_name,
+ content=media_item.content,
+ mime_type=media_item.mime_type,
+ )
+ saved_media.append(media)
+ media_paths.append(Path(media.file_path))
+
+ # 阶段2: 分块
+ if progress_callback:
+ await progress_callback("chunking", 0, 100)
+
+ chunks_text = await self.chunker.chunk(
+ text_content,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
)
- saved_media.append(media)
- media_paths.append(Path(media.file_path))
-
- # 阶段2: 分块
- if progress_callback:
- await progress_callback("chunking", 0, 100)
-
- chunks_text = await self.chunker.chunk(
- text_content,
- chunk_size=chunk_size,
- chunk_overlap=chunk_overlap,
- )
contents = []
metadatas = []
for idx, chunk_text in enumerate(chunks_text):
@@ -205,7 +312,7 @@ class KBHelper:
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
- file_size=len(file_content),
+ file_size=file_size,
# file_path=str(file_path),
file_path="",
chunk_count=len(chunks_text),
@@ -359,3 +466,177 @@ class KBHelper:
)
return media
+
+ async def upload_from_url(
+ self,
+ url: str,
+ chunk_size: int = 512,
+ chunk_overlap: int = 50,
+ batch_size: int = 32,
+ tasks_limit: int = 3,
+ max_retries: int = 3,
+ progress_callback=None,
+ enable_cleaning: bool = False,
+ cleaning_provider_id: str | None = None,
+ ) -> KBDocument:
+ """从 URL 上传并处理文档(带原子性保证和失败清理)
+ Args:
+ url: 要提取内容的网页 URL
+ chunk_size: 文本块大小
+ chunk_overlap: 文本块重叠大小
+ batch_size: 批处理大小
+ tasks_limit: 并发任务限制
+ max_retries: 最大重试次数
+ progress_callback: 进度回调函数,接收参数 (stage, current, total)
+ - stage: 当前阶段 ('extracting', 'cleaning', 'parsing', 'chunking', 'embedding')
+ - current: 当前进度
+ - total: 总数
+ Returns:
+ KBDocument: 上传的文档对象
+ Raises:
+ ValueError: 如果 URL 为空或无法提取内容
+ IOError: 如果网络请求失败
+ """
+ # 获取 Tavily API 密钥
+ config = self.prov_mgr.acm.default_conf
+ tavily_keys = config.get("provider_settings", {}).get(
+ "websearch_tavily_key", []
+ )
+ if not tavily_keys:
+ raise ValueError(
+ "Error: Tavily API key is not configured in provider_settings."
+ )
+
+ # 阶段1: 从 URL 提取内容
+ if progress_callback:
+ await progress_callback("extracting", 0, 100)
+
+ try:
+ text_content = await extract_text_from_url(url, tavily_keys)
+ except Exception as e:
+ logger.error(f"Failed to extract content from URL {url}: {e}")
+ raise OSError(f"Failed to extract content from URL {url}: {e}") from e
+
+ if not text_content:
+ raise ValueError(f"No content extracted from URL: {url}")
+
+ if progress_callback:
+ await progress_callback("extracting", 100, 100)
+
+ # 阶段2: (可选)清洗内容并分块
+ final_chunks = await self._clean_and_rechunk_content(
+ content=text_content,
+ url=url,
+ progress_callback=progress_callback,
+ enable_cleaning=enable_cleaning,
+ cleaning_provider_id=cleaning_provider_id,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ )
+
+ if enable_cleaning and not final_chunks:
+ raise ValueError(
+ "内容清洗后未提取到有效文本。请尝试关闭内容清洗功能,或更换更高性能的LLM模型后重试。"
+ )
+
+ # 创建一个虚拟文件名
+ file_name = url.split("/")[-1] or f"document_from_{url}"
+ if not Path(file_name).suffix:
+ file_name += ".url"
+
+ # 复用现有的 upload_document 方法,但传入预分块文本
+ return await self.upload_document(
+ file_name=file_name,
+ file_content=None,
+ file_type="url", # 使用 'url' 作为特殊文件类型
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ batch_size=batch_size,
+ tasks_limit=tasks_limit,
+ max_retries=max_retries,
+ progress_callback=progress_callback,
+ pre_chunked_text=final_chunks,
+ )
+
+ async def _clean_and_rechunk_content(
+ self,
+ content: str,
+ url: str,
+ progress_callback=None,
+ enable_cleaning: bool = False,
+ cleaning_provider_id: str | None = None,
+ repair_max_rpm: int = 60,
+ chunk_size: int = 512,
+ chunk_overlap: int = 50,
+ ) -> list[str]:
+ """
+ 对从 URL 获取的内容进行清洗、修复、翻译和重新分块。
+ """
+ if not enable_cleaning:
+ # 如果不启用清洗,则使用从前端传递的参数进行分块
+ logger.info(
+ f"内容清洗未启用,使用指定参数进行分块: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}"
+ )
+ return await self.chunker.chunk(
+ content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
+ )
+
+ if not cleaning_provider_id:
+ logger.warning(
+ "启用了内容清洗,但未提供 cleaning_provider_id,跳过清洗并使用默认分块。"
+ )
+ return await self.chunker.chunk(content)
+
+ if progress_callback:
+ await progress_callback("cleaning", 0, 100)
+
+ try:
+ # 获取指定的 LLM Provider
+ llm_provider = await self.prov_mgr.get_provider_by_id(cleaning_provider_id)
+ if not llm_provider or not isinstance(llm_provider, LLMProvider):
+ raise ValueError(
+ f"无法找到 ID 为 {cleaning_provider_id} 的 LLM Provider 或类型不正确"
+ )
+
+ # 初步分块
+ # 优化分隔符,优先按段落分割,以获得更高质量的文本块
+ text_splitter = RecursiveCharacterChunker(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ separators=["\n\n", "\n", " "], # 优先使用段落分隔符
+ )
+ initial_chunks = await text_splitter.chunk(content)
+ logger.info(f"初步分块完成,生成 {len(initial_chunks)} 个块用于修复。")
+
+ # 并发处理所有块
+ rate_limiter = RateLimiter(repair_max_rpm)
+ tasks = [
+ _repair_and_translate_chunk_with_retry(
+ chunk, llm_provider, rate_limiter
+ )
+ for chunk in initial_chunks
+ ]
+
+ repaired_results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ final_chunks = []
+ for i, result in enumerate(repaired_results):
+ if isinstance(result, Exception):
+ logger.warning(f"块 {i} 处理异常: {str(result)}. 回退到原始块。")
+ final_chunks.append(initial_chunks[i])
+ elif isinstance(result, list):
+ final_chunks.extend(result)
+
+ logger.info(
+ f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。"
+ )
+
+ if progress_callback:
+ await progress_callback("cleaning", 100, 100)
+
+ return final_chunks
+
+ except Exception as e:
+ logger.error(f"使用 Provider '{cleaning_provider_id}' 清洗内容失败: {e}")
+ # 清洗失败,返回默认分块结果,保证流程不中断
+ return await self.chunker.chunk(content)
diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py
index f7e07fe15..2219cc00b 100644
--- a/astrbot/core/knowledge_base/kb_mgr.py
+++ b/astrbot/core/knowledge_base/kb_mgr.py
@@ -8,7 +8,7 @@ from astrbot.core.provider.manager import ProviderManager
from .chunking.recursive import RecursiveCharacterChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .kb_helper import KBHelper
-from .models import KnowledgeBase
+from .models import KBDocument, KnowledgeBase
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.rank_fusion import RankFusion
from .retrieval.sparse_retriever import SparseRetriever
@@ -284,3 +284,47 @@ class KnowledgeBaseManager:
await self.kb_db.close()
except Exception as e:
logger.error(f"关闭知识库元数据数据库失败: {e}")
+
+ async def upload_from_url(
+ self,
+ kb_id: str,
+ url: str,
+ chunk_size: int = 512,
+ chunk_overlap: int = 50,
+ batch_size: int = 32,
+ tasks_limit: int = 3,
+ max_retries: int = 3,
+ progress_callback=None,
+ ) -> KBDocument:
+ """从 URL 上传文档到指定的知识库
+
+ Args:
+ kb_id: 知识库 ID
+ url: 要提取内容的网页 URL
+ chunk_size: 文本块大小
+ chunk_overlap: 文本块重叠大小
+ batch_size: 批处理大小
+ tasks_limit: 并发任务限制
+ max_retries: 最大重试次数
+ progress_callback: 进度回调函数
+
+ Returns:
+ KBDocument: 上传的文档对象
+
+ Raises:
+ ValueError: 如果知识库不存在或 URL 为空
+ IOError: 如果网络请求失败
+ """
+ kb_helper = await self.get_kb(kb_id)
+ if not kb_helper:
+ raise ValueError(f"Knowledge base with id {kb_id} not found.")
+
+ return await kb_helper.upload_from_url(
+ url=url,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ batch_size=batch_size,
+ tasks_limit=tasks_limit,
+ max_retries=max_retries,
+ progress_callback=progress_callback,
+ )
diff --git a/astrbot/core/knowledge_base/parsers/url_parser.py b/astrbot/core/knowledge_base/parsers/url_parser.py
new file mode 100644
index 000000000..f68e2e0c4
--- /dev/null
+++ b/astrbot/core/knowledge_base/parsers/url_parser.py
@@ -0,0 +1,103 @@
+import asyncio
+
+import aiohttp
+
+
+class URLExtractor:
+ """URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
+
+ def __init__(self, tavily_keys: list[str]):
+ """
+ 初始化 URL 提取器
+
+ Args:
+ tavily_keys: Tavily API 密钥列表
+ """
+ if not tavily_keys:
+ raise ValueError("Error: Tavily API keys are not configured.")
+
+ self.tavily_keys = tavily_keys
+ self.tavily_key_index = 0
+ self.tavily_key_lock = asyncio.Lock()
+
+ async def _get_tavily_key(self) -> str:
+ """并发安全的从列表中获取并轮换Tavily API密钥。"""
+ async with self.tavily_key_lock:
+ key = self.tavily_keys[self.tavily_key_index]
+ self.tavily_key_index = (self.tavily_key_index + 1) % len(self.tavily_keys)
+ return key
+
+ async def extract_text_from_url(self, url: str) -> str:
+ """
+ 使用 Tavily API 从 URL 提取主要文本内容。
+ 这是 web_searcher 插件中 tavily_extract_web_page 方法的简化版本,
+ 专门为知识库模块设计,不依赖 AstrMessageEvent。
+
+ Args:
+ url: 要提取内容的网页 URL
+
+ Returns:
+ 提取的文本内容
+
+ Raises:
+ ValueError: 如果 URL 为空或 API 密钥未配置
+ IOError: 如果请求失败或返回错误
+ """
+ if not url:
+ raise ValueError("Error: url must be a non-empty string.")
+
+ tavily_key = await self._get_tavily_key()
+ api_url = "https://api.tavily.com/extract"
+ headers = {
+ "Authorization": f"Bearer {tavily_key}",
+ "Content-Type": "application/json",
+ }
+
+ payload = {
+ "urls": [url],
+ "extract_depth": "basic", # 使用基础提取深度
+ }
+
+ try:
+ async with aiohttp.ClientSession(trust_env=True) as session:
+ async with session.post(
+ api_url,
+ json=payload,
+ headers=headers,
+ timeout=30.0, # 增加超时时间,因为内容提取可能需要更长时间
+ ) as response:
+ if response.status != 200:
+ reason = await response.text()
+ raise OSError(
+ f"Tavily web extraction failed: {reason}, status: {response.status}"
+ )
+
+ data = await response.json()
+ results = data.get("results", [])
+
+ if not results:
+ raise ValueError(f"No content extracted from URL: {url}")
+
+ # 返回第一个结果的内容
+ return results[0].get("raw_content", "")
+
+ except aiohttp.ClientError as e:
+ raise OSError(f"Failed to fetch URL {url}: {e}") from e
+ except Exception as e:
+ raise OSError(f"Failed to extract content from URL {url}: {e}") from e
+
+
+# 为了向后兼容,提供一个简单的函数接口
+async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str:
+ """
+ 简单的函数接口,用于从 URL 提取文本内容
+
+ Args:
+ url: 要提取内容的网页 URL
+ tavily_keys: Tavily API 密钥列表
+
+ Returns:
+ 提取的文本内容
+ """
+ extractor = URLExtractor(tavily_keys)
+ return await extractor.extract_text_from_url(url)
diff --git a/astrbot/core/knowledge_base/prompts.py b/astrbot/core/knowledge_base/prompts.py
new file mode 100644
index 000000000..7874fa5f6
--- /dev/null
+++ b/astrbot/core/knowledge_base/prompts.py
@@ -0,0 +1,65 @@
+TEXT_REPAIR_SYSTEM_PROMPT = """You are a meticulous digital archivist. Your mission is to reconstruct a clean, readable article from raw, noisy text chunks.
+
+**Core Task:**
+1. **Analyze:** Examine the text chunk to separate "signal" (substantive information) from "noise" (UI elements, ads, navigation, footers).
+2. **Process:** Clean and repair the signal. **Do not translate it.** Keep the original language.
+
+**Crucial Rules:**
+- **NEVER discard a chunk if it contains ANY valuable information.** Your primary duty is to salvage content.
+- **If a chunk contains multiple distinct topics, split them.** Enclose each topic in its own `
` tag.
+- Your output MUST be ONLY `... ` tags or a single ` ` tag.
+
+---
+**Example 1: Chunk with Noise and Signal**
+
+*Input Chunk:*
+"Home | About | Products | **The Llama is a domesticated South American camelid.** | © 2025 ACME Corp."
+
+*Your Thought Process:*
+1. "Home | About | Products..." and "© 2025 ACME Corp." are noise.
+2. "The Llama is a domesticated..." is the signal.
+3. I must extract the signal and wrap it.
+
+*Your Output:*
+
+The Llama is a domesticated South American camelid.
+
+
+---
+**Example 2: Chunk with ONLY Noise**
+
+*Input Chunk:*
+"Next Page > | Subscribe to our newsletter | Follow us on X"
+
+*Your Thought Process:*
+1. This entire chunk is noise. There is no signal.
+2. I must discard this.
+
+*Your Output:*
+
+
+---
+**Example 3: Chunk with Multiple Topics (Requires Splitting)**
+
+*Input Chunk:*
+"## Chapter 1: The Sun
+The Sun is the star at the center of the Solar System.
+
+## Chapter 2: The Moon
+The Moon is Earth's only natural satellite."
+
+*Your Thought Process:*
+1. This chunk contains two distinct topics.
+2. I must process them separately to maintain semantic integrity.
+3. I will create two `` blocks.
+
+*Your Output:*
+
+## Chapter 1: The Sun
+The Sun is the star at the center of the Solar System.
+
+
+## Chapter 2: The Moon
+The Moon is Earth's only natural satellite.
+
+"""
diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py
index 9a42cd6cd..746406e90 100644
--- a/astrbot/core/knowledge_base/retrieval/manager.py
+++ b/astrbot/core/knowledge_base/retrieval/manager.py
@@ -166,7 +166,11 @@ class RetrievalManager:
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
- vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
+ vec_db = kb_options[kb_id]["vec_db"]
+ if not isinstance(vec_db, FaissVecDB):
+ logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
+ continue
+
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
diff --git a/astrbot/core/log.py b/astrbot/core/log.py
index 376f5ffd6..806ebcebb 100644
--- a/astrbot/core/log.py
+++ b/astrbot/core/log.py
@@ -24,6 +24,7 @@ import asyncio
import logging
import os
import sys
+import time
from asyncio import Queue
from collections import deque
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish(
{
"level": record.levelname,
- "time": record.asctime,
+ "time": time.time(),
"data": log_entry,
},
)
diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py
index 43e3bf0e3..050e36521 100644
--- a/astrbot/core/message/components.py
+++ b/astrbot/core/message/components.py
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel):
type: ComponentType
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
def toDict(self):
data = {}
for k, v in self.__dict__.items():
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
id: int | None = 0 # 忽略
name: str | None = "" # qq昵称
uin: str | None = "0" # qq号
- content: list[BaseMessageComponent] | None = []
+ content: list[BaseMessageComponent] = []
seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
ret["messages"].append(d)
return ret
- async def to_dict(self):
+ async def to_dict(self) -> dict:
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
@@ -626,12 +629,11 @@ class Nodes(BaseMessageComponent):
class Json(BaseMessageComponent):
type = ComponentType.Json
- data: str | dict
- resid: int | None = 0
+ data: dict
- def __init__(self, data, **_):
- if isinstance(data, dict):
- data = json.dumps(data)
+ def __init__(self, data: str | dict, **_):
+ if isinstance(data, str):
+ data = json.loads(data)
super().__init__(data=data, **_)
@@ -714,15 +716,23 @@ class File(BaseMessageComponent):
if self.url:
await self._download_file()
- return os.path.abspath(self.file_)
+ if self.file_:
+ return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
+ if not self.url:
+ raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
- file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
+ if self.name:
+ name, ext = os.path.splitext(self.name)
+ filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
+ else:
+ filename = f"{uuid.uuid4().hex}"
+ file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py
index 5d1743ab9..b2d2c6be1 100644
--- a/astrbot/core/persona_mgr.py
+++ b/astrbot/core/persona_mgr.py
@@ -98,8 +98,8 @@ class PersonaManager:
self,
persona_id: str,
system_prompt: str,
- begin_dialogs: list[str] = None,
- tools: list[str] = None,
+ begin_dialogs: list[str] | None = None,
+ tools: list[str] | None = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py
index c477cc23a..b089c48e0 100644
--- a/astrbot/core/pipeline/content_safety_check/stage.py
+++ b/astrbot/core/pipeline/content_safety_check/stage.py
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
self,
event: AstrMessageEvent,
check_text: str | None = None,
- ) -> None | AsyncGenerator[None, None]:
+ ) -> AsyncGenerator[None, None]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py
index 44186764e..a6cd567e0 100644
--- a/astrbot/core/pipeline/context.py
+++ b/astrbot/core/pipeline/context.py
@@ -3,7 +3,7 @@ from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
-from .context_utils import call_event_hook, call_handler, call_local_llm_tool
+from .context_utils import call_event_hook, call_handler
@dataclass
@@ -15,4 +15,3 @@ class PipelineContext:
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
- call_local_llm_tool = call_local_llm_tool
diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py
index 371816b6e..1f5ba43a0 100644
--- a/astrbot/core/pipeline/context_utils.py
+++ b/astrbot/core/pipeline/context_utils.py
@@ -3,8 +3,6 @@ import traceback
import typing as T
from astrbot import logger
-from astrbot.core.agent.run_context import ContextWrapper
-from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
@@ -13,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
async def call_handler(
event: AstrMessageEvent,
- handler: T.Callable[..., T.Awaitable[T.Any]],
+ handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -93,6 +91,7 @@ async def call_event_hook(
)
for handler in handlers:
try:
+ assert inspect.iscoroutinefunction(handler.handler)
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
)
@@ -107,66 +106,3 @@ async def call_event_hook(
return True
return event.is_stopped()
-
-
-async def call_local_llm_tool(
- context: ContextWrapper[AstrAgentContext],
- handler: T.Callable[..., T.Awaitable[T.Any]],
- method_name: str,
- *args,
- **kwargs,
-) -> T.AsyncGenerator[T.Any, None]:
- """执行本地 LLM 工具的处理函数并处理其返回结果"""
- ready_to_call = None # 一个协程或者异步生成器
-
- trace_ = None
-
- event = context.context.event
-
- try:
- if method_name == "run" or method_name == "decorator_handler":
- ready_to_call = handler(event, *args, **kwargs)
- elif method_name == "call":
- ready_to_call = handler(context, *args, **kwargs)
- else:
- raise ValueError(f"未知的方法名: {method_name}")
- except ValueError as e:
- logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
- except TypeError:
- logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
- except Exception as e:
- trace_ = traceback.format_exc()
- logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
-
- if not ready_to_call:
- return
-
- if inspect.isasyncgen(ready_to_call):
- _has_yielded = False
- try:
- async for ret in ready_to_call:
- # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
- # 返回值只能是 MessageEventResult 或者 None(无返回值)
- _has_yielded = True
- if isinstance(ret, (MessageEventResult, CommandResult)):
- # 如果返回值是 MessageEventResult, 设置结果并继续
- event.set_result(ret)
- yield
- else:
- # 如果返回值是 None, 则不设置结果并继续
- # 继续执行后续阶段
- yield ret
- if not _has_yielded:
- # 如果这个异步生成器没有执行到 yield 分支
- yield
- except Exception as e:
- logger.error(f"Previous Error: {trace_}")
- raise e
- elif inspect.iscoroutine(ready_to_call):
- # 如果只是一个协程, 直接执行
- ret = await ready_to_call
- if isinstance(ret, (MessageEventResult, CommandResult)):
- event.set_result(ret)
- yield
- else:
- yield ret
diff --git a/astrbot/core/pipeline/process_stage/method/agent_request.py b/astrbot/core/pipeline/process_stage/method/agent_request.py
new file mode 100644
index 000000000..f6f81631e
--- /dev/null
+++ b/astrbot/core/pipeline/process_stage/method/agent_request.py
@@ -0,0 +1,48 @@
+from collections.abc import AsyncGenerator
+
+from astrbot.core import logger
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.star.session_llm_manager import SessionServiceManager
+
+from ...context import PipelineContext
+from ..stage import Stage
+from .agent_sub_stages.internal import InternalAgentSubStage
+from .agent_sub_stages.third_party import ThirdPartyAgentSubStage
+
+
+class AgentRequestSubStage(Stage):
+ async def initialize(self, ctx: PipelineContext) -> None:
+ self.ctx = ctx
+ self.config = ctx.astrbot_config
+
+ self.bot_wake_prefixs: list[str] = self.config["wake_prefix"]
+ self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"]
+ for bwp in self.bot_wake_prefixs:
+ if self.prov_wake_prefix.startswith(bwp):
+ logger.info(
+ f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
+ )
+ self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :]
+
+ agent_runner_type = self.config["provider_settings"]["agent_runner_type"]
+ if agent_runner_type == "local":
+ self.agent_sub_stage = InternalAgentSubStage()
+ else:
+ self.agent_sub_stage = ThirdPartyAgentSubStage()
+ await self.agent_sub_stage.initialize(ctx)
+
+ async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]:
+ if not self.ctx.astrbot_config["provider_settings"]["enable"]:
+ logger.debug(
+ "This pipeline does not enable AI capability, skip processing."
+ )
+ return
+
+ if not SessionServiceManager.should_process_llm_request(event):
+ logger.debug(
+ f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
+ )
+ return
+
+ async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix):
+ yield resp
diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py
new file mode 100644
index 000000000..7e3305f55
--- /dev/null
+++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py
@@ -0,0 +1,523 @@
+"""本地 Agent 模式的 LLM 调用 Stage"""
+
+import asyncio
+import copy
+import json
+from collections.abc import AsyncGenerator
+
+from astrbot.core import logger
+from astrbot.core.agent.tool import ToolSet
+from astrbot.core.astr_agent_context import AstrAgentContext
+from astrbot.core.conversation_mgr import Conversation
+from astrbot.core.message.components import File, Image, Reply
+from astrbot.core.message.message_event_result import (
+ MessageChain,
+ MessageEventResult,
+ ResultContentType,
+)
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.provider import Provider
+from astrbot.core.provider.entities import (
+ LLMResponse,
+ ProviderRequest,
+)
+from astrbot.core.star.star_handler import EventType, star_map
+from astrbot.core.utils.file_extract import extract_file_moonshotai
+from astrbot.core.utils.metrics import Metric
+from astrbot.core.utils.session_lock import session_lock_manager
+
+from .....astr_agent_context import AgentContextWrapper
+from .....astr_agent_hooks import MAIN_AGENT_HOOKS
+from .....astr_agent_run_util import AgentRunner, run_agent
+from .....astr_agent_tool_exec import FunctionToolExecutor
+from ....context import PipelineContext, call_event_hook
+from ...stage import Stage
+from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
+
+
+class InternalAgentSubStage(Stage):
+ async def initialize(self, ctx: PipelineContext) -> None:
+ self.ctx = ctx
+ conf = ctx.astrbot_config
+ settings = conf["provider_settings"]
+ self.max_context_length = settings["max_context_length"] # int
+ self.dequeue_context_length: int = min(
+ max(1, settings["dequeue_context_length"]),
+ self.max_context_length - 1,
+ )
+ self.streaming_response: bool = settings["streaming_response"]
+ self.unsupported_streaming_strategy: str = settings[
+ "unsupported_streaming_strategy"
+ ]
+ self.max_step: int = settings.get("max_agent_step", 30)
+ self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
+ if isinstance(self.max_step, bool): # workaround: #2622
+ self.max_step = 30
+ self.show_tool_use: bool = settings.get("show_tool_use_status", True)
+ self.show_reasoning = settings.get("display_reasoning_text", False)
+ self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
+
+ file_extract_conf: dict = settings.get("file_extract", {})
+ self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
+ self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
+ self.file_extract_msh_api_key: str = file_extract_conf.get(
+ "moonshotai_api_key", ""
+ )
+
+ self.conv_manager = ctx.plugin_manager.context.conversation_manager
+
+ def _select_provider(self, event: AstrMessageEvent):
+ """选择使用的 LLM 提供商"""
+ sel_provider = event.get_extra("selected_provider")
+ _ctx = self.ctx.plugin_manager.context
+ if sel_provider and isinstance(sel_provider, str):
+ provider = _ctx.get_provider_by_id(sel_provider)
+ if not provider:
+ logger.error(f"未找到指定的提供商: {sel_provider}。")
+ return provider
+
+ return _ctx.get_using_provider(umo=event.unified_msg_origin)
+
+ async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
+ umo = event.unified_msg_origin
+ conv_mgr = self.conv_manager
+
+ # 获取对话上下文
+ cid = await conv_mgr.get_curr_conversation_id(umo)
+ if not cid:
+ cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
+ conversation = await conv_mgr.get_conversation(umo, cid)
+ if not conversation:
+ cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
+ conversation = await conv_mgr.get_conversation(umo, cid)
+ if not conversation:
+ raise RuntimeError("无法创建新的对话。")
+ return conversation
+
+ async def _apply_kb(
+ self,
+ event: AstrMessageEvent,
+ req: ProviderRequest,
+ ):
+ """Apply knowledge base context to the provider request"""
+ if not self.kb_agentic_mode:
+ if req.prompt is None:
+ return
+ try:
+ kb_result = await retrieve_knowledge_base(
+ query=req.prompt,
+ umo=event.unified_msg_origin,
+ context=self.ctx.plugin_manager.context,
+ )
+ if not kb_result:
+ return
+ if req.system_prompt is not None:
+ req.system_prompt += (
+ f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
+ )
+ except Exception as e:
+ logger.error(f"Error occurred while retrieving knowledge base: {e}")
+ else:
+ if req.func_tool is None:
+ req.func_tool = ToolSet()
+ req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
+
+ async def _apply_file_extract(
+ self,
+ event: AstrMessageEvent,
+ req: ProviderRequest,
+ ):
+ """Apply file extract to the provider request"""
+ file_paths = []
+ file_names = []
+ for comp in event.message_obj.message:
+ if isinstance(comp, File):
+ file_paths.append(await comp.get_file())
+ file_names.append(comp.name)
+ elif isinstance(comp, Reply) and comp.chain:
+ for reply_comp in comp.chain:
+ if isinstance(reply_comp, File):
+ file_paths.append(await reply_comp.get_file())
+ file_names.append(reply_comp.name)
+ if not file_paths:
+ return
+ if not req.prompt:
+ req.prompt = "总结一下文件里面讲了什么?"
+ if self.file_extract_prov == "moonshotai":
+ if not self.file_extract_msh_api_key:
+ logger.error("Moonshot AI API key for file extract is not set")
+ return
+ file_contents = await asyncio.gather(
+ *[
+ extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
+ for file_path in file_paths
+ ]
+ )
+ else:
+ logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
+ return
+
+ # add file extract results to contexts
+ for file_content, file_name in zip(file_contents, file_names):
+ req.contexts.append(
+ {
+ "role": "system",
+ "content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
+ },
+ )
+
+ def _truncate_contexts(
+ self,
+ contexts: list[dict],
+ ) -> list[dict]:
+ """截断上下文列表,确保不超过最大长度"""
+ if self.max_context_length == -1:
+ return contexts
+
+ if len(contexts) // 2 <= self.max_context_length:
+ return contexts
+
+ truncated_contexts = contexts[
+ -(self.max_context_length - self.dequeue_context_length + 1) * 2 :
+ ]
+ # 找到第一个role 为 user 的索引,确保上下文格式正确
+ index = next(
+ (
+ i
+ for i, item in enumerate(truncated_contexts)
+ if item.get("role") == "user"
+ ),
+ None,
+ )
+ if index is not None and index > 0:
+ truncated_contexts = truncated_contexts[index:]
+
+ return truncated_contexts
+
+ def _modalities_fix(
+ self,
+ provider: Provider,
+ req: ProviderRequest,
+ ):
+ """检查提供商的模态能力,清理请求中的不支持内容"""
+ if req.image_urls:
+ provider_cfg = provider.provider_config.get("modalities", ["image"])
+ if "image" not in provider_cfg:
+ logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
+ req.image_urls = []
+ if req.func_tool:
+ provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
+ # 如果模型不支持工具使用,但请求中包含工具列表,则清空。
+ if "tool_use" not in provider_cfg:
+ logger.debug(
+ f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
+ )
+ req.func_tool = None
+
+ def _plugin_tool_fix(
+ self,
+ event: AstrMessageEvent,
+ req: ProviderRequest,
+ ):
+ """根据事件中的插件设置,过滤请求中的工具列表"""
+ if event.plugins_name is not None and req.func_tool:
+ new_tool_set = ToolSet()
+ for tool in req.func_tool.tools:
+ mp = tool.handler_module_path
+ if not mp:
+ continue
+ plugin = star_map.get(mp)
+ if not plugin:
+ continue
+ if plugin.name in event.plugins_name or plugin.reserved:
+ new_tool_set.add_tool(tool)
+ req.func_tool = new_tool_set
+
+ async def _handle_webchat(
+ self,
+ event: AstrMessageEvent,
+ req: ProviderRequest,
+ prov: Provider,
+ ):
+ """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
+ if not req.conversation:
+ return
+ conversation = await self.conv_manager.get_conversation(
+ event.unified_msg_origin,
+ req.conversation.cid,
+ )
+ if conversation and not req.conversation.title:
+ messages = json.loads(conversation.history)
+ latest_pair = messages[-2:]
+ if not latest_pair:
+ return
+ content = latest_pair[0].get("content", "")
+ if isinstance(content, list):
+ # 多模态
+ text_parts = []
+ for item in content:
+ if isinstance(item, dict):
+ if item.get("type") == "text":
+ text_parts.append(item.get("text", ""))
+ elif item.get("type") == "image":
+ text_parts.append("[图片]")
+ elif isinstance(item, str):
+ text_parts.append(item)
+ cleaned_text = "User: " + " ".join(text_parts).strip()
+ elif isinstance(content, str):
+ cleaned_text = "User: " + content.strip()
+ else:
+ return
+ logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
+ llm_resp = await prov.text_chat(
+ system_prompt="You are expert in summarizing user's query.",
+ prompt=(
+ f"Please summarize the following query of user:\n"
+ f"{cleaned_text}\n"
+ "Only output the summary within 10 words, DO NOT INCLUDE any other text."
+ "You must use the same language as the user."
+ "If you think the dialog is too short to summarize, only output a special mark: ``"
+ ),
+ )
+ if llm_resp and llm_resp.completion_text:
+ title = llm_resp.completion_text.strip()
+ if not title or "" in title:
+ return
+ await self.conv_manager.update_conversation_title(
+ unified_msg_origin=event.unified_msg_origin,
+ title=title,
+ conversation_id=req.conversation.cid,
+ )
+
+ async def _save_to_history(
+ self,
+ event: AstrMessageEvent,
+ req: ProviderRequest,
+ llm_response: LLMResponse | None,
+ ):
+ if (
+ not req
+ or not req.conversation
+ or not llm_response
+ or llm_response.role != "assistant"
+ ):
+ return
+
+ if not llm_response.completion_text and not req.tool_calls_result:
+ logger.debug("LLM 响应为空,不保存记录。")
+ return
+
+ if req.contexts is None:
+ req.contexts = []
+
+ # 历史上下文
+ messages = copy.deepcopy(req.contexts)
+ # 这一轮对话请求的用户输入
+ messages.append(await req.assemble_context())
+ # 这一轮对话的 LLM 响应
+ if req.tool_calls_result:
+ if not isinstance(req.tool_calls_result, list):
+ messages.extend(req.tool_calls_result.to_openai_messages())
+ elif isinstance(req.tool_calls_result, list):
+ for tcr in req.tool_calls_result:
+ messages.extend(tcr.to_openai_messages())
+ messages.append({"role": "assistant", "content": llm_response.completion_text})
+ messages = list(filter(lambda item: "_no_save" not in item, messages))
+ await self.conv_manager.update_conversation(
+ event.unified_msg_origin,
+ req.conversation.cid,
+ history=messages,
+ )
+
+ def _fix_messages(self, messages: list[dict]) -> list[dict]:
+ """验证并且修复上下文"""
+ fixed_messages = []
+ for message in messages:
+ if message.get("role") == "tool":
+ # tool block 前面必须要有 user 和 assistant block
+ if len(fixed_messages) < 2:
+ # 这种情况可能是上下文被截断导致的
+ # 我们直接将之前的上下文都清空
+ fixed_messages = []
+ else:
+ fixed_messages.append(message)
+ else:
+ fixed_messages.append(message)
+ return fixed_messages
+
+ async def process(
+ self, event: AstrMessageEvent, provider_wake_prefix: str
+ ) -> AsyncGenerator[None, None]:
+ req: ProviderRequest | None = None
+
+ provider = self._select_provider(event)
+ if provider is None:
+ return
+ if not isinstance(provider, Provider):
+ logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
+ return
+
+ streaming_response = self.streaming_response
+ if (enable_streaming := event.get_extra("enable_streaming")) is not None:
+ streaming_response = bool(enable_streaming)
+
+ logger.debug("ready to request llm provider")
+ async with session_lock_manager.acquire_lock(event.unified_msg_origin):
+ logger.debug("acquired session lock for llm request")
+ if event.get_extra("provider_request"):
+ req = event.get_extra("provider_request")
+ assert isinstance(req, ProviderRequest), (
+ "provider_request 必须是 ProviderRequest 类型。"
+ )
+
+ if req.conversation:
+ req.contexts = json.loads(req.conversation.history)
+
+ else:
+ req = ProviderRequest()
+ req.prompt = ""
+ req.image_urls = []
+ if sel_model := event.get_extra("selected_model"):
+ req.model = sel_model
+ if provider_wake_prefix and not event.message_str.startswith(
+ provider_wake_prefix
+ ):
+ return
+
+ req.prompt = event.message_str[len(provider_wake_prefix) :]
+ # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
+ # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
+ for comp in event.message_obj.message:
+ if isinstance(comp, Image):
+ image_path = await comp.convert_to_file_path()
+ req.image_urls.append(image_path)
+
+ conversation = await self._get_session_conv(event)
+ req.conversation = conversation
+ req.contexts = json.loads(conversation.history)
+
+ event.set_extra("provider_request", req)
+
+ # fix contexts json str
+ if isinstance(req.contexts, str):
+ req.contexts = json.loads(req.contexts)
+
+ # apply file extract
+ if self.file_extract_enabled:
+ try:
+ await self._apply_file_extract(event, req)
+ except Exception as e:
+ logger.error(f"Error occurred while applying file extract: {e}")
+
+ if not req.prompt and not req.image_urls:
+ return
+
+ # call event hook
+ if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
+ return
+
+ # apply knowledge base feature
+ await self._apply_kb(event, req)
+
+ # truncate contexts to fit max length
+ if req.contexts:
+ req.contexts = self._truncate_contexts(req.contexts)
+ self._fix_messages(req.contexts)
+
+ # session_id
+ if not req.session_id:
+ req.session_id = event.unified_msg_origin
+
+ # check provider modalities, if provider does not support image/tool_use, clear them in request.
+ self._modalities_fix(provider, req)
+
+ # filter tools, only keep tools from this pipeline's selected plugins
+ self._plugin_tool_fix(event, req)
+
+ stream_to_general = (
+ self.unsupported_streaming_strategy == "turn_off"
+ and not event.platform_meta.support_streaming_message
+ )
+ # 备份 req.contexts
+ backup_contexts = copy.deepcopy(req.contexts)
+
+ # run agent
+ agent_runner = AgentRunner()
+ logger.debug(
+ f"handle provider[id: {provider.provider_config['id']}] request: {req}",
+ )
+ astr_agent_ctx = AstrAgentContext(
+ context=self.ctx.plugin_manager.context,
+ event=event,
+ )
+ await agent_runner.reset(
+ provider=provider,
+ request=req,
+ run_context=AgentContextWrapper(
+ context=astr_agent_ctx,
+ tool_call_timeout=self.tool_call_timeout,
+ ),
+ tool_executor=FunctionToolExecutor(),
+ agent_hooks=MAIN_AGENT_HOOKS,
+ streaming=streaming_response,
+ )
+
+ if streaming_response and not stream_to_general:
+ # 流式响应
+ event.set_result(
+ MessageEventResult()
+ .set_result_content_type(ResultContentType.STREAMING_RESULT)
+ .set_async_stream(
+ run_agent(
+ agent_runner,
+ self.max_step,
+ self.show_tool_use,
+ show_reasoning=self.show_reasoning,
+ ),
+ ),
+ )
+ yield
+ if agent_runner.done():
+ if final_llm_resp := agent_runner.get_final_llm_resp():
+ if final_llm_resp.completion_text:
+ chain = (
+ MessageChain()
+ .message(final_llm_resp.completion_text)
+ .chain
+ )
+ elif final_llm_resp.result_chain:
+ chain = final_llm_resp.result_chain.chain
+ else:
+ chain = MessageChain().chain
+ event.set_result(
+ MessageEventResult(
+ chain=chain,
+ result_content_type=ResultContentType.STREAMING_FINISH,
+ ),
+ )
+ else:
+ async for _ in run_agent(
+ agent_runner,
+ self.max_step,
+ self.show_tool_use,
+ stream_to_general,
+ show_reasoning=self.show_reasoning,
+ ):
+ yield
+
+ # 恢复备份的 contexts
+ req.contexts = backup_contexts
+
+ await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
+
+ # 异步处理 WebChat 特殊情况
+ if event.get_platform_name() == "webchat":
+ asyncio.create_task(self._handle_webchat(event, req, provider))
+
+ asyncio.create_task(
+ Metric.upload(
+ llm_tick=1,
+ model_name=agent_runner.provider.get_model(),
+ provider_type=agent_runner.provider.meta().type,
+ ),
+ )
diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py
new file mode 100644
index 000000000..b590bd77e
--- /dev/null
+++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py
@@ -0,0 +1,205 @@
+import asyncio
+from collections.abc import AsyncGenerator
+from typing import TYPE_CHECKING
+
+from astrbot.core import astrbot_config, logger
+from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
+from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
+ DashscopeAgentRunner,
+)
+from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
+from astrbot.core.message.components import Image
+from astrbot.core.message.message_event_result import (
+ MessageChain,
+ MessageEventResult,
+ ResultContentType,
+)
+
+if TYPE_CHECKING:
+ from astrbot.core.agent.runners.base import BaseAgentRunner
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.provider.entities import (
+ ProviderRequest,
+)
+from astrbot.core.star.star_handler import EventType
+from astrbot.core.utils.metrics import Metric
+
+from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
+from .....astr_agent_hooks import MAIN_AGENT_HOOKS
+from ....context import PipelineContext, call_event_hook
+from ...stage import Stage
+
+AGENT_RUNNER_TYPE_KEY = {
+ "dify": "dify_agent_runner_provider_id",
+ "coze": "coze_agent_runner_provider_id",
+ "dashscope": "dashscope_agent_runner_provider_id",
+}
+
+
+async def run_third_party_agent(
+ runner: "BaseAgentRunner",
+ stream_to_general: bool = False,
+) -> AsyncGenerator[MessageChain | None, None]:
+ """
+ 运行第三方 agent runner 并转换响应格式
+ 类似于 run_agent 函数,但专门处理第三方 agent runner
+ """
+ try:
+ async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
+ if resp.type == "streaming_delta":
+ if stream_to_general:
+ continue
+ yield resp.data["chain"]
+ elif resp.type == "llm_result":
+ if stream_to_general:
+ yield resp.data["chain"]
+ except Exception as e:
+ logger.error(f"Third party agent runner error: {e}")
+ err_msg = (
+ f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
+ f"错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
+ )
+ yield MessageChain().message(err_msg)
+
+
+class ThirdPartyAgentSubStage(Stage):
+ async def initialize(self, ctx: PipelineContext) -> None:
+ self.ctx = ctx
+ self.conf = ctx.astrbot_config
+ self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
+ self.prov_id = self.conf["provider_settings"].get(
+ AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
+ "",
+ )
+ settings = ctx.astrbot_config["provider_settings"]
+ self.streaming_response: bool = settings["streaming_response"]
+ self.unsupported_streaming_strategy: str = settings[
+ "unsupported_streaming_strategy"
+ ]
+
+ async def process(
+ self, event: AstrMessageEvent, provider_wake_prefix: str
+ ) -> AsyncGenerator[None, None]:
+ req: ProviderRequest | None = None
+
+ if provider_wake_prefix and not event.message_str.startswith(
+ provider_wake_prefix
+ ):
+ return
+
+ self.prov_cfg: dict = next(
+ (p for p in astrbot_config["provider"] if p["id"] == self.prov_id),
+ {},
+ )
+ if not self.prov_id:
+ logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。")
+ return
+ if not self.prov_cfg:
+ logger.error(
+ f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。"
+ )
+ return
+
+ # make provider request
+ req = ProviderRequest()
+ req.session_id = event.unified_msg_origin
+ req.prompt = event.message_str[len(provider_wake_prefix) :]
+ for comp in event.message_obj.message:
+ if isinstance(comp, Image):
+ image_path = await comp.convert_to_base64()
+ req.image_urls.append(image_path)
+
+ if not req.prompt and not req.image_urls:
+ return
+
+ # call event hook
+ if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
+ return
+
+ if self.runner_type == "dify":
+ runner = DifyAgentRunner[AstrAgentContext]()
+ elif self.runner_type == "coze":
+ runner = CozeAgentRunner[AstrAgentContext]()
+ elif self.runner_type == "dashscope":
+ runner = DashscopeAgentRunner[AstrAgentContext]()
+ else:
+ raise ValueError(
+ f"Unsupported third party agent runner type: {self.runner_type}",
+ )
+
+ astr_agent_ctx = AstrAgentContext(
+ context=self.ctx.plugin_manager.context,
+ event=event,
+ )
+
+ streaming_response = self.streaming_response
+ if (enable_streaming := event.get_extra("enable_streaming")) is not None:
+ streaming_response = bool(enable_streaming)
+
+ stream_to_general = (
+ self.unsupported_streaming_strategy == "turn_off"
+ and not event.platform_meta.support_streaming_message
+ )
+
+ await runner.reset(
+ request=req,
+ run_context=AgentContextWrapper(
+ context=astr_agent_ctx,
+ tool_call_timeout=60,
+ ),
+ agent_hooks=MAIN_AGENT_HOOKS,
+ provider_config=self.prov_cfg,
+ streaming=streaming_response,
+ )
+
+ if streaming_response and not stream_to_general:
+ # 流式响应
+ event.set_result(
+ MessageEventResult()
+ .set_result_content_type(ResultContentType.STREAMING_RESULT)
+ .set_async_stream(
+ run_third_party_agent(
+ runner,
+ stream_to_general=False,
+ ),
+ ),
+ )
+ yield
+ if runner.done():
+ final_resp = runner.get_final_llm_resp()
+ if final_resp and final_resp.result_chain:
+ event.set_result(
+ MessageEventResult(
+ chain=final_resp.result_chain.chain or [],
+ result_content_type=ResultContentType.STREAMING_FINISH,
+ ),
+ )
+ else:
+ # 非流式响应或转换为普通响应
+ async for _ in run_third_party_agent(
+ runner,
+ stream_to_general=stream_to_general,
+ ):
+ yield
+
+ final_resp = runner.get_final_llm_resp()
+
+ if not final_resp or not final_resp.result_chain:
+ logger.warning("Agent Runner 未返回最终结果。")
+ return
+
+ event.set_result(
+ MessageEventResult(
+ chain=final_resp.result_chain.chain or [],
+ result_content_type=ResultContentType.LLM_RESULT,
+ ),
+ )
+ yield
+
+ asyncio.create_task(
+ Metric.upload(
+ llm_tick=1,
+ model_name=self.runner_type,
+ provider_type=self.runner_type,
+ ),
+ )
diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
deleted file mode 100644
index 5974cd519..000000000
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ /dev/null
@@ -1,727 +0,0 @@
-"""本地 Agent 模式的 LLM 调用 Stage"""
-
-import asyncio
-import copy
-import json
-import traceback
-from collections.abc import AsyncGenerator
-from typing import Any
-
-from mcp.types import CallToolResult
-
-from astrbot.core import logger
-from astrbot.core.agent.handoff import HandoffTool
-from astrbot.core.agent.hooks import BaseAgentRunHooks
-from astrbot.core.agent.mcp_client import MCPTool
-from astrbot.core.agent.run_context import ContextWrapper
-from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
-from astrbot.core.agent.tool import FunctionTool, ToolSet
-from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
-from astrbot.core.astr_agent_context import AstrAgentContext
-from astrbot.core.conversation_mgr import Conversation
-from astrbot.core.message.components import Image
-from astrbot.core.message.message_event_result import (
- MessageChain,
- MessageEventResult,
- ResultContentType,
-)
-from astrbot.core.platform.astr_message_event import AstrMessageEvent
-from astrbot.core.provider import Provider
-from astrbot.core.provider.entities import (
- LLMResponse,
- ProviderRequest,
-)
-from astrbot.core.provider.register import llm_tools
-from astrbot.core.star.session_llm_manager import SessionServiceManager
-from astrbot.core.star.star_handler import EventType, star_map
-from astrbot.core.utils.metrics import Metric
-
-from ...context import PipelineContext, call_event_hook, call_local_llm_tool
-from ..stage import Stage
-from ..utils import inject_kb_context
-
-try:
- import mcp
-except (ModuleNotFoundError, ImportError):
- logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
-
-
-AgentContextWrapper = ContextWrapper[AstrAgentContext]
-AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
-
-
-class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
- @classmethod
- async def execute(cls, tool, run_context, **tool_args):
- """执行函数调用。
-
- Args:
- event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
- **kwargs: 函数调用的参数。
-
- Returns:
- AsyncGenerator[None | mcp.types.CallToolResult, None]
-
- """
- if isinstance(tool, HandoffTool):
- async for r in cls._execute_handoff(tool, run_context, **tool_args):
- yield r
- return
-
- elif isinstance(tool, MCPTool):
- async for r in cls._execute_mcp(tool, run_context, **tool_args):
- yield r
- return
-
- else:
- async for r in cls._execute_local(tool, run_context, **tool_args):
- yield r
- return
-
- @classmethod
- async def _execute_handoff(
- cls,
- tool: HandoffTool,
- run_context: ContextWrapper[AstrAgentContext],
- **tool_args,
- ):
- input_ = tool_args.get("input", "agent")
- agent_runner = AgentRunner()
-
- # make toolset for the agent
- tools = tool.agent.tools
- if tools:
- toolset = ToolSet()
- for t in tools:
- if isinstance(t, str):
- _t = llm_tools.get_func(t)
- if _t:
- toolset.add_tool(_t)
- elif isinstance(t, FunctionTool):
- toolset.add_tool(t)
- else:
- toolset = None
-
- request = ProviderRequest(
- prompt=input_,
- system_prompt=tool.description or "",
- image_urls=[], # 暂时不传递原始 agent 的上下文
- contexts=[], # 暂时不传递原始 agent 的上下文
- func_tool=toolset,
- )
- astr_agent_ctx = AstrAgentContext(
- provider=run_context.context.provider,
- first_provider_request=run_context.context.first_provider_request,
- curr_provider_request=request,
- streaming=run_context.context.streaming,
- event=run_context.context.event,
- )
-
- event = run_context.context.event
-
- logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
- await event.send(
- MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
- )
-
- await agent_runner.reset(
- provider=run_context.context.provider,
- request=request,
- run_context=AgentContextWrapper(
- context=astr_agent_ctx,
- tool_call_timeout=run_context.tool_call_timeout,
- ),
- tool_executor=FunctionToolExecutor(),
- agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
- streaming=run_context.context.streaming,
- )
-
- async for _ in run_agent(agent_runner, 15, True):
- pass
-
- if agent_runner.done():
- llm_response = agent_runner.get_final_llm_resp()
-
- if not llm_response:
- text_content = mcp.types.TextContent(
- type="text",
- text=f"error when deligate task to {tool.agent.name}",
- )
- yield mcp.types.CallToolResult(content=[text_content])
- return
-
- logger.debug(
- f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
- )
-
- result = (
- f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
- "Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
- )
-
- text_content = mcp.types.TextContent(
- type="text",
- text=result,
- )
- yield mcp.types.CallToolResult(content=[text_content])
- else:
- text_content = mcp.types.TextContent(
- type="text",
- text=f"error when deligate task to {tool.agent.name}",
- )
- yield mcp.types.CallToolResult(content=[text_content])
- return
-
- @classmethod
- async def _execute_local(
- cls,
- tool: FunctionTool,
- run_context: ContextWrapper[AstrAgentContext],
- **tool_args,
- ):
- event = run_context.context.event
- if not event:
- raise ValueError("Event must be provided for local function tools.")
-
- is_override_call = False
- for ty in type(tool).mro():
- if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
- logger.debug(f"Found call in: {ty}")
- is_override_call = True
- break
-
- # 检查 tool 下有没有 run 方法
- if not tool.handler and not hasattr(tool, "run") and not is_override_call:
- raise ValueError("Tool must have a valid handler or override 'run' method.")
-
- awaitable = None
- method_name = ""
- if tool.handler:
- awaitable = tool.handler
- method_name = "decorator_handler"
- elif is_override_call:
- awaitable = tool.call
- method_name = "call"
- elif hasattr(tool, "run"):
- awaitable = getattr(tool, "run")
- method_name = "run"
- if awaitable is None:
- raise ValueError("Tool must have a valid handler or override 'run' method.")
-
- wrapper = call_local_llm_tool(
- context=run_context,
- handler=awaitable,
- method_name=method_name,
- **tool_args,
- )
- while True:
- try:
- resp = await asyncio.wait_for(
- anext(wrapper),
- timeout=run_context.tool_call_timeout,
- )
- if resp is not None:
- if isinstance(resp, mcp.types.CallToolResult):
- yield resp
- else:
- text_content = mcp.types.TextContent(
- type="text",
- text=str(resp),
- )
- yield mcp.types.CallToolResult(content=[text_content])
- else:
- # NOTE: Tool 在这里直接请求发送消息给用户
- # TODO: 是否需要判断 event.get_result() 是否为空?
- # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
- if res := run_context.context.event.get_result():
- if res.chain:
- try:
- await event.send(
- MessageChain(
- chain=res.chain,
- type="tool_direct_result",
- )
- )
- except Exception as e:
- logger.error(
- f"Tool 直接发送消息失败: {e}",
- exc_info=True,
- )
- yield None
- except asyncio.TimeoutError:
- raise Exception(
- f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
- )
- except StopAsyncIteration:
- break
-
- @classmethod
- async def _execute_mcp(
- cls,
- tool: FunctionTool,
- run_context: ContextWrapper[AstrAgentContext],
- **tool_args,
- ):
- res = await tool.call(run_context, **tool_args)
- if not res:
- return
- yield res
-
-
-class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
- async def on_agent_done(self, run_context, llm_response):
- # 执行事件钩子
- await call_event_hook(
- run_context.context.event,
- EventType.OnLLMResponseEvent,
- llm_response,
- )
-
- async def on_tool_end(
- self,
- run_context: ContextWrapper[AstrAgentContext],
- tool: FunctionTool[Any],
- tool_args: dict | None,
- tool_result: CallToolResult | None,
- ):
- run_context.context.event.clear_result()
-
-
-MAIN_AGENT_HOOKS = MainAgentHooks()
-
-
-async def run_agent(
- agent_runner: AgentRunner,
- max_step: int = 30,
- show_tool_use: bool = True,
-) -> AsyncGenerator[MessageChain, None]:
- step_idx = 0
- astr_event = agent_runner.run_context.context.event
- while step_idx < max_step:
- step_idx += 1
- try:
- async for resp in agent_runner.step():
- if astr_event.is_stopped():
- return
- if resp.type == "tool_call_result":
- msg_chain = resp.data["chain"]
- if msg_chain.type == "tool_direct_result":
- # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
- resp.data["chain"].type = "tool_call_result"
- await astr_event.send(resp.data["chain"])
- continue
- # 对于其他情况,暂时先不处理
- continue
- elif resp.type == "tool_call":
- if agent_runner.streaming:
- # 用来标记流式响应需要分节
- yield MessageChain(chain=[], type="break")
- if show_tool_use or astr_event.get_platform_name() == "webchat":
- resp.data["chain"].type = "tool_call"
- await astr_event.send(resp.data["chain"])
- continue
-
- if not agent_runner.streaming:
- content_typ = (
- ResultContentType.LLM_RESULT
- if resp.type == "llm_result"
- else ResultContentType.GENERAL_RESULT
- )
- astr_event.set_result(
- MessageEventResult(
- chain=resp.data["chain"].chain,
- result_content_type=content_typ,
- ),
- )
- yield
- astr_event.clear_result()
- elif resp.type == "streaming_delta":
- yield resp.data["chain"] # MessageChain
- if agent_runner.done():
- break
-
- except Exception as e:
- logger.error(traceback.format_exc())
- err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
- if agent_runner.streaming:
- yield MessageChain().message(err_msg)
- else:
- astr_event.set_result(MessageEventResult().message(err_msg))
- return
-
-
-class LLMRequestSubStage(Stage):
- async def initialize(self, ctx: PipelineContext) -> None:
- self.ctx = ctx
- conf = ctx.astrbot_config
- settings = conf["provider_settings"]
- self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
- self.provider_wake_prefix: str = settings["wake_prefix"] # str
- self.max_context_length = settings["max_context_length"] # int
- self.dequeue_context_length: int = min(
- max(1, settings["dequeue_context_length"]),
- self.max_context_length - 1,
- )
- self.streaming_response: bool = settings["streaming_response"]
- self.max_step: int = settings.get("max_agent_step", 30)
- self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
- if isinstance(self.max_step, bool): # workaround: #2622
- self.max_step = 30
- self.show_tool_use: bool = settings.get("show_tool_use_status", True)
-
- for bwp in self.bot_wake_prefixs:
- if self.provider_wake_prefix.startswith(bwp):
- logger.info(
- f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
- )
- self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
-
- self.conv_manager = ctx.plugin_manager.context.conversation_manager
-
- def _select_provider(self, event: AstrMessageEvent):
- """选择使用的 LLM 提供商"""
- sel_provider = event.get_extra("selected_provider")
- _ctx = self.ctx.plugin_manager.context
- if sel_provider and isinstance(sel_provider, str):
- provider = _ctx.get_provider_by_id(sel_provider)
- if not provider:
- logger.error(f"未找到指定的提供商: {sel_provider}。")
- return provider
-
- return _ctx.get_using_provider(umo=event.unified_msg_origin)
-
- async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
- umo = event.unified_msg_origin
- conv_mgr = self.conv_manager
-
- # 获取对话上下文
- cid = await conv_mgr.get_curr_conversation_id(umo)
- if not cid:
- cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
- conversation = await conv_mgr.get_conversation(umo, cid)
- if not conversation:
- cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
- conversation = await conv_mgr.get_conversation(umo, cid)
- if not conversation:
- raise RuntimeError("无法创建新的对话。")
- return conversation
-
- async def process(
- self,
- event: AstrMessageEvent,
- _nested: bool = False,
- ) -> None | AsyncGenerator[None, None]:
- req: ProviderRequest | None = None
-
- if not self.ctx.astrbot_config["provider_settings"]["enable"]:
- logger.debug("未启用 LLM 能力,跳过处理。")
- return
-
- # 检查会话级别的LLM启停状态
- if not SessionServiceManager.should_process_llm_request(event):
- logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
- return
-
- provider = self._select_provider(event)
- if provider is None:
- return
- if not isinstance(provider, Provider):
- logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
- return
-
- streaming_response = self.streaming_response
- if (enable_streaming := event.get_extra("enable_streaming")) is not None:
- streaming_response = bool(enable_streaming)
-
- if event.get_extra("provider_request"):
- req = event.get_extra("provider_request")
- assert isinstance(req, ProviderRequest), (
- "provider_request 必须是 ProviderRequest 类型。"
- )
-
- if req.conversation:
- req.contexts = json.loads(req.conversation.history)
-
- else:
- req = ProviderRequest(prompt="", image_urls=[])
- if sel_model := event.get_extra("selected_model"):
- req.model = sel_model
- if self.provider_wake_prefix:
- if not event.message_str.startswith(self.provider_wake_prefix):
- return
- req.prompt = event.message_str[len(self.provider_wake_prefix) :]
- # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
- # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
- for comp in event.message_obj.message:
- if isinstance(comp, Image):
- image_path = await comp.convert_to_file_path()
- req.image_urls.append(image_path)
-
- conversation = await self._get_session_conv(event)
- req.conversation = conversation
- req.contexts = json.loads(conversation.history)
-
- event.set_extra("provider_request", req)
-
- if not req.prompt and not req.image_urls:
- return
-
- # 应用知识库
- try:
- await inject_kb_context(
- umo=event.unified_msg_origin,
- p_ctx=self.ctx,
- req=req,
- )
- except Exception as e:
- logger.error(f"调用知识库时遇到问题: {e}")
-
- # 执行请求 LLM 前事件钩子。
- if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
- return
-
- if isinstance(req.contexts, str):
- req.contexts = json.loads(req.contexts)
-
- # max context length
- if (
- self.max_context_length != -1 # -1 为不限制
- and len(req.contexts) // 2 > self.max_context_length
- ):
- logger.debug("上下文长度超过限制,将截断。")
- req.contexts = req.contexts[
- -(self.max_context_length - self.dequeue_context_length + 1) * 2 :
- ]
- # 找到第一个role 为 user 的索引,确保上下文格式正确
- index = next(
- (
- i
- for i, item in enumerate(req.contexts)
- if item.get("role") == "user"
- ),
- None,
- )
- if index is not None and index > 0:
- req.contexts = req.contexts[index:]
-
- # session_id
- if not req.session_id:
- req.session_id = event.unified_msg_origin
-
- # fix messages
- req.contexts = self.fix_messages(req.contexts)
-
- # check provider modalities
- # 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
- if req.image_urls:
- provider_cfg = provider.provider_config.get("modalities", ["image"])
- if "image" not in provider_cfg:
- logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
- req.image_urls = []
- if req.func_tool:
- provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
- # 如果模型不支持工具使用,但请求中包含工具列表,则清空。
- if "tool_use" not in provider_cfg:
- logger.debug(
- f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
- )
- req.func_tool = None
- # 插件可用性设置
- if event.plugins_name is not None and req.func_tool:
- new_tool_set = ToolSet()
- for tool in req.func_tool.tools:
- mp = tool.handler_module_path
- if not mp:
- continue
- plugin = star_map.get(mp)
- if not plugin:
- continue
- if plugin.name in event.plugins_name or plugin.reserved:
- new_tool_set.add_tool(tool)
- req.func_tool = new_tool_set
-
- # 备份 req.contexts
- backup_contexts = copy.deepcopy(req.contexts)
-
- # run agent
- agent_runner = AgentRunner()
- logger.debug(
- f"handle provider[id: {provider.provider_config['id']}] request: {req}",
- )
- astr_agent_ctx = AstrAgentContext(
- provider=provider,
- first_provider_request=req,
- curr_provider_request=req,
- streaming=streaming_response,
- event=event,
- )
- await agent_runner.reset(
- provider=provider,
- request=req,
- run_context=AgentContextWrapper(
- context=astr_agent_ctx,
- tool_call_timeout=self.tool_call_timeout,
- ),
- tool_executor=FunctionToolExecutor(),
- agent_hooks=MAIN_AGENT_HOOKS,
- streaming=streaming_response,
- )
-
- if streaming_response:
- # 流式响应
- event.set_result(
- MessageEventResult()
- .set_result_content_type(ResultContentType.STREAMING_RESULT)
- .set_async_stream(
- run_agent(agent_runner, self.max_step, self.show_tool_use),
- ),
- )
- yield
- if agent_runner.done():
- if final_llm_resp := agent_runner.get_final_llm_resp():
- if final_llm_resp.completion_text:
- chain = (
- MessageChain().message(final_llm_resp.completion_text).chain
- )
- elif final_llm_resp.result_chain:
- chain = final_llm_resp.result_chain.chain
- else:
- chain = MessageChain().chain
- event.set_result(
- MessageEventResult(
- chain=chain,
- result_content_type=ResultContentType.STREAMING_FINISH,
- ),
- )
- else:
- async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
- yield
-
- # 恢复备份的 contexts
- req.contexts = backup_contexts
-
- await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
-
- # 异步处理 WebChat 特殊情况
- if event.get_platform_name() == "webchat":
- asyncio.create_task(self._handle_webchat(event, req, provider))
-
- asyncio.create_task(
- Metric.upload(
- llm_tick=1,
- model_name=agent_runner.provider.get_model(),
- provider_type=agent_runner.provider.meta().type,
- ),
- )
-
- async def _handle_webchat(
- self,
- event: AstrMessageEvent,
- req: ProviderRequest,
- prov: Provider,
- ):
- """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
- if not req.conversation:
- return
- conversation = await self.conv_manager.get_conversation(
- event.unified_msg_origin,
- req.conversation.cid,
- )
- if conversation and not req.conversation.title:
- messages = json.loads(conversation.history)
- latest_pair = messages[-2:]
- if not latest_pair:
- return
- content = latest_pair[0].get("content", "")
- if isinstance(content, list):
- # 多模态
- text_parts = []
- for item in content:
- if isinstance(item, dict):
- if item.get("type") == "text":
- text_parts.append(item.get("text", ""))
- elif item.get("type") == "image":
- text_parts.append("[图片]")
- elif isinstance(item, str):
- text_parts.append(item)
- cleaned_text = "User: " + " ".join(text_parts).strip()
- elif isinstance(content, str):
- cleaned_text = "User: " + content.strip()
- else:
- return
- logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
- llm_resp = await prov.text_chat(
- system_prompt="You are expert in summarizing user's query.",
- prompt=(
- f"Please summarize the following query of user:\n"
- f"{cleaned_text}\n"
- "Only output the summary within 10 words, DO NOT INCLUDE any other text."
- "You must use the same language as the user."
- "If you think the dialog is too short to summarize, only output a special mark: ``"
- ),
- )
- if llm_resp and llm_resp.completion_text:
- logger.debug(
- f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
- )
- title = llm_resp.completion_text.strip()
- if not title or "" in title:
- return
- await self.conv_manager.update_conversation_title(
- unified_msg_origin=event.unified_msg_origin,
- title=title,
- conversation_id=req.conversation.cid,
- )
-
- async def _save_to_history(
- self,
- event: AstrMessageEvent,
- req: ProviderRequest,
- llm_response: LLMResponse | None,
- ):
- if (
- not req
- or not req.conversation
- or not llm_response
- or llm_response.role != "assistant"
- ):
- return
-
- if not llm_response.completion_text and not req.tool_calls_result:
- logger.debug("LLM 响应为空,不保存记录。")
- return
-
- # 历史上下文
- messages = copy.deepcopy(req.contexts)
- # 这一轮对话请求的用户输入
- messages.append(await req.assemble_context())
- # 这一轮对话的 LLM 响应
- if req.tool_calls_result:
- if not isinstance(req.tool_calls_result, list):
- messages.extend(req.tool_calls_result.to_openai_messages())
- elif isinstance(req.tool_calls_result, list):
- for tcr in req.tool_calls_result:
- messages.extend(tcr.to_openai_messages())
- messages.append({"role": "assistant", "content": llm_response.completion_text})
- messages = list(filter(lambda item: "_no_save" not in item, messages))
- await self.conv_manager.update_conversation(
- event.unified_msg_origin,
- req.conversation.cid,
- history=messages,
- )
-
- def fix_messages(self, messages: list[dict]) -> list[dict]:
- """验证并且修复上下文"""
- fixed_messages = []
- for message in messages:
- if message.get("role") == "tool":
- # tool block 前面必须要有 user 和 assistant block
- if len(fixed_messages) < 2:
- # 这种情况可能是上下文被截断导致的
- # 我们直接将之前的上下文都清空
- fixed_messages = []
- else:
- fixed_messages.append(message)
- else:
- fixed_messages.append(message)
- return fixed_messages
diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py
index ff8120b16..8a79b96c9 100644
--- a/astrbot/core/pipeline/process_stage/method/star_request.py
+++ b/astrbot/core/pipeline/process_stage/method/star_request.py
@@ -16,7 +16,6 @@ from ..stage import Stage
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
- self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
self.ctx = ctx
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
async def process(
self,
event: AstrMessageEvent,
- ) -> None | AsyncGenerator[None, None]:
+ ) -> AsyncGenerator[Any, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
)
diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py
index 9f0b5f92a..076f7f12a 100644
--- a/astrbot/core/pipeline/process_stage/stage.py
+++ b/astrbot/core/pipeline/process_stage/stage.py
@@ -1,13 +1,12 @@
from collections.abc import AsyncGenerator
-from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.star.star_handler import StarHandlerMetadata
from ..context import PipelineContext
from ..stage import Stage, register_stage
-from .method.llm_request import LLMRequestSubStage
+from .method.agent_request import AgentRequestSubStage
from .method.star_request import StarRequestSubStage
@@ -17,9 +16,12 @@ class ProcessStage(Stage):
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
- self.llm_request_sub_stage = LLMRequestSubStage()
- await self.llm_request_sub_stage.initialize(ctx)
+ # initialize agent sub stage
+ self.agent_sub_stage = AgentRequestSubStage()
+ await self.agent_sub_stage.initialize(ctx)
+
+ # initialize star request sub stage
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
@@ -39,7 +41,7 @@ class ProcessStage(Stage):
# Handler 的 LLM 请求
event.set_extra("provider_request", resp)
_t = False
- async for _ in self.llm_request_sub_stage.process(event):
+ async for _ in self.agent_sub_stage.process(event):
_t = True
yield
if not _t:
@@ -58,14 +60,7 @@ class ProcessStage(Stage):
):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
- event.get_result() and not event.get_result().is_stopped()
+ event.get_result() and not event.is_stopped()
) or not event.get_result():
- # 事件没有终止传播
- provider = self.ctx.plugin_manager.context.get_using_provider()
-
- if not provider:
- logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
- return
-
- async for _ in self.llm_request_sub_stage.process(event):
+ async for _ in self.agent_sub_stage.process(event):
yield
diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py
index b1168aa0a..24e052e1e 100644
--- a/astrbot/core/pipeline/process_stage/utils.py
+++ b/astrbot/core/pipeline/process_stage/utils.py
@@ -1,23 +1,64 @@
+from pydantic import Field
+from pydantic.dataclasses import dataclass
+
from astrbot.api import logger, sp
-from astrbot.core.provider.entities import ProviderRequest
-
-from ..context import PipelineContext
+from astrbot.core.agent.run_context import ContextWrapper
+from astrbot.core.agent.tool import FunctionTool, ToolExecResult
+from astrbot.core.astr_agent_context import AstrAgentContext
+from astrbot.core.star.context import Context
-async def inject_kb_context(
+@dataclass
+class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
+ name: str = "astr_kb_search"
+ description: str = (
+ "Query the knowledge base for facts or relevant context. "
+ "Use this tool when the user's question requires factual information, "
+ "definitions, background knowledge, or previously indexed content. "
+ "Only send short keywords or a concise question as the query."
+ )
+ parameters: dict = Field(
+ default_factory=lambda: {
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "A concise keyword query for the knowledge base.",
+ },
+ },
+ "required": ["query"],
+ }
+ )
+
+ async def call(
+ self, context: ContextWrapper[AstrAgentContext], **kwargs
+ ) -> ToolExecResult:
+ query = kwargs.get("query", "")
+ if not query:
+ return "error: Query parameter is empty."
+ result = await retrieve_knowledge_base(
+ query=kwargs.get("query", ""),
+ umo=context.context.event.unified_msg_origin,
+ context=context.context.context,
+ )
+ if not result:
+ return "No relevant knowledge found."
+ return result
+
+
+async def retrieve_knowledge_base(
+ query: str,
umo: str,
- p_ctx: PipelineContext,
- req: ProviderRequest,
-) -> None:
+ context: Context,
+) -> str | None:
"""Inject knowledge base context into the provider request
Args:
umo: Unique message object (session ID)
p_ctx: Pipeline context
- req: Provider request
-
"""
- kb_mgr = p_ctx.plugin_manager.context.kb_manager
+ kb_mgr = context.kb_manager
+ config = context.get_config(umo=umo)
# 1. 优先读取会话级配置
session_config = await sp.session_get(umo, "kb_config", default={})
@@ -54,18 +95,18 @@ async def inject_kb_context(
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
else:
- kb_names = p_ctx.astrbot_config.get("kb_names", [])
- top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
+ kb_names = config.get("kb_names", [])
+ top_k = config.get("kb_final_top_k", 5)
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
- top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
+ top_k_fusion = config.get("kb_fusion_top_k", 20)
if not kb_names:
return
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
kb_context = await kb_mgr.retrieve(
- query=req.prompt,
+ query=query,
kb_names=kb_names,
top_k_fusion=top_k_fusion,
top_m_final=top_k,
@@ -78,4 +119,7 @@ async def inject_kb_context(
if formatted:
results = kb_context.get("results", [])
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
- req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"
+ return formatted
+
+
+KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool()
diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py
index f20445594..60ab168b3 100644
--- a/astrbot/core/pipeline/respond/stage.py
+++ b/astrbot/core/pipeline/respond/stage.py
@@ -10,7 +10,6 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
-from astrbot.core.utils.session_lock import session_lock_manager
from ..context import PipelineContext, call_event_hook
from ..stage import Stage, register_stage
@@ -118,7 +117,9 @@ class RespondStage(Stage):
if not self.enable_seg:
return False
- if self.only_llm_result and not event.get_result().is_llm_result():
+ if (result := event.get_result()) is None:
+ return False
+ if self.only_llm_result and not result.is_llm_result():
return False
if event.get_platform_name() in [
@@ -157,7 +158,11 @@ class RespondStage(Stage):
result = event.get_result()
if result is None:
return
+ if event.get_extra("_streaming_finished", False):
+ # prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
+ return
if result.result_content_type == ResultContentType.STREAMING_FINISH:
+ event.set_extra("_streaming_finished", True)
return
logger.info(
@@ -169,12 +174,15 @@ class RespondStage(Stage):
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
- use_fallback = self.config.get("provider_settings", {}).get(
- "streaming_segmented",
- False,
+ realtime_segmenting = (
+ self.config.get("provider_settings", {}).get(
+ "unsupported_streaming_strategy",
+ "realtime_segmenting",
+ )
+ == "realtime_segmenting"
)
logger.info(f"应用流式输出({event.get_platform_id()})")
- await event.send_streaming(result.async_stream, use_fallback)
+ await event.send_streaming(result.async_stream, realtime_segmenting)
return
if len(result.chain) > 0:
# 检查路径映射
@@ -183,7 +191,7 @@ class RespondStage(Stage):
if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file)
- event.get_result().chain[idx] = component
+ result.chain[idx] = component
# 检查消息链是否为空
try:
@@ -218,21 +226,20 @@ class RespondStage(Stage):
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
)
return
- async with session_lock_manager.acquire_lock(event.unified_msg_origin):
- for comp in result.chain:
- i = await self._calc_comp_interval(comp)
- await asyncio.sleep(i)
- try:
- if comp.type in need_separately:
- await event.send(MessageChain([comp]))
- else:
- await event.send(MessageChain([*header_comps, comp]))
- header_comps.clear()
- except Exception as e:
- logger.error(
- f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
- exc_info=True,
- )
+ for comp in result.chain:
+ i = await self._calc_comp_interval(comp)
+ await asyncio.sleep(i)
+ try:
+ if comp.type in need_separately:
+ await event.send(MessageChain([comp]))
+ else:
+ await event.send(MessageChain([*header_comps, comp]))
+ header_comps.clear()
+ except Exception as e:
+ logger.error(
+ f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
+ exc_info=True,
+ )
else:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}
diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py
index 5dfb52f6f..7647ef022 100644
--- a/astrbot/core/pipeline/result_decorate/stage.py
+++ b/astrbot/core/pipeline/result_decorate/stage.py
@@ -1,3 +1,4 @@
+import random
import re
import time
import traceback
@@ -6,6 +7,7 @@ from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
+from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager
@@ -41,6 +43,18 @@ class ResultDecorateStage(Stage):
"forward_threshold"
]
+ trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
+ "trigger_probability",
+ 1,
+ )
+ try:
+ self.tts_trigger_probability = max(
+ 0.0,
+ min(float(trigger_probability), 1.0),
+ )
+ except (TypeError, ValueError):
+ self.tts_trigger_probability = 1.0
+
# 分段回复
self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][
@@ -53,7 +67,22 @@ class ResultDecorateStage(Stage):
self.only_llm_result = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["only_llm_result"]
+ self.split_mode = ctx.astrbot_config["platform_settings"][
+ "segmented_reply"
+ ].get("split_mode", "regex")
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
+ self.split_words = ctx.astrbot_config["platform_settings"][
+ "segmented_reply"
+ ].get("split_words", ["。", "?", "!", "~", "…"])
+ if self.split_words:
+ escaped_words = sorted(
+ [re.escape(word) for word in self.split_words], key=len, reverse=True
+ )
+ self.split_words_pattern = re.compile(
+ f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
+ )
+ else:
+ self.split_words_pattern = None
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["content_cleanup_rule"]
@@ -69,6 +98,28 @@ class ResultDecorateStage(Stage):
self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx)
+ def _split_text_by_words(self, text: str) -> list[str]:
+ """使用分段词列表分段文本"""
+ if not self.split_words_pattern:
+ return [text]
+
+ segments = self.split_words_pattern.findall(text)
+ result = []
+ for seg in segments:
+ if isinstance(seg, tuple):
+ content = seg[0]
+ if not isinstance(content, str):
+ continue
+ for word in self.split_words:
+ if content.endswith(word):
+ content = content[: -len(word)]
+ break
+ if content.strip():
+ result.append(content)
+ elif seg and seg.strip():
+ result.append(seg)
+ return result if result else [text]
+
async def process(
self,
event: AstrMessageEvent,
@@ -93,11 +144,13 @@ class ResultDecorateStage(Stage):
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
- async for _ in self.content_safe_check_stage.process(
- event,
- check_text=text,
- ):
- yield
+
+ if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
+ async for _ in self.content_safe_check_stage.process(
+ event,
+ check_text=text,
+ ):
+ yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -114,7 +167,8 @@ class ResultDecorateStage(Stage):
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
)
await handler.handler(event)
- if event.get_result() is None or not event.get_result().chain:
+
+ if (result := event.get_result()) is None or not result.chain:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
)
@@ -161,11 +215,27 @@ class ResultDecorateStage(Stage):
# 不分段回复
new_chain.append(comp)
continue
- split_response = re.findall(
- self.regex,
- comp.text,
- re.DOTALL | re.MULTILINE,
- )
+
+ # 根据 split_mode 选择分段方式
+ if self.split_mode == "words":
+ split_response = self._split_text_by_words(comp.text)
+ else: # regex 模式
+ try:
+ split_response = re.findall(
+ self.regex,
+ comp.text,
+ re.DOTALL | re.MULTILINE,
+ )
+ except re.error:
+ logger.error(
+ f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
+ )
+ split_response = re.findall(
+ r".*?[。?!~…]+|.+$",
+ comp.text,
+ re.DOTALL | re.MULTILINE,
+ )
+
if not split_response:
new_chain.append(comp)
continue
@@ -189,7 +259,14 @@ class ResultDecorateStage(Stage):
and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event)
):
- if not tts_provider:
+ should_tts = self.tts_trigger_probability >= 1.0 or (
+ self.tts_trigger_probability > 0.0
+ and random.random() <= self.tts_trigger_probability
+ )
+
+ if not should_tts:
+ logger.debug("跳过 TTS:触发概率未命中。")
+ elif not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
)
diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py
index 5c461a1e1..5fb3034f5 100644
--- a/astrbot/core/pipeline/scheduler.py
+++ b/astrbot/core/pipeline/scheduler.py
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent
+from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
+from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
+ WecomAIBotMessageEvent,
+)
from . import STAGES_ORDER
from .context import PipelineContext
@@ -78,7 +82,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
- if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
+ if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
await event.send(None)
logger.debug("pipeline 执行完毕。")
diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py
index 814919115..1efda7c84 100644
--- a/astrbot/core/pipeline/waking_check/stage.py
+++ b/astrbot/core/pipeline/waking_check/stage.py
@@ -50,6 +50,9 @@ class WakingCheckStage(Stage):
"ignore_at_all",
False,
)
+ self.disable_builtin_commands = self.ctx.astrbot_config.get(
+ "disable_builtin_commands", False
+ )
async def process(
self,
@@ -131,6 +134,13 @@ class WakingCheckStage(Stage):
EventType.AdapterMessageEvent,
plugins_name=event.plugins_name,
):
+ if (
+ self.disable_builtin_commands
+ and handler.handler_module_path == "packages.builtin_commands.main"
+ ):
+ logger.debug("skipping builtin command")
+ continue
+
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py
index 6402aeaed..f6eda07a9 100644
--- a/astrbot/core/platform/astr_message_event.py
+++ b/astrbot/core/platform/astr_message_event.py
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
- return self.message_obj.sender.nickname
+ if isinstance(self.message_obj.sender.nickname, str):
+ return self.message_obj.sender.nickname
+ return ""
def set_extra(self, key, value):
"""设置额外的信息。"""
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
"""
self.call_llm = call_llm
- def get_result(self) -> MessageEventResult:
+ def get_result(self) -> MessageEventResult | None:
"""获取消息事件的结果。"""
return self._result
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
self,
prompt: str,
func_tool_manager=None,
- session_id: str = None,
+ session_id: str = "",
image_urls: list[str] | None = None,
contexts: list | None = None,
system_prompt: str = "",
diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py
index 0ada18506..253963322 100644
--- a/astrbot/core/platform/astrbot_message.py
+++ b/astrbot/core/platform/astrbot_message.py
@@ -54,7 +54,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
- group: Group # 群组
+ group: Group | None # 群组
sender: MessageMember # 发送者
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -78,7 +78,7 @@ class AstrBotMessage:
return ""
@group_id.setter
- def group_id(self, value: str):
+ def group_id(self, value: str | None):
"""设置 group_id"""
if value:
if self.group:
diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py
index 9ff892025..f4313f642 100644
--- a/astrbot/core/platform/manager.py
+++ b/astrbot/core/platform/manager.py
@@ -5,8 +5,9 @@ from asyncio import Queue
from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
+from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
-from .platform import Platform
+from .platform import Platform, PlatformStatus
from .register import platform_cls_map
from .sources.webchat.webchat_adapter import WebChatAdapter
@@ -16,8 +17,9 @@ class PlatformManager:
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""
- self._inst_map = {}
+ self._inst_map: dict[str, dict] = {}
+ self.astrbot_config = config
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
@@ -29,6 +31,8 @@ class PlatformManager:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
try:
+ if ensure_platform_webhook_config(platform):
+ self.astrbot_config.save_config()
await self.load_platform(platform)
except Exception as e:
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
@@ -37,7 +41,10 @@ class PlatformManager:
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
asyncio.create_task(
- self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")),
+ self._task_wrapper(
+ asyncio.create_task(webchat_inst.run(), name="webchat"),
+ platform=webchat_inst,
+ ),
)
async def load_platform(self, platform_config: dict):
@@ -107,7 +114,7 @@ class PlatformManager:
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
- f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。",
+ f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
)
except Exception as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")
@@ -131,6 +138,7 @@ class PlatformManager:
inst.run(),
name=f"platform_{platform_config['type']}_{platform_config['id']}",
),
+ platform=inst,
),
)
handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -145,17 +153,28 @@ class PlatformManager:
except Exception:
logger.error(traceback.format_exc())
- async def _task_wrapper(self, task: asyncio.Task):
+ async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
+ # 设置平台状态为运行中
+ if platform:
+ platform.status = PlatformStatus.RUNNING
+
try:
await task
except asyncio.CancelledError:
- pass
+ if platform:
+ platform.status = PlatformStatus.STOPPED
except Exception as e:
+ error_msg = str(e)
+ tb_str = traceback.format_exc()
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
- for line in traceback.format_exc().split("\n"):
+ for line in tb_str.split("\n"):
logger.error(f"| {line}")
logger.error("-------")
+ # 记录错误到平台实例
+ if platform:
+ platform.record_error(error_msg, tb_str)
+
async def reload(self, platform_config: dict):
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
@@ -172,9 +191,9 @@ class PlatformManager:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
# client_id = self._inst_map.pop(platform_id, None)
- info = self._inst_map.pop(platform_id, None)
+ info = self._inst_map.pop(platform_id)
client_id = info["client_id"]
- inst = info["inst"]
+ inst: Platform = info["inst"]
try:
self.platform_insts.remove(
next(
@@ -196,3 +215,46 @@ class PlatformManager:
def get_insts(self):
return self.platform_insts
+
+ def get_all_stats(self) -> dict:
+ """获取所有平台的统计信息
+
+ Returns:
+ 包含所有平台统计信息的字典
+ """
+ stats_list = []
+ total_errors = 0
+ running_count = 0
+ error_count = 0
+
+ for inst in self.platform_insts:
+ try:
+ stat = inst.get_stats()
+ stats_list.append(stat)
+ total_errors += stat.get("error_count", 0)
+ if stat.get("status") == PlatformStatus.RUNNING.value:
+ running_count += 1
+ elif stat.get("status") == PlatformStatus.ERROR.value:
+ error_count += 1
+ except Exception as e:
+ # 如果获取统计信息失败,记录基本信息
+ logger.warning(f"获取平台统计信息失败: {e}")
+ stats_list.append(
+ {
+ "id": getattr(inst, "config", {}).get("id", "unknown"),
+ "type": "unknown",
+ "status": "unknown",
+ "error_count": 0,
+ "last_error": None,
+ }
+ )
+
+ return {
+ "platforms": stats_list,
+ "summary": {
+ "total": len(stats_list),
+ "running": running_count,
+ "error": error_count,
+ "total_errors": total_errors,
+ },
+ }
diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py
index 3f36e17f3..c2e55fb63 100644
--- a/astrbot/core/platform/platform.py
+++ b/astrbot/core/platform/platform.py
@@ -1,7 +1,10 @@
import abc
import uuid
from asyncio import Queue
-from collections.abc import Awaitable
+from collections.abc import Coroutine
+from dataclasses import dataclass, field
+from datetime import datetime
+from enum import Enum
from typing import Any
from astrbot.core.message.message_event_result import MessageChain
@@ -12,15 +15,100 @@ from .message_session import MessageSesion
from .platform_metadata import PlatformMetadata
+class PlatformStatus(Enum):
+ """平台运行状态"""
+
+ PENDING = "pending" # 待启动
+ RUNNING = "running" # 运行中
+ ERROR = "error" # 发生错误
+ STOPPED = "stopped" # 已停止
+
+
+@dataclass
+class PlatformError:
+ """平台错误信息"""
+
+ message: str
+ timestamp: datetime = field(default_factory=datetime.now)
+ traceback: str | None = None
+
+
class Platform(abc.ABC):
- def __init__(self, event_queue: Queue):
+ def __init__(self, config: dict, event_queue: Queue):
super().__init__()
+ # 平台配置
+ self.config = config
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
self._event_queue = event_queue
self.client_self_id = uuid.uuid4().hex
+ # 平台运行状态
+ self._status: PlatformStatus = PlatformStatus.PENDING
+ self._errors: list[PlatformError] = []
+ self._started_at: datetime | None = None
+
+ @property
+ def status(self) -> PlatformStatus:
+ """获取平台运行状态"""
+ return self._status
+
+ @status.setter
+ def status(self, value: PlatformStatus):
+ """设置平台运行状态"""
+ self._status = value
+ if value == PlatformStatus.RUNNING and self._started_at is None:
+ self._started_at = datetime.now()
+
+ @property
+ def errors(self) -> list[PlatformError]:
+ """获取错误列表"""
+ return self._errors
+
+ @property
+ def last_error(self) -> PlatformError | None:
+ """获取最近的错误"""
+ return self._errors[-1] if self._errors else None
+
+ def record_error(self, message: str, traceback_str: str | None = None):
+ """记录一个错误"""
+ self._errors.append(PlatformError(message=message, traceback=traceback_str))
+ self._status = PlatformStatus.ERROR
+
+ def clear_errors(self):
+ """清除错误记录"""
+ self._errors.clear()
+ if self._status == PlatformStatus.ERROR:
+ self._status = PlatformStatus.RUNNING
+
+ def unified_webhook(self) -> bool:
+ """是否正在使用统一 Webhook 模式"""
+ return bool(
+ self.config.get("unified_webhook_mode", False)
+ and self.config.get("webhook_uuid")
+ )
+
+ def get_stats(self) -> dict:
+ """获取平台统计信息"""
+ meta = self.meta()
+ return {
+ "id": meta.id or self.config.get("id"),
+ "type": meta.name,
+ "display_name": meta.adapter_display_name or meta.name,
+ "status": self._status.value,
+ "started_at": self._started_at.isoformat() if self._started_at else None,
+ "error_count": len(self._errors),
+ "last_error": {
+ "message": self.last_error.message,
+ "timestamp": self.last_error.timestamp.isoformat(),
+ "traceback": self.last_error.traceback,
+ }
+ if self.last_error
+ else None,
+ "unified_webhook": self.unified_webhook(),
+ }
+
@abc.abstractmethod
- def run(self) -> Awaitable[Any]:
+ def run(self) -> Coroutine[Any, Any, None]:
"""得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError
@@ -36,7 +124,7 @@ class Platform(abc.ABC):
self,
session: MessageSesion,
message_chain: MessageChain,
- ) -> Awaitable[Any]:
+ ) -> None:
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法。
@@ -49,3 +137,20 @@ class Platform(abc.ABC):
def get_client(self):
"""获取平台的客户端对象。"""
+
+ async def webhook_callback(self, request: Any) -> Any:
+ """统一 Webhook 回调入口。
+
+ 支持统一 Webhook 模式的平台需要实现此方法。
+ 当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 响应内容,格式取决于具体平台的要求
+
+ Raises:
+ NotImplementedError: 平台未实现统一 Webhook 模式
+ """
+ raise NotImplementedError(f"平台 {self.meta().name} 未实现统一 Webhook 模式")
diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py
index d75811245..06455aac4 100644
--- a/astrbot/core/platform/platform_metadata.py
+++ b/astrbot/core/platform/platform_metadata.py
@@ -7,7 +7,7 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
- id: str | None = None
+ id: str
"""平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict | None = None
@@ -16,3 +16,6 @@ class PlatformMetadata:
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str | None = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
+
+ support_streaming_message: bool = True
+ """平台是否支持真实流式传输"""
diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py
index 0c6267492..5f550ecd1 100644
--- a/astrbot/core/platform/register.py
+++ b/astrbot/core/platform/register.py
@@ -14,6 +14,7 @@ def register_platform_adapter(
default_config_tmpl: dict | None = None,
adapter_display_name: str | None = None,
logo_path: str | None = None,
+ support_streaming_message: bool = True,
):
"""用于注册平台适配器的带参装饰器。
@@ -39,9 +40,11 @@ def register_platform_adapter(
pm = PlatformMetadata(
name=adapter_name,
description=desc,
+ id=adapter_name,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
+ support_streaming_message=support_streaming_message,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py
index ce8fd56df..293b462d3 100644
--- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py
+++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
bot: CQHttp,
event: Event | None,
is_group: bool,
- session_id: str,
+ session_id: str | None,
messages: list[dict],
):
# session_id 必须是纯数字字符串
- session_id = int(session_id) if session_id.isdigit() else None
+ session_id_int = (
+ int(session_id) if session_id and session_id.isdigit() else None
+ )
- if is_group and isinstance(session_id, int):
- await bot.send_group_msg(group_id=session_id, message=messages)
- elif not is_group and isinstance(session_id, int):
- await bot.send_private_msg(user_id=session_id, message=messages)
+ if is_group and isinstance(session_id_int, int):
+ await bot.send_group_msg(group_id=session_id_int, message=messages)
+ elif not is_group and isinstance(session_id_int, int):
+ await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages)
else:
diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
index 81deead13..52dd21d56 100644
--- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
+++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
@@ -4,7 +4,7 @@ import logging
import time
import uuid
from collections.abc import Awaitable
-from typing import Any
+from typing import Any, cast
from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed
@@ -29,6 +29,7 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
@register_platform_adapter(
"aiocqhttp",
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
+ support_streaming_message=False,
)
class AiocqhttpAdapter(Platform):
def __init__(
@@ -37,9 +38,8 @@ class AiocqhttpAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
+ super().__init__(platform_config, event_queue)
- self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.host = platform_config["ws_reverse_host"]
@@ -48,7 +48,8 @@ class AiocqhttpAdapter(Platform):
self.metadata = PlatformMetadata(
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
- id=self.config.get("id"),
+ id=cast(str, self.config.get("id")),
+ support_streaming_message=False,
)
self.bot = CQHttp(
@@ -126,7 +127,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
- abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
+ abm.sender = MessageMember(
+ user_id=str(event.user_id), nickname=str(event.user_id)
+ )
abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"):
abm.type = MessageType.GROUP_MESSAGE
@@ -152,7 +155,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 通知类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
- abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
+ abm.sender = MessageMember(
+ user_id=str(event.user_id), nickname=str(event.user_id)
+ )
abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"):
abm.group_id = str(event.group_id)
@@ -191,6 +196,7 @@ class AiocqhttpAdapter(Platform):
@param event: 事件对象
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
+ assert event.sender is not None
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
@@ -200,6 +206,7 @@ class AiocqhttpAdapter(Platform):
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
+ abm.group = Group(str(event.group_id))
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
@@ -225,7 +232,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
- return None
+ raise ValueError(err)
# 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
@@ -244,7 +251,13 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange
logger.info("guessing lagrange")
- file_name = m["data"].get("file_name", "file")
+ # 检查多个可能的文件名字段
+ file_name = (
+ m["data"].get("file_name", "")
+ or m["data"].get("name", "")
+ or m["data"].get("file", "")
+ or "file"
+ )
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:
@@ -263,7 +276,14 @@ class AiocqhttpAdapter(Platform):
)
if ret and "url" in ret:
file_url = ret["url"] # https
- a = File(name="", url=file_url)
+ # 优先从 API 返回值获取文件名,其次从原始消息数据获取
+ file_name = (
+ ret.get("file_name", "")
+ or ret.get("name", "")
+ or m["data"].get("file", "")
+ or m["data"].get("file_name", "")
+ )
+ a = File(name=file_name, url=file_url)
abm.message.append(a)
else:
logger.error(f"获取文件失败: {ret}")
@@ -401,7 +421,7 @@ class AiocqhttpAdapter(Platform):
async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait()
- logger.info("aiocqhttp 适配器已被优雅地关闭")
+ logger.info("aiocqhttp 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
index 43d231771..6f9e25df4 100644
--- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
+++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
@@ -2,6 +2,7 @@ import asyncio
import os
import threading
import uuid
+from typing import cast
import aiohttp
import dingtalk_stream
@@ -37,7 +38,9 @@ class MyEventHandler(dingtalk_stream.EventHandler):
return AckMessage.STATUS_OK, "OK"
-@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
+@register_platform_adapter(
+ "dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
+)
class DingtalkPlatformAdapter(Platform):
def __init__(
self,
@@ -45,21 +48,21 @@ class DingtalkPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
-
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.unique_session = platform_settings["unique_session"]
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
+ outer_self = self
+
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
- async def process(self_, message: dingtalk_stream.CallbackMessage):
+ async def process(self, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
- abm = await self.convert_msg(im)
- await self.handle_msg(abm)
+ abm = await outer_self.convert_msg(im)
+ await outer_self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK"
@@ -73,6 +76,15 @@ class DingtalkPlatformAdapter(Platform):
self.client,
)
self.client_ = client # 用于 websockets 的 client
+ self._shutdown_event: threading.Event | None = None
+
+ def _id_to_sid(self, dingtalk_id: str | None) -> str:
+ if not dingtalk_id:
+ return dingtalk_id or "unknown"
+ prefix = "$:LWCP_v1:$"
+ if dingtalk_id.startswith(prefix):
+ return dingtalk_id[len(prefix) :]
+ return dingtalk_id or "unknown"
async def send_by_session(
self,
@@ -85,7 +97,8 @@ class DingtalkPlatformAdapter(Platform):
return PlatformMetadata(
name="dingtalk",
description="钉钉机器人官方 API 适配器",
- id=self.config.get("id"),
+ id=cast(str, self.config.get("id")),
+ support_streaming_message=False,
)
async def convert_msg(
@@ -95,26 +108,26 @@ class DingtalkPlatformAdapter(Platform):
abm = AstrBotMessage()
abm.message = []
abm.message_str = ""
- abm.timestamp = int(message.create_at / 1000)
+ abm.timestamp = int(cast(int, message.create_at) / 1000)
abm.type = (
MessageType.GROUP_MESSAGE
if message.conversation_type == "2"
else MessageType.FRIEND_MESSAGE
)
abm.sender = MessageMember(
- user_id=message.sender_id,
+ user_id=self._id_to_sid(message.sender_id),
nickname=message.sender_nick,
)
- abm.self_id = message.chatbot_user_id
- abm.message_id = message.message_id
+ abm.self_id = self._id_to_sid(message.chatbot_user_id)
+ abm.message_id = cast(str, message.message_id)
abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE:
# 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
if message.at_users:
for user in message.at_users:
- if user.dingtalk_id:
- abm.message.append(At(qq=user.dingtalk_id))
+ if id := self._id_to_sid(user.dingtalk_id):
+ abm.message.append(At(qq=id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id
@@ -123,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
else:
abm.session_id = abm.sender.user_id
- message_type: str = message.message_type
+ message_type: str = cast(str, message.message_type)
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "richText":
- rtc: dingtalk_stream.RichTextContent = message.rich_text_content
- contents: list[dict] = rtc.rich_text_list
+ rtc: dingtalk_stream.RichTextContent = cast(
+ dingtalk_stream.RichTextContent, message.rich_text_content
+ )
+ contents: list[dict] = cast(list[dict], rtc.rich_text_list)
for content in contents:
plains = ""
if "text" in content:
@@ -139,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file(
content["downloadCode"],
- message.robot_code,
+ cast(str, message.robot_code),
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
@@ -184,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
)
- return None
+ return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
@@ -204,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
)
- return None
+ return ""
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
@@ -230,7 +245,7 @@ class DingtalkPlatformAdapter(Platform):
task.result()
except Exception as e:
if "Graceful shutdown" in str(e):
- logger.info("钉钉适配器已被优雅地关闭")
+ logger.info("钉钉适配器已被关闭")
return
logger.error(f"钉钉机器人启动失败: {e}")
@@ -239,11 +254,13 @@ class DingtalkPlatformAdapter(Platform):
async def terminate(self):
def monkey_patch_close():
- raise Exception("Graceful shutdown")
+ raise KeyboardInterrupt("Graceful shutdown")
- self.client_.open_connection = monkey_patch_close
- await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
- self._shutdown_event.set()
+ if self.client_.websocket is not None:
+ self.client_.open_connection = monkey_patch_close
+ await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
+ if self._shutdown_event is not None:
+ self._shutdown_event.set()
def get_client(self):
return self.client
diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py
index a1cd9c1aa..d520189d8 100644
--- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py
+++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py
@@ -1,4 +1,5 @@
import asyncio
+from typing import cast
import dingtalk_stream
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
segment.text,
segment.text,
- self.message_obj.raw_message,
+ cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
"😄",
markdown_str,
- self.message_obj.raw_message,
+ cast(
+ dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
+ ),
)
logger.debug(f"send image: {ret}")
diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py
index 5d29e3429..ac0610f2a 100644
--- a/astrbot/core/platform/sources/discord/client.py
+++ b/astrbot/core/platform/sources/discord/client.py
@@ -1,4 +1,5 @@
import sys
+from collections.abc import Awaitable, Callable
import discord
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
super().__init__(intents=intents, proxy=proxy)
# 回调函数
- self.on_message_received = None
- self.on_ready_once_callback = None
+ self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
+ self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False
- @override
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
+ if self.user is None:
+ logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
+ return
+
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
+ if self.user is None:
+ raise RuntimeError("Bot is not ready: self.user is None")
+
is_mentioned = self.user in message.mentions
return {
"message": message,
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
+ if self.user is None:
+ raise RuntimeError("Bot is not ready: self.user is None")
+
+ if interaction.user is None:
+ raise ValueError("Interaction received without a valid user")
+
return {
"interaction": interaction,
"bot_id": str(self.user.id),
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
"type": "interaction",
}
- @override
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py
index d3e69e763..f875652a0 100644
--- a/astrbot/core/platform/sources/discord/components.py
+++ b/astrbot/core/platform/sources/discord/components.py
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
def __init__(
self,
- components: list[BaseMessageComponent] = None,
- timeout: float = None,
+ components: list[BaseMessageComponent] | None = None,
+ timeout: float | None = None,
):
self.components = components or []
self.timeout = timeout
diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py
index 2752f3a9b..50aa0fe6f 100644
--- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py
+++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py
@@ -1,10 +1,10 @@
import asyncio
import re
import sys
-from typing import Any
+from typing import Any, cast
import discord
-from discord.abc import Messageable
+from discord.abc import GuildChannel, Messageable, PrivateChannel
from discord.channel import DMChannel
from astrbot import logger
@@ -34,7 +34,9 @@ else:
# 注册平台适配器
-@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
+@register_platform_adapter(
+ "discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
+)
class DiscordPlatformAdapter(Platform):
def __init__(
self,
@@ -42,10 +44,9 @@ class DiscordPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.settings = platform_settings
- self.client_self_id = None
+ self.client_self_id: str | None = None
self.registered_handlers = []
# 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True)
@@ -61,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
message_chain: MessageChain,
):
"""通过会话发送消息"""
+ if self.client.user is None:
+ logger.error(
+ "[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
+ )
+ return
+
# 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage()
if "_" in session.session_id:
@@ -88,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
user_id=str(self.client_self_id),
nickname=self.client.user.display_name,
)
- message_obj.self_id = self.client_self_id
+ message_obj.self_id = cast(str, self.client_self_id)
message_obj.session_id = session.session_id
message_obj.message = message_chain.chain
@@ -109,8 +116,9 @@ class DiscordPlatformAdapter(Platform):
return PlatformMetadata(
"discord",
"Discord 适配器",
- id=self.config.get("id"),
+ id=cast(str, self.config.get("id")),
default_config_tmpl=self.config,
+ support_streaming_message=False,
)
@override
@@ -158,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
def _get_message_type(
self,
- channel: Messageable,
+ channel: Messageable | GuildChannel | PrivateChannel,
guild_id: int | None = None,
) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型"""
@@ -168,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE
- def _get_channel_id(self, channel: Messageable) -> str:
+ def _get_channel_id(
+ self, channel: Messageable | GuildChannel | PrivateChannel
+ ) -> str:
"""根据 channel 对象获取ID"""
return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
- message: discord.Message = data["message"]
+ message = data["message"]
content = message.content
@@ -231,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = message_chain
abm.raw_message = message
- abm.self_id = self.client_self_id
+ abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(message.channel.id)
abm.message_id = str(message.id)
return abm
@@ -252,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
interaction_followup_webhook=followup_webhook,
)
+ if self.client.user is None:
+ logger.error(
+ "[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
+ )
+ return
+
# 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None
+ # 1. 优先处理斜杠指令
+ if is_slash_command:
+ message_event.is_wake = True
+ message_event.is_at_or_wake_command = True
+ self.commit_event(message_event)
+ return
+
+ # 2. 处理普通消息(提及检测)
+ # 确保 raw_message 是 discord.Message 类型,以便静态检查通过
+ raw_message = message.raw_message
+ if not isinstance(raw_message, discord.Message):
+ logger.warning(
+ f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
+ )
+ return
+
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
is_mention = False
+
# User Mention
- if (
- self.client
- and self.client.user
- and hasattr(message.raw_message, "mentions")
- ):
- if self.client.user in message.raw_message.mentions:
- is_mention = True
+ # 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
+ if self.client.user in raw_message.mentions:
+ is_mention = True
+
# Role Mention(Bot 拥有的角色被提及)
- if not is_mention and hasattr(message.raw_message, "role_mentions"):
+ if not is_mention and raw_message.role_mentions:
bot_member = None
- if hasattr(message.raw_message, "guild") and message.raw_message.guild:
+ if raw_message.guild:
try:
- bot_member = message.raw_message.guild.get_member(
+ bot_member = raw_message.guild.get_member(
self.client.user.id,
)
except Exception:
bot_member = None
if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles)
- mentioned_roles = set(message.raw_message.role_mentions)
+ mentioned_roles = set(raw_message.role_mentions)
if (
bot_roles
and mentioned_roles
@@ -285,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
):
is_mention = True
- # 如果是斜杠指令或被@的消息,设置为唤醒状态
- if is_slash_command or is_mention:
+ # 如果是被@的消息,设置为唤醒状态
+ if is_mention:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
@@ -422,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction
- abm.self_id = self.client_self_id
+ abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id)
@@ -435,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
def _extract_command_info(
event_filter: Any,
handler_metadata: StarHandlerMetadata,
- ) -> tuple[str, str, CommandFilter] | None:
+ ) -> tuple[str, str, CommandFilter | None] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
# is_group = False
diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py
index 82eb9f144..053018225 100644
--- a/astrbot/core/platform/sources/discord/discord_platform_event.py
+++ b/astrbot/core/platform/sources/discord/discord_platform_event.py
@@ -4,8 +4,10 @@ import binascii
from collections.abc import AsyncGenerator
from io import BytesIO
from pathlib import Path
+from typing import cast
import discord
+from discord.types.interactions import ComponentInteractionData
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
channel = await self._get_channel()
if not channel:
return
+ if not isinstance(channel, discord.abc.Messageable):
+ logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
+ return
await channel.send(**kwargs)
except Exception as e:
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
- async def _get_channel(self) -> discord.abc.Messageable | None:
+ async def _get_channel(
+ self,
+ ) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
"""获取当前事件对应的频道对象"""
try:
channel_id = int(self.session_id)
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
async def _parse_to_discord(
self,
message: MessageChain,
- ) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
+ ) -> tuple[
+ str,
+ list[discord.File],
+ discord.ui.View | None,
+ list[discord.Embed],
+ str | int | None,
+ ]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
content_parts = []
files = []
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"add_reaction",
):
- await self.message_obj.raw_message.add_reaction(emoji)
+ await cast(discord.Message, self.message_obj.raw_message).add_reaction(
+ emoji
+ )
except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}")
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
- and self.message_obj.raw_message.type
+ and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.application_command
)
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
- and self.message_obj.raw_message.type == discord.InteractionType.component
+ and cast(discord.Interaction, self.message_obj.raw_message).type
+ == discord.InteractionType.component
)
def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id"""
if self.is_button_interaction():
try:
- return self.message_obj.raw_message.data.get("custom_id", "")
+ return cast(
+ ComponentInteractionData,
+ cast(discord.Interaction, self.message_obj.raw_message).data,
+ ).get("custom_id", "")
except Exception:
pass
return ""
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
):
return any(
mention.id == int(self.message_obj.self_id)
- for mention in self.message_obj.raw_message.mentions
+ for mention in cast(
+ discord.Message, self.message_obj.raw_message
+ ).mentions
)
return False
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"clean_content",
):
- return self.message_obj.raw_message.clean_content
+ return cast(discord.Message, self.message_obj.raw_message).clean_content
return self.message_str
diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py
index b59dbaca4..08df1f359 100644
--- a/astrbot/core/platform/sources/lark/lark_adapter.py
+++ b/astrbot/core/platform/sources/lark/lark_adapter.py
@@ -2,10 +2,17 @@ import asyncio
import base64
import json
import re
+import time
import uuid
+from typing import Any, cast
import lark_oapi as lark
-from lark_oapi.api.im.v1 import *
+from lark_oapi.api.im.v1 import (
+ CreateMessageRequest,
+ CreateMessageRequestBody,
+ GetMessageResourceRequest,
+)
+from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
import astrbot.api.message_components as Comp
from astrbot import logger
@@ -18,12 +25,16 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent
+from .server import LarkWebhookServer
-@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
+@register_platform_adapter(
+ "lark", "飞书机器人官方 API 适配器", support_streaming_message=False
+)
class LarkPlatformAdapter(Platform):
def __init__(
self,
@@ -31,9 +42,7 @@ class LarkPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
-
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.unique_session = platform_settings["unique_session"]
@@ -42,9 +51,13 @@ class LarkPlatformAdapter(Platform):
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
self.bot_name = platform_config.get("lark_bot_name", "astrbot")
+ # socket or webhook
+ self.connection_mode = platform_config.get("lark_connection_mode", "socket")
+
if not self.bot_name:
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
+ # 初始化 WebSocket 长连接相关配置
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
await self.convert_msg(event)
@@ -57,6 +70,8 @@ class LarkPlatformAdapter(Platform):
.build()
)
+ self.do_v2_msg_event = do_v2_msg_event
+
self.client = lark.ws.Client(
app_id=self.appid,
app_secret=self.appsecret,
@@ -66,14 +81,56 @@ class LarkPlatformAdapter(Platform):
)
self.lark_api = (
- lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
+ lark.Client.builder()
+ .app_id(self.appid)
+ .app_secret(self.appsecret)
+ .log_level(lark.LogLevel.ERROR)
+ .domain(self.domain)
+ .build()
)
+ self.webhook_server = None
+ if self.connection_mode == "webhook":
+ self.webhook_server = LarkWebhookServer(platform_config, event_queue)
+ self.webhook_server.set_callback(self.handle_webhook_event)
+
+ self.event_id_timestamps: dict[str, float] = {}
+
+ def _clean_expired_events(self):
+ """清理超过 30 分钟的事件记录"""
+ current_time = time.time()
+ expired_keys = [
+ event_id
+ for event_id, timestamp in self.event_id_timestamps.items()
+ if current_time - timestamp > 1800
+ ]
+ for event_id in expired_keys:
+ del self.event_id_timestamps[event_id]
+
+ def _is_duplicate_event(self, event_id: str) -> bool:
+ """检查事件是否重复
+
+ Args:
+ event_id: 事件ID
+
+ Returns:
+ True 表示重复事件,False 表示新事件
+ """
+ self._clean_expired_events()
+ if event_id in self.event_id_timestamps:
+ return True
+ self.event_id_timestamps[event_id] = time.time()
+ return False
+
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
+ if self.lark_api.im is None:
+ logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
+ return
+
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = {
"zh_cn": {
@@ -114,13 +171,25 @@ class LarkPlatformAdapter(Platform):
return PlatformMetadata(
name="lark",
description="飞书机器人官方 API 适配器",
- id=self.config.get("id"),
+ id=cast(str, self.config.get("id")),
+ support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
+ if event.event is None:
+ logger.debug("[Lark] 收到空事件(event.event is None)")
+ return
message = event.event.message
+ if message is None:
+ logger.debug("[Lark] 事件中没有消息体(message is None)")
+ return
+
abm = AstrBotMessage()
- abm.timestamp = int(message.create_time) / 1000
+
+ if message.create_time:
+ abm.timestamp = int(message.create_time) // 1000
+ else:
+ abm.timestamp = int(time.time())
abm.message = []
abm.type = (
MessageType.GROUP_MESSAGE
@@ -135,14 +204,28 @@ class LarkPlatformAdapter(Platform):
at_list = {}
if message.mentions:
for m in message.mentions:
- at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
- if m.name == self.bot_name:
- abm.self_id = m.id.open_id
+ if m.id is None:
+ continue
+ # 飞书 open_id 可能是 None,这里做个防护
+ open_id = m.id.open_id if m.id.open_id else ""
+ at_list[m.key] = Comp.At(qq=open_id, name=m.name)
- content_json_b = json.loads(message.content)
+ if m.name == self.bot_name:
+ if m.id.open_id is not None:
+ abm.self_id = m.id.open_id
+
+ if message.content is None:
+ logger.warning("[Lark] 消息内容为空")
+ return
+
+ try:
+ content_json_b = json.loads(message.content)
+ except json.JSONDecodeError:
+ logger.error(f"[Lark] 解析消息内容失败: {message.content}")
+ return
if message.message_type == "text":
- message_str_raw = content_json_b["text"] # 带有 @ 的消息
+ message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分
@@ -167,27 +250,47 @@ class LarkPlatformAdapter(Platform):
content_json_b = _ls
elif message.message_type == "image":
content_json_b = [
- {"tag": "img", "image_key": content_json_b["image_key"], "style": []},
+ {
+ "tag": "img",
+ "image_key": content_json_b.get("image_key"),
+ "style": [],
+ },
]
if message.message_type in ("post", "image"):
for comp in content_json_b:
- if comp["tag"] == "at":
- abm.message.append(at_list[comp["user_id"]])
- elif comp["tag"] == "text" and comp["text"].strip():
+ if comp.get("tag") == "at":
+ user_id = comp.get("user_id")
+ if user_id in at_list:
+ abm.message.append(at_list[user_id])
+ elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip()))
- elif comp["tag"] == "img":
- image_key = comp["image_key"]
+ elif comp.get("tag") == "img":
+ image_key = comp.get("image_key")
+ if not image_key:
+ continue
+
request = (
GetMessageResourceRequest.builder()
- .message_id(message.message_id)
+ .message_id(cast(str, message.message_id))
.file_key(image_key)
.type("image")
.build()
)
+
+ if self.lark_api.im is None:
+ logger.error("[Lark] API Client im 模块未初始化")
+ continue
+
response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success():
logger.error(f"无法下载飞书图片: {image_key}")
+ continue
+
+ if response.file is None:
+ logger.error(f"飞书图片响应中不包含文件流: {image_key}")
+ continue
+
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64))
@@ -195,6 +298,19 @@ class LarkPlatformAdapter(Platform):
for comp in abm.message:
if isinstance(comp, Comp.Plain):
abm.message_str += comp.text
+
+ if message.message_id is None:
+ logger.error("[Lark] 消息缺少 message_id")
+ return
+
+ if (
+ event.event.sender is None
+ or event.event.sender.sender_id is None
+ or event.event.sender.sender_id.open_id is None
+ ):
+ logger.error("[Lark] 消息发送者信息不完整")
+ return
+
abm.message_id = message.message_id
abm.raw_message = message
abm.sender = MessageMember(
@@ -226,13 +342,61 @@ class LarkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
+ async def handle_webhook_event(self, event_data: dict):
+ """处理 Webhook 事件
+
+ Args:
+ event_data: Webhook 事件数据
+ """
+ try:
+ header = event_data.get("header", {})
+ event_id = header.get("event_id", "")
+ if event_id and self._is_duplicate_event(event_id):
+ logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
+ return
+ event_type = header.get("event_type", "")
+ if event_type == "im.message.receive_v1":
+ processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
+ data = (processor.type())(event_data)
+ processor.do(data)
+ else:
+ logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
+ except Exception as e:
+ logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
+
async def run(self):
- # self.client.start()
- await self.client._connect()
+ if self.connection_mode == "webhook":
+ # Webhook 模式
+ if self.webhook_server is None:
+ logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
+ return
+
+ webhook_uuid = self.config.get("webhook_uuid")
+ if webhook_uuid:
+ log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
+ else:
+ logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
+ else:
+ # 长连接模式
+ await self.client._connect()
+
+ async def webhook_callback(self, request: Any) -> Any:
+ """统一 Webhook 回调入口"""
+ if not self.webhook_server:
+ return {"error": "Webhook server not initialized"}, 500
+
+ return await self.webhook_server.handle_callback(request)
async def terminate(self):
- await self.client._disconnect()
- logger.info("飞书(Lark) 适配器已被优雅地关闭")
+ if self.connection_mode == "socket":
+ await self.client._disconnect()
+ logger.info("飞书(Lark) 适配器已关闭")
- def get_client(self) -> lark.Client:
+ def get_client(self) -> lark.ws.Client:
return self.client
+
+ def unified_webhook(self) -> bool:
+ return bool(
+ self.config.get("lark_connection_mode", "") == "webhook"
+ and self.config.get("webhook_uuid")
+ )
diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py
index 04204d35e..7b7d20b38 100644
--- a/astrbot/core/platform/sources/lark/lark_event.py
+++ b/astrbot/core/platform/sources/lark/lark_event.py
@@ -5,7 +5,15 @@ import uuid
from io import BytesIO
import lark_oapi as lark
-from lark_oapi.api.im.v1 import *
+from lark_oapi.api.im.v1 import (
+ CreateImageRequest,
+ CreateImageRequestBody,
+ CreateMessageReactionRequest,
+ CreateMessageReactionRequestBody,
+ Emoji,
+ ReplyMessageRequest,
+ ReplyMessageRequestBody,
+)
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file)
- file_path = image_file_path
+ file_path = image_file_path if image_file_path else ""
elif comp.file and comp.file.startswith("base64://"):
base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str)
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else:
- file_path = comp.file
+ file_path = comp.file if comp.file else ""
if image_file is None:
- image_file = open(file_path, "rb")
+ if not file_path:
+ logger.error("[Lark] 图片路径为空,无法上传")
+ continue
+ try:
+ image_file = open(file_path, "rb")
+ except Exception as e:
+ logger.error(f"[Lark] 无法打开图片文件: {e}")
+ continue
request = (
CreateImageRequest.builder()
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
+
+ if lark_client.im is None:
+ logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
+ continue
+
response = await lark_client.im.v1.image.acreate(request)
if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
+ continue
+
+ if response.data is None:
+ logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
+ continue
+
image_key = response.data.image_key
logger.debug(image_key)
ret.append(_stage)
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
.build()
)
+ if self.bot.im is None:
+ logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
+ return
+
response = await self.bot.im.v1.message.areply(request)
if not response.success():
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def react(self, emoji: str):
+ if self.bot.im is None:
+ logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
+ return
+
request = (
CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id)
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
+
response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
diff --git a/astrbot/core/platform/sources/lark/server.py b/astrbot/core/platform/sources/lark/server.py
new file mode 100644
index 000000000..3921eb8be
--- /dev/null
+++ b/astrbot/core/platform/sources/lark/server.py
@@ -0,0 +1,206 @@
+"""飞书(Lark) Webhook 服务器实现
+
+实现飞书事件订阅的 Webhook 模式,支持:
+1. 请求 URL 验证 (challenge 验证)
+2. 事件加密/解密 (AES-256-CBC)
+3. 签名校验 (SHA256)
+4. 事件接收和处理
+"""
+
+import asyncio
+import base64
+import hashlib
+import json
+from collections.abc import Awaitable, Callable
+
+from Crypto.Cipher import AES
+
+from astrbot.api import logger
+
+
+class AESCipher:
+ """AES 加密/解密工具类"""
+
+ def __init__(self, key: str):
+ self.bs = AES.block_size
+ self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
+
+ @staticmethod
+ def str_to_bytes(data):
+ u_type = type(b"".decode("utf8"))
+ if isinstance(data, u_type):
+ return data.encode("utf8")
+ return data
+
+ @staticmethod
+ def _unpad(s):
+ return s[: -ord(s[len(s) - 1 :])]
+
+ def decrypt(self, enc):
+ iv = enc[: AES.block_size]
+ cipher = AES.new(self.key, AES.MODE_CBC, iv)
+ return self._unpad(cipher.decrypt(enc[AES.block_size :]))
+
+ def decrypt_string(self, enc):
+ enc = base64.b64decode(enc)
+ return self.decrypt(enc).decode("utf8")
+
+
+class LarkWebhookServer:
+ """飞书 Webhook 服务器
+
+ 仅支持统一 Webhook 模式
+ """
+
+ def __init__(self, config: dict, event_queue: asyncio.Queue):
+ """初始化 Webhook 服务器
+
+ Args:
+ config: 飞书配置
+ event_queue: 事件队列
+ """
+ self.app_id = config["app_id"]
+ self.app_secret = config["app_secret"]
+ self.encrypt_key = config.get("lark_encrypt_key", "")
+ self.verification_token = config.get("lark_verification_token", "")
+
+ self.event_queue = event_queue
+ self.callback: Callable[[dict], Awaitable[None]] | None = None
+
+ # 初始化加密工具
+ self.cipher = None
+ if self.encrypt_key:
+ self.cipher = AESCipher(self.encrypt_key)
+
+ def verify_signature(
+ self,
+ timestamp: str,
+ nonce: str,
+ encrypt_key: str,
+ body: bytes,
+ signature: str,
+ ) -> bool:
+ """验证签名
+
+ Args:
+ timestamp: 请求时间戳
+ nonce: 随机数
+ encrypt_key: 加密密钥
+ body: 请求体
+ signature: 签名
+
+ Returns:
+ 签名是否有效
+ """
+ # 拼接字符串: timestamp + nonce + encrypt_key + body
+ bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8")
+ bytes_b = bytes_b1 + body
+ h = hashlib.sha256(bytes_b)
+ calculated_signature = h.hexdigest()
+ return calculated_signature == signature
+
+ def decrypt_event(self, encrypted_data: str) -> dict:
+ """解密事件数据
+
+ Args:
+ encrypted_data: 加密的事件数据
+
+ Returns:
+ 解密后的事件字典
+ """
+ if not self.cipher:
+ raise ValueError("未配置 encrypt_key,无法解密事件")
+
+ decrypted_str = self.cipher.decrypt_string(encrypted_data)
+ return json.loads(decrypted_str)
+
+ async def handle_challenge(self, event_data: dict) -> dict:
+ """处理 challenge 验证请求
+
+ Args:
+ event_data: 事件数据
+
+ Returns:
+ 包含 challenge 的响应
+ """
+ challenge = event_data.get("challenge", "")
+ logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}")
+
+ return {"challenge": challenge}
+
+ async def handle_callback(self, request) -> tuple[dict, int] | dict:
+ """处理 webhook 回调,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 响应数据
+ """
+ # 获取原始请求体
+ body = await request.get_data()
+
+ try:
+ event_data = await request.json
+ except Exception as e:
+ logger.error(f"[Lark Webhook] 解析请求体失败: {e}")
+ return {"error": "Invalid JSON"}, 400
+
+ if not event_data:
+ logger.error("[Lark Webhook] 请求体为空")
+ return {"error": "Empty request body"}, 400
+
+ # 如果配置了 encrypt_key,进行签名验证
+ if self.encrypt_key:
+ timestamp = request.headers.get("X-Lark-Request-Timestamp", "")
+ nonce = request.headers.get("X-Lark-Request-Nonce", "")
+ signature = request.headers.get("X-Lark-Signature", "")
+
+ if timestamp and nonce and signature:
+ if not self.verify_signature(
+ timestamp, nonce, self.encrypt_key, body, signature
+ ):
+ logger.error("[Lark Webhook] 签名验证失败")
+ return {"error": "Invalid signature"}, 401
+
+ # 检查是否是加密事件
+ if "encrypt" in event_data:
+ try:
+ event_data = self.decrypt_event(event_data["encrypt"])
+ logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}")
+ except Exception as e:
+ logger.error(f"[Lark Webhook] 解密事件失败: {e}")
+ return {"error": "Decryption failed"}, 400
+
+ # 验证 token
+ if self.verification_token:
+ header = event_data.get("header", {})
+ if header:
+ token = header.get("token", "")
+ else:
+ token = event_data.get("token", "")
+ if token != self.verification_token:
+ logger.error("[Lark Webhook] Verification Token 不匹配。")
+ return {"error": "Invalid verification token"}, 401
+
+ # 处理 URL 验证 (challenge)
+ if event_data.get("type") == "url_verification":
+ return await self.handle_challenge(event_data)
+
+ # 调用回调函数处理事件
+ if self.callback:
+ try:
+ await self.callback(event_data)
+ except Exception as e:
+ logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True)
+ return {"error": "Event processing failed"}, 500
+
+ return {}
+
+ def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
+ """设置事件回调函数
+
+ Args:
+ callback: 处理事件的异步函数
+ """
+ self.callback = callback
diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py
index 0a553dc6f..7f3db3062 100644
--- a/astrbot/core/platform/sources/misskey/misskey_adapter.py
+++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py
@@ -1,7 +1,6 @@
import asyncio
import os
import random
-from collections.abc import Awaitable
from typing import Any
import astrbot.api.message_components as Comp
@@ -45,7 +44,9 @@ MAX_FILE_UPLOAD_COUNT = 16
DEFAULT_UPLOAD_CONCURRENCY = 3
-@register_platform_adapter("misskey", "Misskey 平台适配器")
+@register_platform_adapter(
+ "misskey", "Misskey 平台适配器", support_streaming_message=False
+)
class MisskeyPlatformAdapter(Platform):
def __init__(
self,
@@ -53,8 +54,7 @@ class MisskeyPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
- self.config = platform_config or {}
+ super().__init__(platform_config or {}, event_queue)
self.settings = platform_settings or {}
self.instance_url = self.config.get("misskey_instance_url", "")
self.access_token = self.config.get("misskey_token", "")
@@ -120,6 +120,7 @@ class MisskeyPlatformAdapter(Platform):
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
+ support_streaming_message=False,
)
async def run(self):
@@ -201,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
if not isinstance(message.raw_message, dict):
message.raw_message = {}
message.raw_message["poll"] = poll
- message.poll = poll
+ message.__setattr__("poll", poll)
except Exception:
pass
@@ -370,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
self,
session: MessageSession,
message_chain: MessageChain,
- ) -> Awaitable[Any]:
+ ) -> None:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
index fe1496644..d693c4206 100644
--- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
+++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
@@ -3,6 +3,7 @@ import base64
import os
import random
import uuid
+from typing import cast
import aiofiles
import botpy
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
- ret = await self._post_send(stream=stream_payload)
+ ret = cast(
+ message.Message,
+ await self._post_send(stream=stream_payload),
+ )
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
@@ -69,6 +73,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
# 结束流式对话,并且传输 buffer 中剩余的消息
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
+ else:
+ ret = await self._post_send()
except Exception as e:
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
@@ -81,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return None
source = self.message_obj.raw_message
- assert isinstance(
+
+ if not isinstance(
source,
(
botpy.message.Message,
@@ -89,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
botpy.message.DirectMessage,
botpy.message.C2CMessage,
),
- )
+ ):
+ logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
+ return None
(
plain_text,
@@ -106,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
):
return None
- payload = {
+ payload: dict = {
"content": plain_text,
"msg_id": self.message_obj.message_id,
}
@@ -116,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ret = None
- match type(source):
- case botpy.message.GroupMessage:
+ match source:
+ case botpy.message.GroupMessage():
+ if not source.group_openid:
+ logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
+ return None
+
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -138,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
group_openid=source.group_openid,
**payload,
)
- case botpy.message.C2CMessage:
+
+ case botpy.message.C2CMessage():
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -167,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
**payload,
)
logger.debug(f"Message sent to C2C: {ret}")
- case botpy.message.Message:
+
+ case botpy.message.Message():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_message(
channel_id=source.channel_id,
**payload,
)
- case botpy.message.DirectMessage:
+
+ case botpy.message.DirectMessage():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
+ case _:
+ pass
+
await super().send(self.send_buffer)
self.send_buffer = None
@@ -196,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"file_type": file_type,
"srv_send_msg": False,
}
+
+ result = None
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
- return await self.bot.api._http.request(route, json=payload)
- if "group_openid" in kwargs:
+ result = await self.bot.api._http.request(route, json=payload)
+ elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"]
route = Route(
"POST",
"/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"],
)
- return await self.bot.api._http.request(route, json=payload)
+ result = await self.bot.api._http.request(route, json=payload)
+ else:
+ raise ValueError("Invalid upload parameters")
+
+ if not isinstance(result, dict):
+ raise RuntimeError(
+ f"Failed to upload image, response is not dict: {result}"
+ )
+
+ return Media(
+ file_uuid=result["file_uuid"],
+ file_info=result["file_info"],
+ ttl=result.get("ttl", 0),
+ )
async def upload_group_and_c2c_record(
self,
@@ -250,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
result = await self.bot.api._http.request(route, json=payload)
if result:
+ if not isinstance(result, dict):
+ logger.error(f"上传文件响应格式错误: {result}")
+ return None
+
return Media(
- file_uuid=result.get("file_uuid"),
- file_info=result.get("file_info"),
+ file_uuid=result["file_uuid"],
+ file_info=result["file_info"],
ttl=result.get("ttl", 0),
- file_id=result.get("id", ""),
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
@@ -271,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
message_reference: message.Reference | None = None,
media: message.Media | None = None,
msg_id: str | None = None,
- msg_seq: str = 1,
+ msg_seq: int | None = 1,
event_id: str | None = None,
markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None,
@@ -280,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = locals()
payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
- return await self.bot.api._http.request(route, json=payload)
+ result = await self.bot.api._http.request(route, json=payload)
+
+ if not isinstance(result, dict):
+ raise RuntimeError(
+ f"Failed to post c2c message, response is not dict: {result}"
+ )
+
+ return message.Message(**result)
@staticmethod
async def _parse_to_qqofficial(message: MessageChain):
@@ -300,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"):
image_base64 = i.file
- else:
+ elif i.file:
image_base64 = file_to_base64(i.file)
+ else:
+ raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
index 96be734fd..2a1bcda47 100644
--- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
+++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
@@ -4,6 +4,7 @@ import asyncio
import logging
import os
import time
+from typing import cast
import botpy
import botpy.message
@@ -44,7 +45,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
- abm.sender.user_id if self.platform.unique_session else message.group_openid
+ abm.sender.user_id
+ if self.platform.unique_session
+ else cast(str, message.group_openid)
)
self._commit(abm)
@@ -97,13 +100,11 @@ class QQOfficialPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
-
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
- self.unique_session = platform_settings["unique_session"]
+ self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"]
@@ -139,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official",
description="QQ 机器人官方 API 适配器",
- id=self.config.get("id"),
+ id=cast(str, self.config.get("id")),
)
@staticmethod
def _parse_from_qqofficial(
- message: botpy.message.Message | botpy.message.GroupMessage,
+ message: botpy.message.Message
+ | botpy.message.GroupMessage
+ | botpy.message.DirectMessage
+ | botpy.message.C2CMessage,
message_type: MessageType,
):
abm = AstrBotMessage()
@@ -152,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.timestamp = int(time.time())
abm.raw_message = message
abm.message_id = message.id
- abm.tag = "qq_official"
+ # abm.tag = "qq_official"
msg: list[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(
@@ -182,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
message,
botpy.message.DirectMessage,
):
- try:
+ if isinstance(message, botpy.message.Message):
abm.self_id = str(message.mentions[0].id)
- except BaseException as _:
+ else:
abm.self_id = ""
plain_content = message.content.replace(
diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py
index 2b8c0b420..63b6726fe 100644
--- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py
+++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py
@@ -1,5 +1,6 @@
import asyncio
import logging
+from typing import Any, cast
import botpy
import botpy.message
@@ -11,6 +12,7 @@ from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import AstrBotMessage, MessageType, Platform, PlatformMetadata
from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
@@ -34,7 +36,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
- abm.sender.user_id if self.platform.unique_session else message.group_openid
+ abm.sender.user_id
+ if self.platform.unique_session
+ else cast(str, message.group_openid)
)
self._commit(abm)
@@ -87,13 +91,12 @@ class QQOfficialWebhookPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
-
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"]
+ self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
intents = botpy.Intents(
public_messages=True,
@@ -106,6 +109,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
timeout=20,
)
self.client.set_platform(self)
+ self.webhook_helper = None
async def send_by_session(
self,
@@ -118,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
- id=self.config.get("id"),
+ id=cast(str, self.config.get("id")),
)
async def run(self):
@@ -128,16 +132,37 @@ class QQOfficialWebhookPlatformAdapter(Platform):
self.client,
)
await self.webhook_helper.initialize()
- await self.webhook_helper.start_polling()
+
+ # 如果启用统一 webhook 模式,则不启动独立服务器
+ webhook_uuid = self.config.get("webhook_uuid")
+ if self.unified_webhook_mode and webhook_uuid:
+ log_webhook_info(f"{self.meta().id}(QQ 官方机器人 Webhook)", webhook_uuid)
+ # 保持运行状态,等待 shutdown
+ await self.webhook_helper.shutdown_event.wait()
+ else:
+ await self.webhook_helper.start_polling()
def get_client(self) -> botClient:
return self.client
+ async def webhook_callback(self, request: Any) -> Any:
+ """统一 Webhook 回调入口"""
+ if not self.webhook_helper:
+ return {"error": "Webhook helper not initialized"}, 500
+
+ # 复用 webhook_helper 的回调处理逻辑
+ return await self.webhook_helper.handle_callback(request)
+
async def terminate(self):
- self.webhook_helper.shutdown_event.set()
+ if self.webhook_helper:
+ self.webhook_helper.shutdown_event.set()
await self.client.close()
- try:
- await self.webhook_helper.server.shutdown()
- except Exception as _:
- pass
+ if self.webhook_helper and not self.unified_webhook_mode:
+ try:
+ await self.webhook_helper.server.shutdown()
+ except Exception as exc:
+ logger.warning(
+ f"Exception occurred during QQOfficialWebhook server shutdown: {exc}",
+ exc_info=True,
+ )
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py
index 65b7c701a..2eda11a6c 100644
--- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py
+++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py
@@ -1,5 +1,6 @@
import asyncio
import logging
+from typing import cast
import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
@@ -78,7 +79,19 @@ class QQOfficialWebhook:
return response
async def callback(self):
- msg: dict = await quart.request.json
+ """内部服务器的回调入口"""
+ return await self.handle_callback(quart.request)
+
+ async def handle_callback(self, request) -> dict:
+ """处理 webhook 回调,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 响应数据
+ """
+ msg: dict = await request.json
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
event = msg.get("t")
@@ -87,7 +100,7 @@ class QQOfficialWebhook:
if opcode == 13:
# validation
- signed = await self.webhook_validation(data)
+ signed = await self.webhook_validation(cast(dict, data))
print(signed)
return signed
diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py
index b5751ebd2..46f9a4e0f 100644
--- a/astrbot/core/platform/sources/satori/satori_adapter.py
+++ b/astrbot/core/platform/sources/satori/satori_adapter.py
@@ -29,8 +29,7 @@ from astrbot.core.platform.astr_message_event import MessageSession
@register_platform_adapter(
- "satori",
- "Satori 协议适配器",
+ "satori", "Satori 协议适配器", support_streaming_message=False
)
class SatoriPlatformAdapter(Platform):
def __init__(
@@ -39,8 +38,7 @@ class SatoriPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.api_base_url = self.config.get(
@@ -60,6 +58,7 @@ class SatoriPlatformAdapter(Platform):
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
+ support_streaming_message=False,
)
self.ws: ClientConnection | None = None
diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py
index 0411f73a4..fbdc71759 100644
--- a/astrbot/core/platform/sources/slack/client.py
+++ b/astrbot/core/platform/sources/slack/client.py
@@ -4,9 +4,11 @@ import hmac
import json
import logging
from collections.abc import Callable
+from typing import cast
from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient
+from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient
@@ -47,51 +49,62 @@ class SlackWebhookClient:
@self.app.route(self.path, methods=["POST"])
async def slack_events():
- """处理 Slack 事件"""
- try:
- # 获取请求体和头部
- body = await request.get_data()
- event_data = json.loads(body.decode("utf-8"))
-
- # Verify Slack request signature
- timestamp = request.headers.get("X-Slack-Request-Timestamp")
- signature = request.headers.get("X-Slack-Signature")
- if not timestamp or not signature:
- return Response("Missing headers", status=400)
- # Calculate the HMAC signature
- sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
- my_signature = (
- "v0="
- + hmac.new(
- self.signing_secret.encode("utf-8"),
- sig_basestring.encode("utf-8"),
- hashlib.sha256,
- ).hexdigest()
- )
- # Verify the signature
- if not hmac.compare_digest(my_signature, signature):
- logger.warning("Slack request signature verification failed")
- return Response("Invalid signature", status=400)
- logger.info(f"Received Slack event: {event_data}")
-
- # 处理 URL 验证事件
- if event_data.get("type") == "url_verification":
- return {"challenge": event_data.get("challenge")}
- # 处理事件
- if self.event_handler and event_data.get("type") == "event_callback":
- await self.event_handler(event_data)
-
- return Response("", status=200)
-
- except Exception as e:
- logger.error(f"处理 Slack 事件时出错: {e}")
- return Response("Internal Server Error", status=500)
+ """内部服务器的 POST 回调入口"""
+ return await self.handle_callback(request)
@self.app.route("/health", methods=["GET"])
async def health_check():
"""健康检查端点"""
return {"status": "ok", "service": "slack-webhook"}
+ async def handle_callback(self, req):
+ """处理 Slack 回调请求,可被统一 webhook 入口复用
+
+ Args:
+ req: Quart 请求对象
+
+ Returns:
+ Response 对象或字典
+ """
+ try:
+ # 获取请求体和头部
+ body = cast(bytes, await req.get_data())
+ event_data = json.loads(body.decode("utf-8"))
+
+ # Verify Slack request signature
+ timestamp = req.headers.get("X-Slack-Request-Timestamp")
+ signature = req.headers.get("X-Slack-Signature")
+ if not timestamp or not signature:
+ return Response("Missing headers", status=400)
+ # Calculate the HMAC signature
+ sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
+ my_signature = (
+ "v0="
+ + hmac.new(
+ self.signing_secret.encode("utf-8"),
+ sig_basestring.encode("utf-8"),
+ hashlib.sha256,
+ ).hexdigest()
+ )
+ # Verify the signature
+ if not hmac.compare_digest(my_signature, signature):
+ logger.warning("Slack request signature verification failed")
+ return Response("Invalid signature", status=400)
+ logger.info(f"Received Slack event: {event_data}")
+
+ # 处理 URL 验证事件
+ if event_data.get("type") == "url_verification":
+ return {"challenge": event_data.get("challenge")}
+ # 处理事件
+ if self.event_handler and event_data.get("type") == "event_callback":
+ await self.event_handler(event_data)
+
+ return Response("", status=200)
+
+ except Exception as e:
+ logger.error(f"处理 Slack 事件时出错: {e}")
+ return Response("Internal Server Error", status=500)
+
async def start(self):
"""启动 Webhook 服务器"""
logger.info(
@@ -128,9 +141,14 @@ class SlackSocketClient:
self.event_handler = event_handler
self.socket_client = None
- async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
+ async def _handle_events(
+ self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
+ ):
"""处理 Socket Mode 事件"""
try:
+ if self.socket_client is None:
+ raise RuntimeError("Socket client is not initialized")
+
# 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response)
diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py
index 6bb5a505e..ed838b0a9 100644
--- a/astrbot/core/platform/sources/slack/slack_adapter.py
+++ b/astrbot/core/platform/sources/slack/slack_adapter.py
@@ -3,8 +3,7 @@ import base64
import re
import time
import uuid
-from collections.abc import Awaitable
-from typing import Any
+from typing import Any, cast
import aiohttp
from slack_sdk.socket_mode.request import SocketModeRequest
@@ -21,6 +20,7 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .client import SlackSocketClient, SlackWebhookClient
@@ -30,6 +30,7 @@ from .slack_event import SlackMessageEvent
@register_platform_adapter(
"slack",
"适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
+ support_streaming_message=False,
)
class SlackAdapter(Platform):
def __init__(
@@ -38,9 +39,7 @@ class SlackAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
-
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
@@ -48,6 +47,7 @@ class SlackAdapter(Platform):
self.app_token = platform_config.get("app_token")
self.signing_secret = platform_config.get("signing_secret")
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
+ self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
self.webhook_path = platform_config.get(
@@ -67,7 +67,8 @@ class SlackAdapter(Platform):
self.metadata = PlatformMetadata(
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
- id=self.config.get("id"),
+ id=cast(str, self.config.get("id")),
+ support_streaming_message=False,
)
# 初始化 Slack Web Client
@@ -116,13 +117,13 @@ class SlackAdapter(Platform):
logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage()
- abm.self_id = self.bot_self_id
+ abm.self_id = cast(str, self.bot_self_id)
# 获取用户信息
user_id = event.get("user", "")
try:
user_info = await self.web_client.users_info(user=user_id)
- user_data = user_info["user"]
+ user_data = cast(dict, user_info["user"])
user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception:
user_name = user_id
@@ -133,7 +134,7 @@ class SlackAdapter(Platform):
channel_id = event.get("channel", "")
try:
channel_info = await self.web_client.conversations_info(channel=channel_id)
- is_im = channel_info["channel"]["is_im"]
+ is_im = cast(dict, channel_info["channel"])["is_im"]
if is_im:
abm.type = MessageType.FRIEND_MESSAGE
@@ -176,7 +177,7 @@ class SlackAdapter(Platform):
for mention in mentions:
try:
mentioned_user = await self.web_client.users_info(user=mention)
- user_data = mentioned_user["user"]
+ user_data = cast(dict, mentioned_user["user"])
user_name = user_data.get("real_name") or user_data.get(
"name",
mention,
@@ -327,7 +328,7 @@ class SlackAdapter(Platform):
)
raise Exception(f"下载文件失败: {resp.status}")
- async def run(self) -> Awaitable[Any]:
+ async def run(self) -> None:
self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
@@ -359,10 +360,17 @@ class SlackAdapter(Platform):
self._handle_webhook_event,
)
- logger.info(
- f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
- )
- await self.webhook_client.start()
+ # 如果启用统一 webhook 模式,则不启动独立服务器
+ webhook_uuid = self.config.get("webhook_uuid")
+ if self.unified_webhook_mode and webhook_uuid:
+ log_webhook_info(f"{self.meta().id}(Slack)", webhook_uuid)
+ # 保持运行状态,等待 shutdown
+ await self.webhook_client.shutdown_event.wait()
+ else:
+ logger.info(
+ f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
+ )
+ await self.webhook_client.start()
else:
raise ValueError(
@@ -389,12 +397,19 @@ class SlackAdapter(Platform):
if abm:
await self.handle_msg(abm)
+ async def webhook_callback(self, request: Any) -> Any:
+ """统一 Webhook 回调入口"""
+ if self.connection_mode != "webhook" or not self.webhook_client:
+ return {"error": "Slack adapter is not in webhook mode"}, 400
+
+ return await self.webhook_client.handle_callback(request)
+
async def terminate(self):
if self.socket_client:
await self.socket_client.stop()
if self.webhook_client:
await self.webhook_client.stop()
- logger.info("Slack 适配器已被优雅地关闭")
+ logger.info("Slack 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
@@ -412,3 +427,10 @@ class SlackAdapter(Platform):
def get_client(self):
return self.web_client
+
+ def unified_webhook(self) -> bool:
+ return bool(
+ self.config.get("unified_webhook_mode", False)
+ and self.config.get("slack_connection_mode", "") == "webhook"
+ and self.config.get("webhook_uuid")
+ )
diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py
index c918abbac..822e6fdeb 100644
--- a/astrbot/core/platform/sources/slack/slack_event.py
+++ b/astrbot/core/platform/sources/slack/slack_event.py
@@ -1,6 +1,7 @@
import asyncio
import re
-from collections.abc import AsyncGenerator
+from collections.abc import AsyncGenerator, Iterable
+from typing import cast
from slack_sdk.web.async_client import AsyncWebClient
@@ -31,14 +32,14 @@ class SlackMessageEvent(AstrMessageEvent):
async def _from_segment_to_slack_block(
segment: BaseMessageComponent,
web_client: AsyncWebClient,
- ) -> dict:
+ ) -> dict | None:
"""将消息段转换为 Slack 块格式"""
if isinstance(segment, Plain):
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
if isinstance(segment, Image):
# upload file
url = segment.url or segment.file
- if url.startswith("http"):
+ if url and url.startswith("http"):
return {
"type": "image",
"image_url": url,
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"},
}
- image_url = response["files"][0]["url_private"]
+ image_url = cast(list, response["files"])[0]["url_private"]
logger.debug(f"Slack file upload response: {response}")
return {
"type": "image",
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
- file_url = response["files"][0]["permalink"]
+ file_url = cast(list, response["files"])[0]["permalink"]
return {
"type": "section",
"text": {
@@ -85,7 +86,6 @@ class SlackMessageEvent(AstrMessageEvent):
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
- return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
@staticmethod
async def _parse_slack_blocks(
@@ -115,7 +115,8 @@ class SlackMessageEvent(AstrMessageEvent):
segment,
web_client,
)
- blocks.append(block)
+ if block:
+ blocks.append(block)
# 如果最后还有文本内容
if text_content.strip():
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
)
members = []
- for member_id in members_response["members"]:
+ for member_id in cast(Iterable, members_response["members"]):
try:
user_info = await self.web_client.users_info(user=member_id)
- user_data = user_info["user"]
+ user_data = cast(dict, user_info["user"])
members.append(
MessageMember(
user_id=member_id,
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
# 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id))
- channel_data = channel_info["channel"]
+ channel_data = cast(dict, channel_info["channel"])
return Group(
group_id=channel_id,
group_name=channel_data.get("name", ""),
diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py
index 88a9f7dc6..218d13bdc 100644
--- a/astrbot/core/platform/sources/telegram/tg_adapter.py
+++ b/astrbot/core/platform/sources/telegram/tg_adapter.py
@@ -42,8 +42,7 @@ class TelegramPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
@@ -381,7 +380,9 @@ class TelegramPlatformAdapter(Platform):
f"Telegram document file_path is None, cannot save the file {file_name}.",
)
else:
- message.message.append(Comp.File(file=file_path, name=file_name))
+ message.message.append(
+ Comp.File(file=file_path, name=file_name, url=file_path)
+ )
elif update.message.video:
file = await update.message.video.get_file()
@@ -423,6 +424,6 @@ class TelegramPlatformAdapter(Platform):
if self.application.updater is not None:
await self.application.updater.stop()
- logger.info("Telegram 适配器已被优雅地关闭")
+ logger.info("Telegram 适配器已被关闭")
except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}")
diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py
index 34fd86ad9..5faba6803 100644
--- a/astrbot/core/platform/sources/telegram/tg_event.py
+++ b/astrbot/core/platform/sources/telegram/tg_event.py
@@ -1,6 +1,7 @@
import asyncio
import os
import re
+from typing import Any, cast
import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
Reply,
)
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
-from astrbot.core.utils.io import download_file
class TelegramPlatformEvent(AstrMessageEvent):
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name,
}
if has_reply:
- payload["reply_to_message_id"] = reply_message_id
+ payload["reply_to_message_id"] = str(reply_message_id)
if message_thread_id:
payload["message_thread_id"] = message_thread_id
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
md_text = telegramify_markdown.markdownify(
chunk,
- max_line_length=None,
normalize_whitespace=False,
)
await client.send_message(
text=md_text,
parse_mode="MarkdownV2",
- **payload,
+ **cast(Any, payload),
)
except Exception as e:
logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead.",
)
- await client.send_message(text=chunk, **payload)
+ await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
- await client.send_photo(photo=image_path, **payload)
+ await client.send_photo(photo=image_path, **cast(Any, payload))
elif isinstance(i, File):
- if i.file.startswith("https://"):
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- path = os.path.join(temp_dir, i.name)
- await download_file(i.file, path)
- i.file = path
-
- await client.send_document(document=i.file, filename=i.name, **payload)
+ path = await i.get_file()
+ name = i.name or os.path.basename(path)
+ await client.send_document(
+ document=path, filename=name, **cast(Any, payload)
+ )
elif isinstance(i, Record):
path = await i.convert_to_file_path()
- await client.send_voice(voice=path, **payload)
+ await client.send_voice(voice=path, **cast(Any, payload))
async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -204,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
if isinstance(chain, MessageChain):
if chain.type == "break":
# 分割符
+ if message_id:
+ try:
+ await self.client.edit_message_text(
+ text=delta,
+ chat_id=payload["chat_id"],
+ message_id=message_id,
+ )
+ except Exception as e:
+ logger.warning(f"编辑消息失败(streaming-break): {e!s}")
message_id = None # 重置消息 ID
delta = "" # 重置 delta
continue
@@ -214,24 +219,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
- await self.client.send_photo(photo=image_path, **payload)
+ await self.client.send_photo(
+ photo=image_path, **cast(Any, payload)
+ )
continue
elif isinstance(i, File):
- if i.file.startswith("https://"):
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- path = os.path.join(temp_dir, i.name)
- await download_file(i.file, path)
- i.file = path
+ path = await i.get_file()
+ name = i.name or os.path.basename(path)
await self.client.send_document(
- document=i.file,
- filename=i.name,
- **payload,
+ document=path,
+ filename=name,
+ **cast(Any, payload),
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
- await self.client.send_voice(voice=path, **payload)
+ await self.client.send_voice(voice=path, **cast(Any, payload))
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
@@ -260,7 +264,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
else:
# delta 长度一般不会大于 4096,因此这里直接发送
try:
- msg = await self.client.send_message(text=delta, **payload)
+ msg = await self.client.send_message(
+ text=delta, **cast(Any, payload)
+ )
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
@@ -274,7 +280,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
markdown_text = telegramify_markdown.markdownify(
delta,
- max_line_length=None,
normalize_whitespace=False,
)
await self.client.edit_message_text(
diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py
index ff5482f58..084d7860d 100644
--- a/astrbot/core/platform/sources/webchat/webchat_adapter.py
+++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py
@@ -2,11 +2,13 @@ import asyncio
import os
import time
import uuid
-from collections.abc import Awaitable, Callable
+from collections.abc import Callable, Coroutine
from typing import Any
from astrbot import logger
-from astrbot.core.message.components import Image, Plain, Record
+from astrbot.core import db_helper
+from astrbot.core.db.po import PlatformMessageHistory
+from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
@@ -74,9 +76,8 @@ class WebChatAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
+ super().__init__(platform_config, event_queue)
- self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
@@ -96,6 +97,92 @@ class WebChatAdapter(Platform):
await WebChatMessageEvent._send(message_chain, session.session_id)
await super().send_by_session(session, message_chain)
+ async def _get_message_history(
+ self, message_id: int
+ ) -> PlatformMessageHistory | None:
+ return await db_helper.get_platform_message_history_by_id(message_id)
+
+ async def _parse_message_parts(
+ self,
+ message_parts: list,
+ depth: int = 0,
+ max_depth: int = 1,
+ ) -> tuple[list, list[str]]:
+ """解析消息段列表,返回消息组件列表和纯文本列表
+
+ Args:
+ message_parts: 消息段列表
+ depth: 当前递归深度
+ max_depth: 最大递归深度(用于处理 reply)
+
+ Returns:
+ tuple[list, list[str]]: (消息组件列表, 纯文本列表)
+ """
+ components = []
+ text_parts = []
+
+ for part in message_parts:
+ part_type = part.get("type")
+ if part_type == "plain":
+ text = part.get("text", "")
+ components.append(Plain(text))
+ text_parts.append(text)
+ elif part_type == "reply":
+ message_id = part.get("message_id")
+ reply_chain = []
+ reply_message_str = ""
+ sender_id = None
+ sender_name = None
+
+ # recursively get the content of the referenced message
+ if depth < max_depth and message_id:
+ history = await self._get_message_history(message_id)
+ if history and history.content:
+ reply_parts = history.content.get("message", [])
+ if isinstance(reply_parts, list):
+ (
+ reply_chain,
+ reply_text_parts,
+ ) = await self._parse_message_parts(
+ reply_parts,
+ depth=depth + 1,
+ max_depth=max_depth,
+ )
+ reply_message_str = "".join(reply_text_parts)
+ sender_id = history.sender_id
+ sender_name = history.sender_name
+
+ components.append(
+ Reply(
+ id=message_id,
+ chain=reply_chain,
+ message_str=reply_message_str,
+ sender_id=sender_id,
+ sender_nickname=sender_name,
+ )
+ )
+ elif part_type == "image":
+ path = part.get("path")
+ if path:
+ components.append(Image.fromFileSystem(path))
+ elif part_type == "record":
+ path = part.get("path")
+ if path:
+ components.append(Record.fromFileSystem(path))
+ elif part_type == "file":
+ path = part.get("path")
+ if path:
+ filename = part.get("filename") or (
+ os.path.basename(path) if path else "file"
+ )
+ components.append(File(name=filename, file=path))
+ elif part_type == "video":
+ path = part.get("path")
+ if path:
+ components.append(Video.fromFileSystem(path))
+
+ return components, text_parts
+
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
@@ -108,40 +195,19 @@ class WebChatAdapter(Platform):
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
- abm.message = []
- if payload["message"]:
- abm.message.append(Plain(payload["message"]))
- if payload["image_url"]:
- if isinstance(payload["image_url"], list):
- for img in payload["image_url"]:
- abm.message.append(
- Image.fromFileSystem(os.path.join(self.imgs_dir, img)),
- )
- else:
- abm.message.append(
- Image.fromFileSystem(
- os.path.join(self.imgs_dir, payload["image_url"]),
- ),
- )
- if payload["audio_url"]:
- if isinstance(payload["audio_url"], list):
- for audio in payload["audio_url"]:
- path = os.path.join(self.imgs_dir, audio)
- abm.message.append(Record(file=path, path=path))
- else:
- path = os.path.join(self.imgs_dir, payload["audio_url"])
- abm.message.append(Record(file=path, path=path))
+ # 处理消息段列表
+ message_parts = payload.get("message", [])
+ abm.message, message_str_parts = await self._parse_message_parts(message_parts)
logger.debug(f"WebChatAdapter: {abm.message}")
- message_str = payload["message"]
abm.timestamp = int(time.time())
- abm.message_str = message_str
+ abm.message_str = "".join(message_str_parts)
abm.raw_message = data
return abm
- def run(self) -> Awaitable[Any]:
+ def run(self) -> Coroutine[Any, Any, None]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py
index 4d4d3b59e..2e529bb1d 100644
--- a/astrbot/core/platform/sources/webchat/webchat_event.py
+++ b/astrbot/core/platform/sources/webchat/webchat_event.py
@@ -1,12 +1,13 @@
import base64
+import json
import os
+import shutil
import uuid
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
-from astrbot.api.message_components import Image, Plain, Record
+from astrbot.api.message_components import File, Image, Json, Plain, Record
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
-from astrbot.core.utils.io import download_image_by_url
from .webchat_queue_mgr import webchat_queue_mgr
@@ -19,7 +20,9 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(imgs_dir, exist_ok=True)
@staticmethod
- async def _send(message: MessageChain, session_id: str, streaming: bool = False):
+ async def _send(
+ message: MessageChain | None, session_id: str, streaming: bool = False
+ ) -> str | None:
cid = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
if not message:
@@ -30,7 +33,7 @@ class WebChatMessageEvent(AstrMessageEvent):
"streaming": False,
}, # end means this request is finished
)
- return ""
+ return
data = ""
for comp in message.chain:
@@ -39,61 +42,62 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put(
{
"type": "plain",
- "cid": cid,
"data": data,
"streaming": streaming,
"chain_type": message.type,
},
)
+ elif isinstance(comp, Json):
+ await web_chat_back_queue.put(
+ {
+ "type": "plain",
+ "data": json.dumps(comp.data, ensure_ascii=False),
+ "streaming": streaming,
+ "chain_type": message.type,
+ },
+ )
elif isinstance(comp, Image):
# save image to local
- filename = str(uuid.uuid4()) + ".jpg"
+ filename = f"{str(uuid.uuid4())}.jpg"
path = os.path.join(imgs_dir, filename)
- if comp.file and comp.file.startswith("file:///"):
- ph = comp.file[8:]
- with open(path, "wb") as f:
- with open(ph, "rb") as f2:
- f.write(f2.read())
- elif comp.file.startswith("base64://"):
- base64_str = comp.file[9:]
- image_data = base64.b64decode(base64_str)
- with open(path, "wb") as f:
- f.write(image_data)
- elif comp.file and comp.file.startswith("http"):
- await download_image_by_url(comp.file, path=path)
- else:
- with open(path, "wb") as f:
- with open(comp.file, "rb") as f2:
- f.write(f2.read())
+ image_base64 = await comp.convert_to_base64()
+ with open(path, "wb") as f:
+ f.write(base64.b64decode(image_base64))
data = f"[IMAGE]{filename}"
await web_chat_back_queue.put(
{
"type": "image",
- "cid": cid,
"data": data,
"streaming": streaming,
},
)
elif isinstance(comp, Record):
# save record to local
- filename = str(uuid.uuid4()) + ".wav"
+ filename = f"{str(uuid.uuid4())}.wav"
path = os.path.join(imgs_dir, filename)
- if comp.file and comp.file.startswith("file:///"):
- ph = comp.file[8:]
- with open(path, "wb") as f:
- with open(ph, "rb") as f2:
- f.write(f2.read())
- elif comp.file and comp.file.startswith("http"):
- await download_image_by_url(comp.file, path=path)
- else:
- with open(path, "wb") as f:
- with open(comp.file, "rb") as f2:
- f.write(f2.read())
+ record_base64 = await comp.convert_to_base64()
+ with open(path, "wb") as f:
+ f.write(base64.b64decode(record_base64))
data = f"[RECORD]{filename}"
await web_chat_back_queue.put(
{
"type": "record",
- "cid": cid,
+ "data": data,
+ "streaming": streaming,
+ },
+ )
+ elif isinstance(comp, File):
+ # save file to local
+ file_path = await comp.get_file()
+ original_name = comp.name or os.path.basename(file_path)
+ ext = os.path.splitext(original_name)[1] or ""
+ filename = f"{uuid.uuid4()!s}{ext}"
+ dest_path = os.path.join(imgs_dir, filename)
+ shutil.copy2(file_path, dest_path)
+ data = f"[FILE]{filename}|{original_name}"
+ await web_chat_back_queue.put(
+ {
+ "type": "file",
"data": data,
"streaming": streaming,
},
@@ -103,39 +107,46 @@ class WebChatMessageEvent(AstrMessageEvent):
return data
- async def send(self, message: MessageChain):
+ async def send(self, message: MessageChain | None):
await WebChatMessageEvent._send(message, session_id=self.session_id)
- await super().send(message)
+ await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
+ reasoning_content = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator:
- if chain.type == "break" and final_data:
- # 分割符
- await web_chat_back_queue.put(
- {
- "type": "break", # break means a segment end
- "data": final_data,
- "streaming": True,
- "cid": cid,
- },
- )
- final_data = ""
- continue
- final_data += await WebChatMessageEvent._send(
+ # if chain.type == "break" and final_data:
+ # # 分割符
+ # await web_chat_back_queue.put(
+ # {
+ # "type": "break", # break means a segment end
+ # "data": final_data,
+ # "streaming": True,
+ # },
+ # )
+ # final_data = ""
+ # continue
+
+ r = await WebChatMessageEvent._send(
chain,
session_id=self.session_id,
streaming=True,
)
+ if not r:
+ continue
+ if chain.type == "reasoning":
+ reasoning_content += chain.get_plain_text()
+ else:
+ final_data += r
await web_chat_back_queue.put(
{
"type": "complete", # complete means we return the final result
"data": final_data,
+ "reasoning": reasoning_content,
"streaming": True,
- "cid": cid,
},
)
await super().send_streaming(generator, use_fallback)
diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
index 165375cd5..4c9a9d36b 100644
--- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
+++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
@@ -4,6 +4,7 @@ import json
import os
import time
import traceback
+from typing import cast
import aiohttp
import anyio
@@ -32,7 +33,9 @@ except ImportError as e:
)
-@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
+@register_platform_adapter(
+ "wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
+)
class WeChatPadProAdapter(Platform):
def __init__(
self,
@@ -40,10 +43,9 @@ class WeChatPadProAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
+ super().__init__(platform_config, event_queue)
self._shutdown_event = None
self.wxnewpass = None
- self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
@@ -51,6 +53,7 @@ class WeChatPadProAdapter(Platform):
name="wechatpadpro",
description="WeChatPadPro 消息平台适配器",
id=self.config.get("id", "wechatpadpro"),
+ support_streaming_message=False,
)
# 保存配置信息
@@ -67,7 +70,7 @@ class WeChatPadProAdapter(Platform):
)
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
- self.wxid = None # 用于保存登录成功后的 wxid
+ self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(),
"wechatpadpro_credentials.json",
@@ -396,7 +399,7 @@ class WeChatPadProAdapter(Platform):
)
await asyncio.sleep(5)
- async def handle_websocket_message(self, message: str):
+ async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}")
try:
@@ -428,10 +431,13 @@ class WeChatPadProAdapter(Platform):
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
+ if self.wxid is None:
+ logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
+ return None
abm = AstrBotMessage()
abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id"))
- abm.timestamp = raw_message.get("create_time")
+ abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180:
@@ -444,7 +450,7 @@ class WeChatPadProAdapter(Platform):
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "")
- msg_type = raw_message.get("msg_type")
+ msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = ""
abm.message = []
@@ -572,7 +578,7 @@ class WeChatPadProAdapter(Platform):
from_user_name: str,
to_user_name: str,
msg_id: int,
- ):
+ ) -> dict | None:
"""下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key}
@@ -723,12 +729,15 @@ class WeChatPadProAdapter(Platform):
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
- msg_id = raw_message.get("msg_id")
+ msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image(
from_user_name,
to_user_name,
msg_id,
)
+ if image_resp is None:
+ logger.error(f"下载图片失败: msg_id={msg_id}")
+ return
image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
)
@@ -769,6 +778,9 @@ class WeChatPadProAdapter(Platform):
bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id")
+ if new_msg_id is None:
+ logger.error("语音消息缺少 new_msg_id")
+ return
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
@@ -776,6 +788,9 @@ class WeChatPadProAdapter(Platform):
)
voicemsg = data_parser._format_to_xml().find("voicemsg")
+ if voicemsg is None:
+ logger.error("无法从 XML 解析 voicemsg 节点")
+ return
bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice(
@@ -784,6 +799,9 @@ class WeChatPadProAdapter(Platform):
bufid=bufid,
length=length,
)
+ if voice_resp is None:
+ logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
+ return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
@@ -825,7 +843,8 @@ class WeChatPadProAdapter(Platform):
try:
if self.ws_handle_task:
self.ws_handle_task.cancel()
- self._shutdown_event.set()
+ if self._shutdown_event is not None:
+ self._shutdown_event.set()
except Exception:
pass
@@ -892,8 +911,8 @@ class WeChatPadProAdapter(Platform):
async def get_contact_details_list(
self,
- room_wx_id_list: list[str] = None,
- user_names: list[str] = None,
+ room_wx_id_list: list[str] | None = None,
+ user_names: list[str] | None = None,
) -> dict | None:
"""获取联系人详情列表。"""
if room_wx_id_list is None:
diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py
index ffd5ec8ee..44ed75117 100644
--- a/astrbot/core/platform/sources/wecom/wecom_adapter.py
+++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py
@@ -2,6 +2,8 @@ import asyncio
import os
import sys
import uuid
+from collections.abc import Awaitable, Callable
+from typing import Any, cast
import quart
from requests import Response
@@ -24,6 +26,7 @@ from astrbot.api.platform import (
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.webhook_utils import log_webhook_info
from .wecom_event import WecomPlatformEvent
from .wecom_kf import WeChatKF
@@ -38,7 +41,7 @@ else:
class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
- self.port = int(config.get("port"))
+ self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule(
"/callback/command",
@@ -58,12 +61,24 @@ class WecomServer:
config["corpid"].strip(),
)
- self.callback = None
+ self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
- logger.info(f"验证请求有效性: {quart.request.args}")
- args = quart.request.args
+ """内部服务器的 GET 验证入口"""
+ return await self.handle_verify(quart.request)
+
+ async def handle_verify(self, request) -> str:
+ """处理验证请求,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 验证响应
+ """
+ logger.info(f"验证请求有效性: {request.args}")
+ args = request.args
try:
echo_str = self.crypto.check_signature(
args.get("msg_signature"),
@@ -78,17 +93,29 @@ class WecomServer:
raise
async def callback_command(self):
- data = await quart.request.get_data()
- msg_signature = quart.request.args.get("msg_signature")
- timestamp = quart.request.args.get("timestamp")
- nonce = quart.request.args.get("nonce")
+ """内部服务器的 POST 回调入口"""
+ return await self.handle_callback(quart.request)
+
+ async def handle_callback(self, request) -> str:
+ """处理回调请求,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 响应内容
+ """
+ data = await request.get_data()
+ msg_signature = request.args.get("msg_signature")
+ timestamp = request.args.get("timestamp")
+ nonce = request.args.get("nonce")
try:
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
except InvalidSignatureException:
logger.error("解密失败,签名异常,请检查配置。")
raise
else:
- msg = parse_message(xml)
+ msg = cast(BaseMessage, parse_message(xml))
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -110,7 +137,7 @@ class WecomServer:
await self.shutdown_event.wait()
-@register_platform_adapter("wecom", "wecom 适配器")
+@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
class WecomPlatformAdapter(Platform):
def __init__(
self,
@@ -118,14 +145,14 @@ class WecomPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.settingss = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
self.api_base_url = platform_config.get(
"api_base_url",
"https://qyapi.weixin.qq.com/cgi-bin/",
)
+ self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
if not self.api_base_url:
self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/"
@@ -150,10 +177,10 @@ class WecomPlatformAdapter(Platform):
# inject
self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client)
- self.client.kf = self.wechat_kf_api
- self.client.kf_message = self.wechat_kf_message_api
+ self.client.__setattr__("kf", self.wechat_kf_api)
+ self.client.__setattr__("kf_message", self.wechat_kf_message_api)
- self.client.API_BASE_URL = self.api_base_url
+ self.client.__setattr__("API_BASE_URL", self.api_base_url)
async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
@@ -196,6 +223,7 @@ class WecomPlatformAdapter(Platform):
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
+ support_streaming_message=False,
)
@override
@@ -231,41 +259,53 @@ class WecomPlatformAdapter(Platform):
)
except Exception as e:
logger.error(e)
- await self.server.start_polling()
+
+ # 如果启用统一 webhook 模式,则不启动独立服务器
+ webhook_uuid = self.config.get("webhook_uuid")
+ if self.unified_webhook_mode and webhook_uuid:
+ log_webhook_info(f"{self.meta().id}(企业微信)", webhook_uuid)
+ # 保持运行状态,等待 shutdown
+ await self.server.shutdown_event.wait()
+ else:
+ await self.server.start_polling()
+
+ async def webhook_callback(self, request: Any) -> Any:
+ """统一 Webhook 回调入口"""
+ # 根据请求方法分发到不同的处理函数
+ if request.method == "GET":
+ return await self.server.handle_verify(request)
+ else:
+ return await self.server.handle_callback(request)
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage()
- if msg.type == "text":
- assert isinstance(msg, TextMessage)
+ if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.self_id = str(msg.agent)
abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
- msg.source,
- msg.source,
+ cast(str, msg.source),
+ cast(str, msg.source),
)
- abm.message_id = msg.id
- abm.timestamp = msg.time
+ abm.message_id = str(msg.id)
+ abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
- elif msg.type == "image":
- assert isinstance(msg, ImageMessage)
+ elif isinstance(msg, ImageMessage):
abm.message_str = "[图片]"
abm.self_id = str(msg.agent)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
- msg.source,
- msg.source,
+ cast(str, msg.source),
+ cast(str, msg.source),
)
- abm.message_id = msg.id
- abm.timestamp = msg.time
+ abm.message_id = str(msg.id)
+ abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
- elif msg.type == "voice":
- assert isinstance(msg, VoiceMessage)
-
+ elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
@@ -292,11 +332,11 @@ class WecomPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
- msg.source,
- msg.source,
+ cast(str, msg.source),
+ cast(str, msg.source),
)
- abm.message_id = msg.id
- abm.timestamp = msg.time
+ abm.message_id = str(msg.id)
+ abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
else:
@@ -308,7 +348,7 @@ class WecomPlatformAdapter(Platform):
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype")
- external_userid = msg.get("external_userid")
+ external_userid = cast(str, msg.get("external_userid"))
abm = AstrBotMessage()
abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
@@ -382,4 +422,4 @@ class WecomPlatformAdapter(Platform):
await self.server.server.shutdown()
except Exception as _:
pass
- logger.info("企业微信 适配器已被优雅地关闭")
+ logger.info("企业微信 适配器已被关闭")
diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py
index ba9ad9a49..0b5dae272 100644
--- a/astrbot/core/platform/sources/wecom/wecom_event.py
+++ b/astrbot/core/platform/sources/wecom/wecom_event.py
@@ -16,7 +16,7 @@ try:
import pydub
except Exception:
logger.warning(
- "检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
+ "检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
)
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
if is_wechat_kf:
# 微信客服
kf_message_api = getattr(self.client, "kf_message", None)
- if not kf_message_api:
+ if not isinstance(kf_message_api, WeChatKFMessage):
logger.warning("未找到微信客服发送消息方法。")
return
- assert isinstance(kf_message_api, WeChatKFMessage)
+
user_id = self.get_sender_id()
for comp in message.chain:
if isinstance(comp, Plain):
diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py
index 29ac02653..70581e7ea 100644
--- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py
+++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py
@@ -22,6 +22,7 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .wecomai_api import (
@@ -30,7 +31,7 @@ from .wecomai_api import (
WecomAIBotStreamMessageBuilder,
)
from .wecomai_event import WecomAIBotMessageEvent
-from .wecomai_queue_mgr import WecomAIQueueMgr, wecomai_queue_mgr
+from .wecomai_queue_mgr import WecomAIQueueMgr
from .wecomai_server import WecomAIBotServer
from .wecomai_utils import (
WecomAIBotConstants,
@@ -103,9 +104,7 @@ class WecomAIBotAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
-
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.settings = platform_settings
# 初始化配置参数
@@ -122,6 +121,7 @@ class WecomAIBotAdapter(Platform):
"wecomaibot_friend_message_welcome_text",
"",
)
+ self.unified_webhook_mode = self.config.get("unified_webhook_mode", False)
# 平台元数据
self.metadata = PlatformMetadata(
@@ -144,9 +144,12 @@ class WecomAIBotAdapter(Platform):
# 事件循环和关闭信号
self.shutdown_event = asyncio.Event()
+ # 队列管理器
+ self.queue_mgr = WecomAIQueueMgr()
+
# 队列监听器
self.queue_listener = WecomAIQueueListener(
- wecomai_queue_mgr,
+ self.queue_mgr,
self._handle_queued_message,
)
@@ -189,7 +192,7 @@ class WecomAIBotAdapter(Platform):
stream_id,
session_id,
)
- wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
+ self.queue_mgr.set_pending_response(stream_id, callback_params)
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id,
@@ -207,7 +210,7 @@ class WecomAIBotAdapter(Platform):
elif msgtype == "stream":
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
- if not wecomai_queue_mgr.has_back_queue(stream_id):
+ if not self.queue_mgr.has_back_queue(stream_id):
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
# 返回结束标志,告诉微信服务器流已结束
@@ -222,7 +225,7 @@ class WecomAIBotAdapter(Platform):
callback_params["timestamp"],
)
return resp
- queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
+ queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if queue.empty():
logger.debug(
f"No new messages in back queue for stream_id: {stream_id}",
@@ -242,10 +245,9 @@ class WecomAIBotAdapter(Platform):
elif msg["type"] == "end":
# stream end
finish = True
- wecomai_queue_mgr.remove_queues(stream_id)
+ self.queue_mgr.remove_queues(stream_id)
break
- else:
- pass
+
logger.debug(
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}",
)
@@ -313,8 +315,8 @@ class WecomAIBotAdapter(Platform):
session_id: str,
):
"""将消息放入队列进行异步处理"""
- input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
- _ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
+ input_queue = self.queue_mgr.get_or_create_queue(stream_id)
+ _ = self.queue_mgr.get_or_create_back_queue(stream_id)
message_payload = {
"message_data": message_data,
"callback_params": callback_params,
@@ -423,17 +425,34 @@ class WecomAIBotAdapter(Platform):
def run(self) -> Awaitable[Any]:
"""运行适配器,同时启动HTTP服务器和队列监听器"""
- logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
async def run_both():
- # 同时运行HTTP服务器和队列监听器
- await asyncio.gather(
- self.server.start_server(),
- self.queue_listener.run(),
- )
+ # 如果启用统一 webhook 模式,则不启动独立服务器
+ webhook_uuid = self.config.get("webhook_uuid")
+ if self.unified_webhook_mode and webhook_uuid:
+ log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
+ # 只运行队列监听器
+ await self.queue_listener.run()
+ else:
+ logger.info(
+ "启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
+ )
+ # 同时运行HTTP服务器和队列监听器
+ await asyncio.gather(
+ self.server.start_server(),
+ self.queue_listener.run(),
+ )
return run_both()
+ async def webhook_callback(self, request: Any) -> Any:
+ """统一 Webhook 回调入口"""
+ # 根据请求方法分发到不同的处理函数
+ if request.method == "GET":
+ return await self.server.handle_verify(request)
+ else:
+ return await self.server.handle_callback(request)
+
async def terminate(self):
"""终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...")
@@ -453,6 +472,7 @@ class WecomAIBotAdapter(Platform):
platform_meta=self.meta(),
session_id=message.session_id,
api_client=self.api_client,
+ queue_mgr=self.queue_mgr,
)
self.commit_event(message_event)
diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py
index 130182b48..fd11d7ceb 100644
--- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py
+++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py
@@ -8,7 +8,7 @@ from astrbot.api.message_components import (
)
from .wecomai_api import WecomAIBotAPIClient
-from .wecomai_queue_mgr import wecomai_queue_mgr
+from .wecomai_queue_mgr import WecomAIQueueMgr
class WecomAIBotMessageEvent(AstrMessageEvent):
@@ -21,6 +21,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
platform_meta,
session_id: str,
api_client: WecomAIBotAPIClient,
+ queue_mgr: WecomAIQueueMgr,
):
"""初始化消息事件
@@ -34,14 +35,16 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"""
super().__init__(message_str, message_obj, platform_meta, session_id)
self.api_client = api_client
+ self.queue_mgr = queue_mgr
@staticmethod
async def _send(
- message_chain: MessageChain,
+ message_chain: MessageChain | None,
stream_id: str,
+ queue_mgr: WecomAIQueueMgr,
streaming: bool = False,
):
- back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
+ back_queue = queue_mgr.get_or_create_back_queue(stream_id)
if not message_chain:
await back_queue.put(
@@ -87,15 +90,15 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data
- async def send(self, message: MessageChain):
+ async def send(self, message: MessageChain | None):
"""发送消息"""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
- await WecomAIBotMessageEvent._send(message, stream_id)
- await super().send(message)
+ await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
+ await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息,参考webchat的send_streaming设计"""
@@ -105,7 +108,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
- back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
+ back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
increment_plain = ""
@@ -134,6 +137,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
final_data += await WecomAIBotMessageEvent._send(
chain,
stream_id=stream_id,
+ queue_mgr=self.queue_mgr,
streaming=True,
)
diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py
index eb3455292..3a982bdf7 100644
--- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py
+++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py
@@ -151,7 +151,3 @@ class WecomAIQueueMgr:
"output_queues": len(self.back_queues),
"pending_responses": len(self.pending_responses),
}
-
-
-# 全局队列管理器实例
-wecomai_queue_mgr = WecomAIQueueMgr()
diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py
index 35acd9066..5cbdd1130 100644
--- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py
+++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py
@@ -59,8 +59,19 @@ class WecomAIBotServer:
)
async def verify_url(self):
- """验证回调 URL"""
- args = quart.request.args
+ """内部服务器的 GET 验证入口"""
+ return await self.handle_verify(quart.request)
+
+ async def handle_verify(self, request):
+ """处理 URL 验证请求,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 验证响应元组 (content, status_code, headers)
+ """
+ args = request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
@@ -81,8 +92,19 @@ class WecomAIBotServer:
return result, 200, {"Content-Type": "text/plain"}
async def handle_message(self):
- """处理消息回调"""
- args = quart.request.args
+ """内部服务器的 POST 消息回调入口"""
+ return await self.handle_callback(quart.request)
+
+ async def handle_callback(self, request):
+ """处理消息回调,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 响应元组 (content, status_code, headers)
+ """
+ args = request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
@@ -102,7 +124,7 @@ class WecomAIBotServer:
try:
# 获取请求体
- post_data = await quart.request.get_data()
+ post_data = await request.get_data()
# 确保 post_data 是 bytes 类型
if isinstance(post_data, str):
diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py
index f44b06e90..d12285d68 100644
--- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py
+++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py
@@ -1,6 +1,8 @@
import asyncio
import sys
import uuid
+from collections.abc import Awaitable, Callable
+from typing import Any, cast
import quart
from requests import Response
@@ -22,6 +24,7 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.webhook_utils import log_webhook_info
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
@@ -31,10 +34,10 @@ else:
from typing_extensions import override
-class WecomServer:
+class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
- self.port = int(config.get("port"))
+ self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key")
@@ -53,13 +56,25 @@ class WecomServer:
self.event_queue = event_queue
- self.callback = None
+ self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
- logger.info(f"验证请求有效性: {quart.request.args}")
+ """内部服务器的 GET 验证入口"""
+ return await self.handle_verify(quart.request)
- args = quart.request.args
+ async def handle_verify(self, request) -> str:
+ """处理验证请求,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 验证响应
+ """
+ logger.info(f"验证请求有效性: {request.args}")
+
+ args = request.args
if not args.get("signature", None):
logger.error("未知的响应,请检查回调地址是否填写正确。")
return "err"
@@ -77,10 +92,22 @@ class WecomServer:
return "err"
async def callback_command(self):
- data = await quart.request.get_data()
- msg_signature = quart.request.args.get("msg_signature")
- timestamp = quart.request.args.get("timestamp")
- nonce = quart.request.args.get("nonce")
+ """内部服务器的 POST 回调入口"""
+ return await self.handle_callback(quart.request)
+
+ async def handle_callback(self, request) -> str:
+ """处理回调请求,可被统一 webhook 入口复用
+
+ Args:
+ request: Quart 请求对象
+
+ Returns:
+ 响应内容
+ """
+ data = await request.get_data()
+ msg_signature = request.args.get("msg_signature")
+ timestamp = request.args.get("timestamp")
+ nonce = request.args.get("nonce")
try:
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
except InvalidSignatureException:
@@ -88,6 +115,9 @@ class WecomServer:
raise
else:
msg = parse_message(xml)
+ if not msg:
+ logger.error("解析失败。msg为None。")
+ raise
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -113,7 +143,9 @@ class WecomServer:
await self.shutdown_event.wait()
-@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
+@register_platform_adapter(
+ "weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
+)
class WeixinOfficialAccountPlatformAdapter(Platform):
def __init__(
self,
@@ -121,8 +153,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
- super().__init__(event_queue)
- self.config = platform_config
+ super().__init__(platform_config, event_queue)
self.settingss = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
self.api_base_url = platform_config.get(
@@ -130,6 +161,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
"https://api.weixin.qq.com/cgi-bin/",
)
self.active_send_mode = self.config.get("active_send_mode", False)
+ self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
if not self.api_base_url:
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
@@ -141,14 +173,14 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if not self.api_base_url.endswith("/"):
self.api_base_url += "/"
- self.server = WecomServer(self._event_queue, self.config)
+ self.server = WeixinOfficialAccountServer(self._event_queue, self.config)
self.client = WeChatClient(
self.config["appid"].strip(),
self.config["secret"].strip(),
)
- self.client.API_BASE_URL = self.api_base_url
+ self.client.__setattr__("API_BASE_URL", self.api_base_url)
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future
@@ -160,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
- future = self.wexin_event_workers[msg.id]
+ future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
- self.wexin_event_workers[msg.id] = future
+ self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(
@@ -172,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
60,
) # wait for 60s
logger.debug(f"Got future result: {result}")
- self.wexin_event_workers.pop(msg.id, None)
+ self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError:
pass
@@ -195,42 +227,58 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
+ support_streaming_message=False,
)
@override
async def run(self):
- await self.server.start_polling()
+ # 如果启用统一 webhook 模式,则不启动独立服务器
+ webhook_uuid = self.config.get("webhook_uuid")
+ if self.unified_webhook_mode and webhook_uuid:
+ log_webhook_info(f"{self.meta().id}(微信公众平台)", webhook_uuid)
+ # 保持运行状态,等待 shutdown
+ await self.server.shutdown_event.wait()
+ else:
+ await self.server.start_polling()
+
+ async def webhook_callback(self, request: Any) -> Any:
+ """统一 Webhook 回调入口"""
+ # 根据请求方法分发到不同的处理函数
+ if request.method == "GET":
+ return await self.server.handle_verify(request)
+ else:
+ return await self.server.handle_callback(request)
async def convert_message(
self,
msg,
- future: asyncio.Future = None,
+ future: asyncio.Future | None = None,
) -> AstrBotMessage | None:
abm = AstrBotMessage()
if isinstance(msg, TextMessage):
- abm.message_str = msg.content
+ abm.message_str = cast(str, msg.content)
abm.self_id = str(msg.target)
- abm.message = [Plain(msg.content)]
+ abm.message = [Plain(cast(str, msg.content))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
- msg.source,
- msg.source,
+ cast(str, msg.source),
+ cast(str, msg.source),
)
- abm.message_id = msg.id
- abm.timestamp = msg.time
+ abm.message_id = str(cast(str | int, msg.id))
+ abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]"
abm.self_id = str(msg.target)
- abm.message = [Image(file=msg.image, url=msg.image)]
+ abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
- msg.source,
- msg.source,
+ cast(str, msg.source),
+ cast(str, msg.source),
)
- abm.message_id = msg.id
- abm.timestamp = msg.time
+ abm.message_id = str(cast(str | int, msg.id))
+ abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
@@ -262,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
- msg.source,
- msg.source,
+ cast(str, msg.source),
+ cast(str, msg.source),
)
- abm.message_id = msg.id
- abm.timestamp = msg.time
+ abm.message_id = str(cast(str | int, msg.id))
+ abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
else:
logger.warning(f"暂未实现的事件: {msg.type}")
- future.set_result(None)
+ if future:
+ future.set_result(None)
return
# 很不优雅 :(
abm.raw_message = {
@@ -300,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.server.server.shutdown()
except Exception as _:
pass
- logger.info("微信公众平台 适配器已被优雅地关闭")
+ logger.info("微信公众平台 适配器已被关闭")
diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py
index d138fc80c..c1f137a41 100644
--- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py
+++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py
@@ -1,5 +1,6 @@
import asyncio
import uuid
+from typing import cast
from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -13,7 +14,7 @@ try:
import pydub
except Exception:
logger.warning(
- "检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
+ "检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
)
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
message_obj = self.message_obj
- active_send_mode = message_obj.raw_message.get("active_send_mode", False)
+ active_send_mode = cast(dict, message_obj.raw_message).get(
+ "active_send_mode", False
+ )
for comp in message.chain:
if isinstance(comp, Plain):
# Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = TextReply(
content=chunk,
- message=self.message_obj.raw_message["message"],
+ message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
- future = self.message_obj.raw_message["future"]
+ future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = ImageReply(
media_id=response["media_id"],
- message=self.message_obj.raw_message["message"],
+ message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
- future = self.message_obj.raw_message["future"]
+ future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = VoiceReply(
media_id=response["media_id"],
- message=self.message_obj.raw_message["message"],
+ message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
- future = self.message_obj.raw_message["future"]
+ future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py
index 0e079e893..d6d524698 100644
--- a/astrbot/core/platform_message_history_mgr.py
+++ b/astrbot/core/platform_message_history_mgr.py
@@ -10,12 +10,12 @@ class PlatformMessageHistoryManager:
self,
platform_id: str,
user_id: str,
- content: list[dict], # TODO: parse from message chain
+ content: dict, # TODO: parse from message chain
sender_id: str | None = None,
sender_name: str | None = None,
- ):
+ ) -> PlatformMessageHistory:
"""Insert a new platform message history record."""
- await self.db.insert_platform_message_history(
+ return await self.db.insert_platform_message_history(
platform_id=platform_id,
user_id=user_id,
content=content,
diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py
index abbe08234..812e02171 100644
--- a/astrbot/core/provider/__init__.py
+++ b/astrbot/core/provider/__init__.py
@@ -1,4 +1,4 @@
from .entities import ProviderMetaData
-from .provider import Personality, Provider, STTProvider
+from .provider import Provider, STTProvider
-__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]
+__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py
index 2f1e84419..d13e9b56a 100644
--- a/astrbot/core/provider/entities.py
+++ b/astrbot/core/provider/entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import enum
import json
@@ -30,18 +32,31 @@ class ProviderType(enum.Enum):
@dataclass
-class ProviderMetaData:
- type: str
- """提供商适配器名称,如 openai, ollama"""
- desc: str = ""
- """提供商适配器描述"""
- provider_type: ProviderType = ProviderType.CHAT_COMPLETION
- cls_type: Any = None
+class ProviderMeta:
+ """The basic metadata of a provider instance."""
+ id: str
+ """the unique id of the provider instance that user configured"""
+ model: str | None
+ """the model name of the provider instance currently used"""
+ type: str
+ """the name of the provider adapter, such as openai, ollama"""
+ provider_type: ProviderType = ProviderType.CHAT_COMPLETION
+ """the capability type of the provider adapter"""
+
+
+@dataclass
+class ProviderMetaData(ProviderMeta):
+ """The metadata of a provider adapter for registration."""
+
+ desc: str = ""
+ """the short description of the provider adapter"""
+ cls_type: Any = None
+ """the class type of the provider adapter"""
default_config_tmpl: dict | None = None
- """平台的默认配置模板"""
+ """the default configuration template of the provider adapter"""
provider_display_name: str | None = None
- """显示在 WebUI 配置页中的提供商名称,如空则是 type"""
+ """the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
@dataclass
@@ -60,12 +75,20 @@ class ToolCallsResult:
]
return ret
+ def to_openai_messages_model(
+ self,
+ ) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
+ return [
+ self.tool_calls_info,
+ *self.tool_calls_result,
+ ]
+
@dataclass
class ProviderRequest:
- prompt: str
+ prompt: str | None = None
"""提示词"""
- session_id: str = ""
+ session_id: str | None = ""
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
@@ -178,28 +201,70 @@ class ProviderRequest:
return ""
+@dataclass
+class TokenUsage:
+ input_other: int = 0
+ """The number of input tokens, excluding cached tokens."""
+ input_cached: int = 0
+ """The number of input cached tokens."""
+ output: int = 0
+ """The number of output tokens."""
+
+ @property
+ def total(self) -> int:
+ return self.input_other + self.input_cached + self.output
+
+ @property
+ def input(self) -> int:
+ return self.input_other + self.input_cached
+
+ def __add__(self, other: TokenUsage) -> TokenUsage:
+ return TokenUsage(
+ input_other=self.input_other + other.input_other,
+ input_cached=self.input_cached + other.input_cached,
+ output=self.output + other.output,
+ )
+
+ def __sub__(self, other: TokenUsage) -> TokenUsage:
+ return TokenUsage(
+ input_other=self.input_other - other.input_other,
+ input_cached=self.input_cached - other.input_cached,
+ output=self.output - other.output,
+ )
+
+
@dataclass
class LLMResponse:
role: str
- """角色, assistant, tool, err"""
+ """The role of the message, e.g., assistant, tool, err"""
result_chain: MessageChain | None = None
- """返回的消息链"""
+ """A chain of message components representing the text completion from LLM."""
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
- """工具调用参数"""
+ """Tool call arguments."""
tools_call_name: list[str] = field(default_factory=list)
- """工具调用名称"""
+ """Tool call names."""
tools_call_ids: list[str] = field(default_factory=list)
- """工具调用 ID"""
+ """Tool call IDs."""
+ tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict)
+ """Tool call extra content. tool_call_id -> extra_content dict"""
+ reasoning_content: str = ""
+ """The reasoning content extracted from the LLM, if any."""
raw_completion: (
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
) = None
- _new_record: dict[str, Any] | None = None
+ """The raw completion response from the LLM provider."""
_completion_text: str = ""
+ """The plain text of the completion."""
is_chunk: bool = False
- """是否是流式输出的单个 Chunk"""
+ """Indicates if the response is a chunked response."""
+
+ id: str | None = None
+ """The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
+ usage: TokenUsage | None = None
+ """The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
def __init__(
self,
@@ -209,12 +274,14 @@ class LLMResponse:
tools_call_args: list[dict[str, Any]] | None = None,
tools_call_name: list[str] | None = None,
tools_call_ids: list[str] | None = None,
+ tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
raw_completion: ChatCompletion
| GenerateContentResponse
| AnthropicMessage
| None = None,
- _new_record: dict[str, Any] | None = None,
is_chunk: bool = False,
+ id: str | None = None,
+ usage: TokenUsage | None = None,
):
"""初始化 LLMResponse
@@ -233,6 +300,8 @@ class LLMResponse:
tools_call_name = []
if tools_call_ids is None:
tools_call_ids = []
+ if tools_call_extra_content is None:
+ tools_call_extra_content = {}
self.role = role
self.completion_text = completion_text
@@ -240,8 +309,8 @@ class LLMResponse:
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
+ self.tools_call_extra_content = tools_call_extra_content
self.raw_completion = raw_completion
- self._new_record = _new_record
self.is_chunk = is_chunk
@property
@@ -266,16 +335,19 @@ class LLMResponse:
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
- ret.append(
- {
- "id": self.tools_call_ids[idx],
- "function": {
- "name": self.tools_call_name[idx],
- "arguments": json.dumps(tool_call_arg),
- },
- "type": "function",
+ payload = {
+ "id": self.tools_call_ids[idx],
+ "function": {
+ "name": self.tools_call_name[idx],
+ "arguments": json.dumps(tool_call_arg),
},
- )
+ "type": "function",
+ }
+ if self.tools_call_extra_content.get(self.tools_call_ids[idx]):
+ payload["extra_content"] = self.tools_call_extra_content[
+ self.tools_call_ids[idx]
+ ]
+ ret.append(payload)
return ret
def to_openai_to_calls_model(self) -> list[ToolCall]:
@@ -289,6 +361,10 @@ class LLMResponse:
name=self.tools_call_name[idx],
arguments=json.dumps(tool_call_arg),
),
+ # the extra_content will not serialize if it's None when calling ToolCall.model_dump()
+ extra_content=self.tools_call_extra_content.get(
+ self.tools_call_ids[idx]
+ ),
),
)
return ret
diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py
index 36aad2ae9..7aad86bdd 100644
--- a/astrbot/core/provider/func_tool_manager.py
+++ b/astrbot/core/provider/func_tool_manager.py
@@ -1,9 +1,10 @@
from __future__ import annotations
import asyncio
+import copy
import json
import os
-from collections.abc import Awaitable, Callable
+from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import aiohttp
@@ -24,7 +25,16 @@ SUPPORTED_TYPES = [
"boolean",
] # json schema 支持的数据类型
-
+PY_TO_JSON_TYPE = {
+ "int": "number",
+ "float": "number",
+ "bool": "boolean",
+ "str": "string",
+ "dict": "object",
+ "list": "array",
+ "tuple": "array",
+ "set": "array",
+}
# alias
FuncTool = FunctionTool
@@ -106,19 +116,18 @@ class FunctionToolManager:
def spec_to_func(
self,
name: str,
- func_args: list,
+ func_args: list[dict],
desc: str,
- handler: Callable[..., Awaitable[Any]],
+ handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
- params["properties"][param["name"]] = {
- "type": param["type"],
- "description": param["description"],
- }
+ p = copy.deepcopy(param)
+ p.pop("name", None)
+ params["properties"][param["name"]] = p
return FuncTool(
name=name,
parameters=params,
@@ -131,7 +140,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
- handler: Callable[..., Awaitable[Any]],
+ handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> None:
"""添加函数调用工具
@@ -271,19 +280,22 @@ class FunctionToolManager:
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
if name in self.mcp_client_dict:
+ client = self.mcp_client_dict[name]
try:
# 关闭MCP连接
- await self.mcp_client_dict[name].cleanup()
- self.mcp_client_dict.pop(name)
+ await client.cleanup()
except Exception as e:
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
- # 移除关联的FuncTool
- self.func_list = [
- f
- for f in self.func_list
- if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
- ]
- logger.info(f"已关闭 MCP 服务 {name}")
+ finally:
+ # Remove client from dict after cleanup attempt (successful or not)
+ self.mcp_client_dict.pop(name, None)
+ # 移除关联的FuncTool
+ self.func_list = [
+ f
+ for f in self.func_list
+ if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
+ ]
+ logger.info(f"已关闭 MCP 服务 {name}")
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 5fc5a4b5e..be8edc282 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -1,7 +1,8 @@
import asyncio
import traceback
+from typing import Protocol, runtime_checkable
-from astrbot.core import logger, sp
+from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
@@ -10,6 +11,7 @@ from .entities import ProviderType
from .provider import (
EmbeddingProvider,
Provider,
+ Providers,
RerankProvider,
STTProvider,
TTSProvider,
@@ -17,6 +19,11 @@ from .provider import (
from .register import llm_tools, provider_cls_map
+@runtime_checkable
+class HasInitialize(Protocol):
+ async def initialize(self) -> None: ...
+
+
class ProviderManager:
def __init__(
self,
@@ -24,6 +31,7 @@ class ProviderManager:
db_helper: BaseDatabase,
persona_mgr: PersonaManager,
):
+ self.reload_lock = asyncio.Lock()
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
@@ -47,7 +55,7 @@ class ProviderManager:
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
- Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
+ Providers,
] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -122,15 +130,13 @@ class ProviderManager:
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
- async def get_provider_by_id(self, provider_id: str) -> Provider | None:
+ async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(
- self,
- provider_type: ProviderType,
- umo=None,
- ) -> Provider | STTProvider | TTSProvider | None:
+ self, provider_type: ProviderType, umo=None
+ ) -> Providers | None:
"""获取正在使用的提供商实例。
Args:
@@ -190,7 +196,6 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
- # 设置默认提供商
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
@@ -209,15 +214,37 @@ class ProviderManager:
scope="global",
scope_id="global",
)
- self.curr_provider_inst = self.inst_map.get(selected_provider_id)
+
+ temp_provider = (
+ self.inst_map.get(selected_provider_id)
+ if isinstance(selected_provider_id, str)
+ else None
+ )
+ self.curr_provider_inst = (
+ temp_provider if isinstance(temp_provider, Provider) else None
+ )
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
- self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
+ temp_stt = (
+ self.inst_map.get(selected_stt_provider_id)
+ if isinstance(selected_stt_provider_id, str)
+ else None
+ )
+ self.curr_stt_provider_inst = (
+ temp_stt if isinstance(temp_stt, STTProvider) else None
+ )
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
- self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
+ temp_tts = (
+ self.inst_map.get(selected_tts_provider_id)
+ if isinstance(selected_tts_provider_id, str)
+ else None
+ )
+ self.curr_tts_provider_inst = (
+ temp_tts if isinstance(temp_tts, TTSProvider) else None
+ )
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
@@ -226,6 +253,9 @@ class ProviderManager:
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
+ logger.info(f"Provider {provider_config['id']} is disabled, skipping")
+ return
+ if provider_config.get("provider_type", "") == "agent_runner":
return
logger.info(
@@ -241,18 +271,12 @@ class ProviderManager:
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
+ case "groq_chat_completion":
+ from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
- case "dify":
- from .sources.dify_source import ProviderDify as ProviderDify
- case "coze":
- from .sources.coze_source import ProviderCoze as ProviderCoze
- case "dashscope":
- from .sources.dashscope_source import (
- ProviderDashscope as ProviderDashscope,
- )
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
@@ -329,6 +353,10 @@ class ProviderManager:
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
+ case "bailian_rerank":
+ from .sources.bailian_rerank_source import (
+ BailianRerankProvider as BailianRerankProvider,
+ )
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
@@ -354,74 +382,105 @@ class ProviderManager:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
- if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
- # STT 任务
- inst = cls_type(provider_config, self.provider_settings)
+ provider_metadata.id = provider_config["id"]
- if getattr(inst, "initialize", None):
- await inst.initialize()
+ match provider_metadata.provider_type:
+ case ProviderType.SPEECH_TO_TEXT:
+ # STT 任务
+ if not issubclass(cls_type, STTProvider):
+ raise TypeError(
+ f"Provider class {cls_type} is not a subclass of STTProvider"
+ )
+ inst = cls_type(provider_config, self.provider_settings)
- self.stt_provider_insts.append(inst)
- if (
- self.provider_stt_settings.get("provider_id")
- == provider_config["id"]
- ):
- self.curr_stt_provider_inst = inst
- logger.info(
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
+ if isinstance(inst, HasInitialize):
+ await inst.initialize()
+
+ self.stt_provider_insts.append(inst)
+ if (
+ self.provider_stt_settings.get("provider_id")
+ == provider_config["id"]
+ ):
+ self.curr_stt_provider_inst = inst
+ logger.info(
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
+ )
+ if not self.curr_stt_provider_inst:
+ self.curr_stt_provider_inst = inst
+
+ case ProviderType.TEXT_TO_SPEECH:
+ # TTS 任务
+ if not issubclass(cls_type, TTSProvider):
+ raise TypeError(
+ f"Provider class {cls_type} is not a subclass of TTSProvider"
+ )
+ inst = cls_type(provider_config, self.provider_settings)
+
+ if isinstance(inst, HasInitialize):
+ await inst.initialize()
+
+ self.tts_provider_insts.append(inst)
+ if (
+ self.provider_settings.get("provider_id")
+ == provider_config["id"]
+ ):
+ self.curr_tts_provider_inst = inst
+ logger.info(
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
+ )
+ if not self.curr_tts_provider_inst:
+ self.curr_tts_provider_inst = inst
+
+ case ProviderType.CHAT_COMPLETION:
+ # 文本生成任务
+ if not issubclass(cls_type, Provider):
+ raise TypeError(
+ f"Provider class {cls_type} is not a subclass of Provider"
+ )
+ inst = cls_type(
+ provider_config,
+ self.provider_settings,
)
- if not self.curr_stt_provider_inst:
- self.curr_stt_provider_inst = inst
- elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
- # TTS 任务
- inst = cls_type(provider_config, self.provider_settings)
+ if isinstance(inst, HasInitialize):
+ await inst.initialize()
- if getattr(inst, "initialize", None):
- await inst.initialize()
+ self.provider_insts.append(inst)
+ if (
+ self.provider_settings.get("default_provider_id")
+ == provider_config["id"]
+ ):
+ self.curr_provider_inst = inst
+ logger.info(
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
+ )
+ if not self.curr_provider_inst:
+ self.curr_provider_inst = inst
- self.tts_provider_insts.append(inst)
- if self.provider_settings.get("provider_id") == provider_config["id"]:
- self.curr_tts_provider_inst = inst
- logger.info(
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
+ case ProviderType.EMBEDDING:
+ if not issubclass(cls_type, EmbeddingProvider):
+ raise TypeError(
+ f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
+ )
+ inst = cls_type(provider_config, self.provider_settings)
+ if isinstance(inst, HasInitialize):
+ await inst.initialize()
+ self.embedding_provider_insts.append(inst)
+ case ProviderType.RERANK:
+ if not issubclass(cls_type, RerankProvider):
+ raise TypeError(
+ f"Provider class {cls_type} is not a subclass of RerankProvider"
+ )
+ inst = cls_type(provider_config, self.provider_settings)
+ if isinstance(inst, HasInitialize):
+ await inst.initialize()
+ self.rerank_provider_insts.append(inst)
+ case _:
+ # 未知供应商抛出异常,确保inst初始化
+ # Should be unreachable
+ raise Exception(
+ f"未知的提供商类型:{provider_metadata.provider_type}"
)
- if not self.curr_tts_provider_inst:
- self.curr_tts_provider_inst = inst
-
- elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
- # 文本生成任务
- inst = cls_type(
- provider_config,
- self.provider_settings,
- self.selected_default_persona,
- )
-
- if getattr(inst, "initialize", None):
- await inst.initialize()
-
- self.provider_insts.append(inst)
- if (
- self.provider_settings.get("default_provider_id")
- == provider_config["id"]
- ):
- self.curr_provider_inst = inst
- logger.info(
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
- )
- if not self.curr_provider_inst:
- self.curr_provider_inst = inst
-
- elif provider_metadata.provider_type == ProviderType.EMBEDDING:
- inst = cls_type(provider_config, self.provider_settings)
- if getattr(inst, "initialize", None):
- await inst.initialize()
- self.embedding_provider_insts.append(inst)
- elif provider_metadata.provider_type == ProviderType.RERANK:
- inst = cls_type(provider_config, self.provider_settings)
- if getattr(inst, "initialize", None):
- await inst.initialize()
- self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
@@ -433,40 +492,46 @@ class ProviderManager:
)
async def reload(self, provider_config: dict):
- await self.terminate_provider(provider_config["id"])
- if provider_config["enable"]:
- await self.load_provider(provider_config)
+ async with self.reload_lock:
+ await self.terminate_provider(provider_config["id"])
+ if provider_config["enable"]:
+ await self.load_provider(provider_config)
- # 和配置文件保持同步
- config_ids = [provider["id"] for provider in self.providers_config]
- logger.debug(f"providers in user's config: {config_ids}")
- for key in list(self.inst_map.keys()):
- if key not in config_ids:
- await self.terminate_provider(key)
+ # 和配置文件保持同步
+ self.providers_config = astrbot_config["provider"]
+ config_ids = [provider["id"] for provider in self.providers_config]
+ logger.info(f"providers in user's config: {config_ids}")
+ for key in list(self.inst_map.keys()):
+ if key not in config_ids:
+ await self.terminate_provider(key)
- if len(self.provider_insts) == 0:
- self.curr_provider_inst = None
- elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
- self.curr_provider_inst = self.provider_insts[0]
- logger.info(
- f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
- )
+ if len(self.provider_insts) == 0:
+ self.curr_provider_inst = None
+ elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
+ self.curr_provider_inst = self.provider_insts[0]
+ logger.info(
+ f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
+ )
- if len(self.stt_provider_insts) == 0:
- self.curr_stt_provider_inst = None
- elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
- self.curr_stt_provider_inst = self.stt_provider_insts[0]
- logger.info(
- f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
- )
+ if len(self.stt_provider_insts) == 0:
+ self.curr_stt_provider_inst = None
+ elif (
+ self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
+ ):
+ self.curr_stt_provider_inst = self.stt_provider_insts[0]
+ logger.info(
+ f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
+ )
- if len(self.tts_provider_insts) == 0:
- self.curr_tts_provider_inst = None
- elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
- self.curr_tts_provider_inst = self.tts_provider_insts[0]
- logger.info(
- f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
- )
+ if len(self.tts_provider_insts) == 0:
+ self.curr_tts_provider_inst = None
+ elif (
+ self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
+ ):
+ self.curr_tts_provider_inst = self.tts_provider_insts[0]
+ logger.info(
+ f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
+ )
def get_insts(self):
return self.provider_insts
diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py
index 7ab8f00ba..7f21a2ee1 100644
--- a/astrbot/core/provider/provider.py
+++ b/astrbot/core/provider/provider.py
@@ -1,26 +1,27 @@
import abc
import asyncio
+import os
from collections.abc import AsyncGenerator
-from dataclasses import dataclass
+from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
-from astrbot.core.db.po import Personality
from astrbot.core.provider.entities import (
LLMResponse,
- ProviderType,
+ ProviderMeta,
RerankResult,
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
+from astrbot.core.utils.astrbot_path import get_astrbot_path
-
-@dataclass
-class ProviderMeta:
- id: str
- model: str
- type: str
- provider_type: ProviderType
+Providers: TypeAlias = Union[
+ "Provider",
+ "STTProvider",
+ "TTSProvider",
+ "EmbeddingProvider",
+ "RerankProvider",
+]
class AbstractProvider(abc.ABC):
@@ -43,15 +44,23 @@ class AbstractProvider(abc.ABC):
"""Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
- provider_type = meta_data.provider_type if meta_data else None
- if provider_type is None:
- raise ValueError(f"Cannot find provider type: {provider_type_name}")
- return ProviderMeta(
- id=self.provider_config["id"],
+ if not meta_data:
+ raise ValueError(f"Provider type {provider_type_name} not registered")
+ meta = ProviderMeta(
+ id=self.provider_config.get("id", "default"),
model=self.get_model(),
type=provider_type_name,
- provider_type=provider_type,
+ provider_type=meta_data.provider_type,
)
+ return meta
+
+ async def test(self):
+ """test the provider is a
+
+ raises:
+ Exception: if the provider is not available
+ """
+ ...
class Provider(AbstractProvider):
@@ -61,15 +70,10 @@ class Provider(AbstractProvider):
self,
provider_config: dict,
provider_settings: dict,
- default_persona: Personality | None = None,
) -> None:
super().__init__(provider_config)
-
self.provider_settings = provider_settings
- self.curr_personality = default_persona
- """维护了当前的使用的 persona,即人格。可能为 None"""
-
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError
@@ -147,7 +151,9 @@ class Provider(AbstractProvider):
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
- ...
+ if False: # pragma: no cover - make this an async generator for typing
+ yield None # type: ignore
+ raise NotImplementedError()
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
@@ -180,6 +186,12 @@ class Provider(AbstractProvider):
return dicts
+ async def test(self, timeout: float = 45.0):
+ await asyncio.wait_for(
+ self.text_chat(prompt="REPLY `PONG` ONLY"),
+ timeout=timeout,
+ )
+
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -192,6 +204,14 @@ class STTProvider(AbstractProvider):
"""获取音频的文本"""
raise NotImplementedError
+ async def test(self):
+ sample_audio_path = os.path.join(
+ get_astrbot_path(),
+ "samples",
+ "stt_health_check.wav",
+ )
+ await self.get_text(sample_audio_path)
+
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -204,6 +224,9 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError
+ async def test(self):
+ await self.get_audio("hi")
+
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -226,6 +249,9 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度"""
...
+ async def test(self):
+ await self.get_embedding("astrbot")
+
async def get_embeddings_batch(
self,
texts: list[str],
@@ -309,3 +335,8 @@ class RerankProvider(AbstractProvider):
) -> list[RerankResult]:
"""获取查询和文档的重排序分数"""
...
+
+ async def test(self):
+ result = await self.rerank("Apple", documents=["apple", "banana"])
+ if not result:
+ raise Exception("Rerank provider test failed, no results returned")
diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py
index 1aead54df..3ad83784e 100644
--- a/astrbot/core/provider/register.py
+++ b/astrbot/core/provider/register.py
@@ -36,6 +36,8 @@ def register_provider_adapter(
default_config_tmpl["id"] = provider_type_name
pm = ProviderMetaData(
+ id="default", # will be replaced when instantiated
+ model=None,
type=provider_type_name,
desc=desc,
provider_type=provider_type,
diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py
index 77c85cef4..7e33f40d9 100644
--- a/astrbot/core/provider/sources/anthropic_source.py
+++ b/astrbot/core/provider/sources/anthropic_source.py
@@ -6,10 +6,12 @@ from mimetypes import guess_type
import anthropic
from anthropic import AsyncAnthropic
from anthropic.types import Message
+from anthropic.types.message_delta_usage import MessageDeltaUsage
+from anthropic.types.usage import Usage
from astrbot import logger
from astrbot.api.provider import Provider
-from astrbot.core.provider.entities import LLMResponse
+from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
@@ -25,12 +27,10 @@ class ProviderAnthropic(Provider):
self,
provider_config,
provider_settings,
- default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
- default_persona,
)
self.chosen_api_key: str = ""
@@ -109,6 +109,22 @@ class ProviderAnthropic(Provider):
return system_prompt, new_messages
+ def _extract_usage(self, usage: Usage) -> TokenUsage:
+ # https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
+ return TokenUsage(
+ input_other=usage.input_tokens or 0,
+ input_cached=usage.cache_read_input_tokens or 0,
+ output=usage.output_tokens,
+ )
+
+ def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
+ if usage.input_tokens is not None:
+ token_usage.input_other = usage.input_tokens
+ if usage.cache_read_input_tokens is not None:
+ token_usage.input_cached = usage.cache_read_input_tokens
+ if usage.output_tokens is not None:
+ token_usage.output = usage.output_tokens
+
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -133,6 +149,10 @@ class ProviderAnthropic(Provider):
llm_response.tools_call_args.append(content_block.input)
llm_response.tools_call_name.append(content_block.name)
llm_response.tools_call_ids.append(content_block.id)
+
+ llm_response.id = completion.id
+ llm_response.usage = self._extract_usage(completion.usage)
+
# TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args:
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
@@ -154,9 +174,16 @@ class ProviderAnthropic(Provider):
final_text = ""
final_tool_calls = []
+ id = None
+ usage = TokenUsage()
+
async with self.client.messages.stream(**payloads) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
+ if event.type == "message_start":
+ # the usage contains input token usage
+ id = event.message.id
+ usage = self._extract_usage(event.message.usage)
if event.type == "content_block_start":
if event.content_block.type == "text":
# 文本块开始
@@ -164,6 +191,8 @@ class ProviderAnthropic(Provider):
role="assistant",
completion_text="",
is_chunk=True,
+ usage=usage,
+ id=id,
)
elif event.content_block.type == "tool_use":
# 工具使用块开始,初始化缓冲区
@@ -181,6 +210,8 @@ class ProviderAnthropic(Provider):
role="assistant",
completion_text=event.delta.text,
is_chunk=True,
+ usage=usage,
+ id=id,
)
elif event.delta.type == "input_json_delta":
# 工具调用参数增量
@@ -217,6 +248,8 @@ class ProviderAnthropic(Provider):
tools_call_name=[tool_info["name"]],
tools_call_ids=[tool_info["id"]],
is_chunk=True,
+ usage=usage,
+ id=id,
)
except json.JSONDecodeError:
# JSON 解析失败,跳过这个工具调用
@@ -225,11 +258,17 @@ class ProviderAnthropic(Provider):
# 清理缓冲区
del tool_use_buffer[event.index]
+ elif event.type == "message_delta":
+ if event.usage:
+ self._update_usage(usage, event.usage)
+
# 返回最终的完整结果
final_response = LLMResponse(
role="assistant",
completion_text=final_text,
is_chunk=False,
+ usage=usage,
+ id=id,
)
if final_tool_calls:
@@ -292,7 +331,7 @@ class ProviderAnthropic(Provider):
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
- logger.error(f"发生了错误。Provider 配置如下: {model_config}")
+ # logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response
diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py
index e85d91793..2ccf146ca 100644
--- a/astrbot/core/provider/sources/azure_tts_source.py
+++ b/astrbot/core/provider/sources/azure_tts_source.py
@@ -29,15 +29,24 @@ class OTTSProvider:
self.last_sync_time = 0
self.timeout = Timeout(10.0)
self.retry_count = 3
- self.client = None
+ self._client: AsyncClient | None = None
+
+ @property
+ def client(self) -> AsyncClient:
+ if self._client is None:
+ raise RuntimeError(
+ "Client not initialized. Please use 'async with' context."
+ )
+ return self._client
async def __aenter__(self):
- self.client = AsyncClient(timeout=self.timeout)
+ self._client = AsyncClient(timeout=self.timeout)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
- if self.client:
- await self.client.aclose()
+ if self._client:
+ await self._client.aclose()
+ self._client = None
async def _sync_time(self):
try:
@@ -90,6 +99,7 @@ class OTTSProvider:
if attempt == self.retry_count - 1:
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
await asyncio.sleep(0.5 * (attempt + 1))
+ raise RuntimeError("OTTS未返回音频文件")
class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
)
- self.client = None
+ self._client: AsyncClient | None = None
self.token = None
self.token_expire = 0
self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100"),
}
+ @property
+ def client(self) -> AsyncClient:
+ if self._client is None:
+ raise RuntimeError(
+ "Client not initialized. Please use 'async with' context."
+ )
+ return self._client
+
async def __aenter__(self):
- self.client = AsyncClient(
+ self._client = AsyncClient(
headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
- if self.client:
- await self.client.aclose()
+ if self._client:
+ await self._client.aclose()
+ self._client = None
async def _refresh_token(self):
token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
key_value = provider_config.get("azure_tts_subscription_key", "")
self.provider = self._parse_provider(key_value, provider_config)
- def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
+ def _parse_provider(
+ self, key_value: str, config: dict
+ ) -> OTTSProvider | AzureNativeProvider:
if key_value.lower().startswith("other["):
+ json_str = ""
try:
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
if not match:
diff --git a/astrbot/core/provider/sources/bailian_rerank_source.py b/astrbot/core/provider/sources/bailian_rerank_source.py
new file mode 100644
index 000000000..9e079d4a9
--- /dev/null
+++ b/astrbot/core/provider/sources/bailian_rerank_source.py
@@ -0,0 +1,240 @@
+import os
+
+import aiohttp
+
+from astrbot import logger
+
+from ..entities import ProviderType, RerankResult
+from ..provider import RerankProvider
+from ..register import register_provider_adapter
+
+
+class BailianRerankError(Exception):
+ """百炼重排序服务异常基类"""
+
+ pass
+
+
+class BailianAPIError(BailianRerankError):
+ """百炼API返回错误"""
+
+ pass
+
+
+class BailianNetworkError(BailianRerankError):
+ """百炼网络请求错误"""
+
+ pass
+
+
+@register_provider_adapter(
+ "bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
+)
+class BailianRerankProvider(RerankProvider):
+ """阿里云百炼文本重排序适配器."""
+
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
+ super().__init__(provider_config, provider_settings)
+ self.provider_config = provider_config
+ self.provider_settings = provider_settings
+
+ # API配置
+ self.api_key = provider_config.get("rerank_api_key") or os.getenv(
+ "DASHSCOPE_API_KEY", ""
+ )
+ if not self.api_key:
+ raise ValueError("阿里云百炼 API Key 不能为空。")
+
+ self.model = provider_config.get("rerank_model", "qwen3-rerank")
+ self.timeout = provider_config.get("timeout", 30)
+ self.return_documents = provider_config.get("return_documents", False)
+ self.instruct = provider_config.get("instruct", "")
+
+ self.base_url = provider_config.get(
+ "rerank_api_base",
+ "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
+ )
+
+ # 设置HTTP客户端
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json",
+ }
+
+ self.client = aiohttp.ClientSession(
+ headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
+ )
+
+ # 设置模型名称
+ self.set_model(self.model)
+
+ logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
+
+ def _build_payload(
+ self, query: str, documents: list[str], top_n: int | None
+ ) -> dict:
+ """构建请求载荷
+
+ Args:
+ query: 查询文本
+ documents: 文档列表
+ top_n: 返回前N个结果,如果为None则返回所有结果
+
+ Returns:
+ 请求载荷字典
+ """
+ base = {"model": self.model, "input": {"query": query, "documents": documents}}
+
+ params = {
+ k: v
+ for k, v in [
+ ("top_n", top_n if top_n is not None and top_n > 0 else None),
+ ("return_documents", True if self.return_documents else None),
+ (
+ "instruct",
+ self.instruct
+ if self.instruct and self.model == "qwen3-rerank"
+ else None,
+ ),
+ ]
+ if v is not None
+ }
+
+ if params:
+ base["parameters"] = params
+
+ return base
+
+ def _parse_results(self, data: dict) -> list[RerankResult]:
+ """解析API响应结果
+
+ Args:
+ data: API响应数据
+
+ Returns:
+ 重排序结果列表
+
+ Raises:
+ BailianAPIError: API返回错误
+ KeyError: 结果缺少必要字段
+ """
+ # 检查响应状态
+ if data.get("code", "200") != "200":
+ raise BailianAPIError(
+ f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
+ )
+
+ results = data.get("output", {}).get("results", [])
+ if not results:
+ logger.warning(f"百炼 Rerank 返回空结果: {data}")
+ return []
+
+ # 转换为RerankResult对象,使用.get()避免KeyError
+ rerank_results = []
+ for idx, result in enumerate(results):
+ try:
+ index = result.get("index", idx)
+ relevance_score = result.get("relevance_score", 0.0)
+
+ if relevance_score is None:
+ logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
+ relevance_score = 0.0
+
+ rerank_result = RerankResult(
+ index=index, relevance_score=relevance_score
+ )
+ rerank_results.append(rerank_result)
+ except Exception as e:
+ logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
+ continue
+
+ return rerank_results
+
+ def _log_usage(self, data: dict) -> None:
+ """记录使用量信息
+
+ Args:
+ data: API响应数据
+ """
+ tokens = data.get("usage", {}).get("total_tokens", 0)
+ if tokens > 0:
+ logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
+
+ async def rerank(
+ self,
+ query: str,
+ documents: list[str],
+ top_n: int | None = None,
+ ) -> list[RerankResult]:
+ """
+ 对文档进行重排序
+
+ Args:
+ query: 查询文本
+ documents: 待排序的文档列表
+ top_n: 返回前N个结果,如果为None则使用配置中的默认值
+
+ Returns:
+ 重排序结果列表
+ """
+ if not self.client:
+ logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
+ return []
+
+ if not documents:
+ logger.warning("文档列表为空,返回空结果")
+ return []
+
+ if not query.strip():
+ logger.warning("查询文本为空,返回空结果")
+ return []
+
+ # 检查限制
+ if len(documents) > 500:
+ logger.warning(
+ f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
+ )
+ documents = documents[:500]
+
+ try:
+ # 构建请求载荷,如果top_n为None则返回所有重排序结果
+ payload = self._build_payload(query, documents, top_n)
+
+ logger.debug(
+ f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
+ )
+
+ # 发送请求
+ async with self.client.post(self.base_url, json=payload) as response:
+ response.raise_for_status()
+ response_data = await response.json()
+
+ # 解析结果并记录使用量
+ results = self._parse_results(response_data)
+ self._log_usage(response_data)
+
+ logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
+
+ return results
+
+ except aiohttp.ClientError as e:
+ error_msg = f"网络请求失败: {e}"
+ logger.error(f"百炼 Rerank 网络请求失败: {e}")
+ raise BailianNetworkError(error_msg) from e
+ except BailianRerankError:
+ raise
+ except Exception as e:
+ error_msg = f"重排序失败: {e}"
+ logger.error(f"百炼 Rerank 处理失败: {e}")
+ raise BailianRerankError(error_msg) from e
+
+ async def terminate(self) -> None:
+ """关闭HTTP客户端会话."""
+ if self.client:
+ logger.info("关闭 百炼 Rerank 客户端会话")
+ try:
+ await self.client.close()
+ except Exception as e:
+ logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
+ finally:
+ self.client = None
diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py
deleted file mode 100644
index 23a8b3b76..000000000
--- a/astrbot/core/provider/sources/coze_source.py
+++ /dev/null
@@ -1,652 +0,0 @@
-import base64
-import hashlib
-import json
-import os
-from collections.abc import AsyncGenerator
-
-import astrbot.core.message.components as Comp
-from astrbot import logger
-from astrbot.api.provider import Provider
-from astrbot.core.message.message_event_result import MessageChain
-from astrbot.core.provider.entities import LLMResponse
-
-from ..register import register_provider_adapter
-from .coze_api_client import CozeAPIClient
-
-
-@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
-class ProviderCoze(Provider):
- def __init__(
- self,
- provider_config,
- provider_settings,
- default_persona=None,
- ) -> None:
- super().__init__(
- provider_config,
- provider_settings,
- default_persona,
- )
- self.api_key = provider_config.get("coze_api_key", "")
- if not self.api_key:
- raise Exception("Coze API Key 不能为空。")
- self.bot_id = provider_config.get("bot_id", "")
- if not self.bot_id:
- raise Exception("Coze Bot ID 不能为空。")
- self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
-
- if not isinstance(self.api_base, str) or not self.api_base.startswith(
- ("http://", "https://"),
- ):
- raise Exception(
- "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
- )
-
- self.timeout = provider_config.get("timeout", 120)
- if isinstance(self.timeout, str):
- self.timeout = int(self.timeout)
- self.auto_save_history = provider_config.get("auto_save_history", True)
- self.conversation_ids: dict[str, str] = {}
- self.file_id_cache: dict[str, dict[str, str]] = {}
-
- # 创建 API 客户端
- self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
-
- def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
- """生成统一的缓存键
-
- Args:
- data: 图片数据或路径
- is_base64: 是否是 base64 数据
-
- Returns:
- str: 缓存键
-
- """
- try:
- if is_base64 and data.startswith("data:image/"):
- try:
- header, encoded = data.split(",", 1)
- image_bytes = base64.b64decode(encoded)
- cache_key = hashlib.md5(image_bytes).hexdigest()
- return cache_key
- except Exception:
- cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
- return cache_key
- elif data.startswith(("http://", "https://")):
- # URL图片,使用URL作为缓存键
- cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
- return cache_key
- else:
- clean_path = (
- data.split("_")[0]
- if "_" in data and len(data.split("_")) >= 3
- else data
- )
-
- if os.path.exists(clean_path):
- with open(clean_path, "rb") as f:
- file_content = f.read()
- cache_key = hashlib.md5(file_content).hexdigest()
- return cache_key
- cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
- return cache_key
-
- except Exception as e:
- cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
- logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
- return cache_key
-
- async def _upload_file(
- self,
- file_data: bytes,
- session_id: str | None = None,
- cache_key: str | None = None,
- ) -> str:
- """上传文件到 Coze 并返回 file_id"""
- # 使用 API 客户端上传文件
- file_id = await self.api_client.upload_file(file_data)
-
- # 缓存 file_id
- if session_id and cache_key:
- if session_id not in self.file_id_cache:
- self.file_id_cache[session_id] = {}
- self.file_id_cache[session_id][cache_key] = file_id
- logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
-
- return file_id
-
- async def _download_and_upload_image(
- self,
- image_url: str,
- session_id: str | None = None,
- ) -> str:
- """下载图片并上传到 Coze,返回 file_id"""
- # 计算哈希实现缓存
- cache_key = self._generate_cache_key(image_url) if session_id else None
-
- if session_id and cache_key:
- if session_id not in self.file_id_cache:
- self.file_id_cache[session_id] = {}
-
- if cache_key in self.file_id_cache[session_id]:
- file_id = self.file_id_cache[session_id][cache_key]
- return file_id
-
- try:
- image_data = await self.api_client.download_image(image_url)
-
- file_id = await self._upload_file(image_data, session_id, cache_key)
-
- if session_id and cache_key:
- self.file_id_cache[session_id][cache_key] = file_id
-
- return file_id
-
- except Exception as e:
- logger.error(f"处理图片失败 {image_url}: {e!s}")
- raise Exception(f"处理图片失败: {e!s}")
-
- async def _process_context_images(
- self,
- content: str | list,
- session_id: str,
- ) -> str:
- """处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
- try:
- if isinstance(content, str):
- return content
-
- processed_content = []
- if session_id not in self.file_id_cache:
- self.file_id_cache[session_id] = {}
-
- for item in content:
- if not isinstance(item, dict):
- processed_content.append(item)
- continue
- if item.get("type") == "text":
- processed_content.append(item)
- elif item.get("type") == "image_url":
- # 处理图片逻辑
- if "file_id" in item:
- # 已经有 file_id
- logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
- processed_content.append(item)
- else:
- # 获取图片数据
- image_data = ""
- if "image_url" in item and isinstance(item["image_url"], dict):
- image_data = item["image_url"].get("url", "")
- elif "data" in item:
- image_data = item.get("data", "")
- elif "url" in item:
- image_data = item.get("url", "")
-
- if not image_data:
- continue
- # 计算哈希用于缓存
- cache_key = self._generate_cache_key(
- image_data,
- is_base64=image_data.startswith("data:image/"),
- )
-
- # 检查缓存
- if cache_key in self.file_id_cache[session_id]:
- file_id = self.file_id_cache[session_id][cache_key]
- processed_content.append(
- {"type": "image", "file_id": file_id},
- )
- else:
- # 上传图片并缓存
- if image_data.startswith("data:image/"):
- # base64 处理
- _, encoded = image_data.split(",", 1)
- image_bytes = base64.b64decode(encoded)
- file_id = await self._upload_file(
- image_bytes,
- session_id,
- cache_key,
- )
- elif image_data.startswith(("http://", "https://")):
- # URL 图片
- file_id = await self._download_and_upload_image(
- image_data,
- session_id,
- )
- # 为URL图片也添加缓存
- self.file_id_cache[session_id][cache_key] = file_id
- elif os.path.exists(image_data):
- # 本地文件
- with open(image_data, "rb") as f:
- image_bytes = f.read()
- file_id = await self._upload_file(
- image_bytes,
- session_id,
- cache_key,
- )
- else:
- logger.warning(
- f"无法处理的图片格式: {image_data[:50]}...",
- )
- continue
-
- processed_content.append(
- {"type": "image", "file_id": file_id},
- )
-
- result = json.dumps(processed_content, ensure_ascii=False)
- return result
- except Exception as e:
- logger.error(f"处理上下文图片失败: {e!s}")
- if isinstance(content, str):
- return content
- return json.dumps(content, ensure_ascii=False)
-
- async def text_chat(
- self,
- prompt: str,
- session_id=None,
- image_urls=None,
- func_tool=None,
- contexts=None,
- system_prompt=None,
- tool_calls_result=None,
- model=None,
- **kwargs,
- ) -> LLMResponse:
- """文本对话, 内部使用流式接口实现非流式
-
- Args:
- prompt (str): 用户提示词
- session_id (str): 会话ID
- image_urls (List[str]): 图片URL列表
- func_tool (FuncCall): 函数调用工具(不支持)
- contexts (List): 上下文列表
- system_prompt (str): 系统提示语
- tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
- model (str): 模型名称(不支持)
-
- Returns:
- LLMResponse: LLM响应对象
-
- """
- accumulated_content = ""
- final_response = None
-
- async for llm_response in self.text_chat_stream(
- prompt=prompt,
- session_id=session_id,
- image_urls=image_urls,
- func_tool=func_tool,
- contexts=contexts,
- system_prompt=system_prompt,
- tool_calls_result=tool_calls_result,
- model=model,
- **kwargs,
- ):
- if llm_response.is_chunk:
- if llm_response.completion_text:
- accumulated_content += llm_response.completion_text
- else:
- final_response = llm_response
-
- if final_response:
- return final_response
-
- if accumulated_content:
- chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
- return LLMResponse(role="assistant", result_chain=chain)
- return LLMResponse(role="assistant", completion_text="")
-
- async def text_chat_stream(
- self,
- prompt: str,
- session_id=None,
- image_urls=None,
- func_tool=None,
- contexts=None,
- system_prompt=None,
- tool_calls_result=None,
- model=None,
- **kwargs,
- ) -> AsyncGenerator[LLMResponse, None]:
- """流式对话接口"""
- # 用户ID参数(参考文档, 可以自定义)
- user_id = session_id or kwargs.get("user", "default_user")
-
- # 获取或创建会话ID
- conversation_id = self.conversation_ids.get(user_id)
-
- # 构建消息
- additional_messages = []
-
- if system_prompt:
- if not self.auto_save_history or not conversation_id:
- additional_messages.append(
- {
- "role": "system",
- "content": system_prompt,
- "content_type": "text",
- },
- )
-
- contexts = self._ensure_message_to_dicts(contexts)
- if not self.auto_save_history and contexts:
- # 如果关闭了自动保存历史,传入上下文
- for ctx in contexts:
- if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
- content = ctx["content"]
- content_type = ctx.get("content_type", "text")
-
- # 处理可能包含图片的上下文
- if (
- content_type == "object_string"
- or (isinstance(content, str) and content.startswith("["))
- or (
- isinstance(content, list)
- and any(
- isinstance(item, dict)
- and item.get("type") == "image_url"
- for item in content
- )
- )
- ):
- processed_content = await self._process_context_images(
- content,
- user_id,
- )
- additional_messages.append(
- {
- "role": ctx["role"],
- "content": processed_content,
- "content_type": "object_string",
- },
- )
- else:
- # 纯文本
- additional_messages.append(
- {
- "role": ctx["role"],
- "content": (
- content
- if isinstance(content, str)
- else json.dumps(content, ensure_ascii=False)
- ),
- "content_type": "text",
- },
- )
- else:
- logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
-
- if prompt or image_urls:
- if image_urls:
- # 多模态
- object_string_content = []
- if prompt:
- object_string_content.append({"type": "text", "text": prompt})
-
- for url in image_urls:
- try:
- if url.startswith(("http://", "https://")):
- # 网络图片
- file_id = await self._download_and_upload_image(
- url,
- user_id,
- )
- else:
- # 本地文件或 base64
- if url.startswith("data:image/"):
- # base64
- _, encoded = url.split(",", 1)
- image_data = base64.b64decode(encoded)
- cache_key = self._generate_cache_key(
- url,
- is_base64=True,
- )
- file_id = await self._upload_file(
- image_data,
- user_id,
- cache_key,
- )
- # 本地文件
- elif os.path.exists(url):
- with open(url, "rb") as f:
- image_data = f.read()
- # 用文件路径和修改时间来缓存
- file_stat = os.stat(url)
- cache_key = self._generate_cache_key(
- f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
- is_base64=False,
- )
- file_id = await self._upload_file(
- image_data,
- user_id,
- cache_key,
- )
- else:
- logger.warning(f"图片文件不存在: {url}")
- continue
-
- object_string_content.append(
- {
- "type": "image",
- "file_id": file_id,
- },
- )
- except Exception as e:
- logger.error(f"处理图片失败 {url}: {e!s}")
- continue
-
- if object_string_content:
- content = json.dumps(object_string_content, ensure_ascii=False)
- additional_messages.append(
- {
- "role": "user",
- "content": content,
- "content_type": "object_string",
- },
- )
- # 纯文本
- elif prompt:
- additional_messages.append(
- {
- "role": "user",
- "content": prompt,
- "content_type": "text",
- },
- )
-
- try:
- accumulated_content = ""
- message_started = False
-
- async for chunk in self.api_client.chat_messages(
- bot_id=self.bot_id,
- user_id=user_id,
- additional_messages=additional_messages,
- conversation_id=conversation_id,
- auto_save_history=self.auto_save_history,
- stream=True,
- timeout=self.timeout,
- ):
- event_type = chunk.get("event")
- data = chunk.get("data", {})
-
- if event_type == "conversation.chat.created":
- if isinstance(data, dict) and "conversation_id" in data:
- self.conversation_ids[user_id] = data["conversation_id"]
-
- elif event_type == "conversation.message.delta":
- if isinstance(data, dict):
- content = data.get("content", "")
- if not content and "delta" in data:
- content = data["delta"].get("content", "")
- if not content and "text" in data:
- content = data.get("text", "")
-
- if content:
- message_started = True
- accumulated_content += content
- yield LLMResponse(
- role="assistant",
- completion_text=content,
- is_chunk=True,
- )
-
- elif event_type == "conversation.message.completed":
- if isinstance(data, dict):
- msg_type = data.get("type")
- if msg_type == "answer" and data.get("role") == "assistant":
- final_content = data.get("content", "")
- if not accumulated_content and final_content:
- chain = MessageChain(chain=[Comp.Plain(final_content)])
- yield LLMResponse(
- role="assistant",
- result_chain=chain,
- is_chunk=False,
- )
-
- elif event_type == "conversation.chat.completed":
- if accumulated_content:
- chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
- yield LLMResponse(
- role="assistant",
- result_chain=chain,
- is_chunk=False,
- )
- break
-
- elif event_type == "done":
- break
-
- elif event_type == "error":
- error_msg = (
- data.get("message", "未知错误")
- if isinstance(data, dict)
- else str(data)
- )
- logger.error(f"Coze 流式响应错误: {error_msg}")
- yield LLMResponse(
- role="err",
- completion_text=f"Coze 错误: {error_msg}",
- is_chunk=False,
- )
- break
-
- if not message_started and not accumulated_content:
- yield LLMResponse(
- role="assistant",
- completion_text="LLM 未响应任何内容。",
- is_chunk=False,
- )
- elif message_started and accumulated_content:
- chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
- yield LLMResponse(
- role="assistant",
- result_chain=chain,
- is_chunk=False,
- )
-
- except Exception as e:
- logger.error(f"Coze 流式请求失败: {e!s}")
- yield LLMResponse(
- role="err",
- completion_text=f"Coze 流式请求失败: {e!s}",
- is_chunk=False,
- )
-
- async def forget(self, session_id: str):
- """清空指定会话的上下文"""
- user_id = session_id
- conversation_id = self.conversation_ids.get(user_id)
-
- if user_id in self.file_id_cache:
- self.file_id_cache.pop(user_id, None)
-
- if not conversation_id:
- return True
-
- try:
- response = await self.api_client.clear_context(conversation_id)
-
- if "code" in response and response["code"] == 0:
- self.conversation_ids.pop(user_id, None)
- return True
- logger.warning(f"清空 Coze 会话上下文失败: {response}")
- return False
-
- except Exception as e:
- logger.error(f"清空 Coze 会话失败: {e!s}")
- return False
-
- async def get_current_key(self):
- """获取当前API Key"""
- return self.api_key
-
- async def set_key(self, key: str):
- """设置新的API Key"""
- raise NotImplementedError("Coze 适配器不支持设置 API Key。")
-
- async def get_models(self):
- """获取可用模型列表"""
- return [f"bot_{self.bot_id}"]
-
- def get_model(self):
- """获取当前模型"""
- return f"bot_{self.bot_id}"
-
- def set_model(self, model: str):
- """设置模型(在Coze中是Bot ID)"""
- if model.startswith("bot_"):
- self.bot_id = model[4:]
- else:
- self.bot_id = model
-
- async def get_human_readable_context(
- self,
- session_id: str,
- page: int = 1,
- page_size: int = 10,
- ):
- """获取人类可读的上下文历史"""
- user_id = session_id
- conversation_id = self.conversation_ids.get(user_id)
-
- if not conversation_id:
- return []
-
- try:
- data = await self.api_client.get_message_list(
- conversation_id=conversation_id,
- order="desc",
- limit=page_size,
- offset=(page - 1) * page_size,
- )
-
- if data.get("code") != 0:
- logger.warning(f"获取 Coze 消息历史失败: {data}")
- return []
-
- messages = data.get("data", {}).get("messages", [])
-
- readable_history = []
- for msg in messages:
- role = msg.get("role", "unknown")
- content = msg.get("content", "")
- msg_type = msg.get("type", "")
-
- if role == "user":
- readable_history.append(f"用户: {content}")
- elif role == "assistant" and msg_type == "answer":
- readable_history.append(f"助手: {content}")
-
- return readable_history
-
- except Exception as e:
- logger.error(f"获取 Coze 消息历史失败: {e!s}")
- return []
-
- async def terminate(self):
- """清理资源"""
- await self.api_client.close()
diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py
deleted file mode 100644
index 9b262c001..000000000
--- a/astrbot/core/provider/sources/dashscope_source.py
+++ /dev/null
@@ -1,209 +0,0 @@
-import asyncio
-import functools
-import re
-
-from dashscope import Application
-from dashscope.app.application_response import ApplicationResponse
-
-from astrbot.core import logger, sp
-from astrbot.core.message.message_event_result import MessageChain
-
-from .. import Personality, Provider
-from ..entities import LLMResponse
-from ..register import register_provider_adapter
-from .openai_source import ProviderOpenAIOfficial
-
-
-@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
-class ProviderDashscope(ProviderOpenAIOfficial):
- def __init__(
- self,
- provider_config: dict,
- provider_settings: dict,
- default_persona: Personality | None = None,
- ) -> None:
- Provider.__init__(
- self,
- provider_config,
- provider_settings,
- default_persona,
- )
- self.api_key = provider_config.get("dashscope_api_key", "")
- if not self.api_key:
- raise Exception("阿里云百炼 API Key 不能为空。")
- self.app_id = provider_config.get("dashscope_app_id", "")
- if not self.app_id:
- raise Exception("阿里云百炼 APP ID 不能为空。")
- self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
- if not self.dashscope_app_type:
- raise Exception("阿里云百炼 APP 类型不能为空。")
- self.model_name = "dashscope"
- self.variables: dict = provider_config.get("variables", {})
- self.rag_options: dict = provider_config.get("rag_options", {})
- self.output_reference = self.rag_options.get("output_reference", False)
- self.rag_options = self.rag_options.copy()
- self.rag_options.pop("output_reference", None)
-
- self.timeout = provider_config.get("timeout", 120)
- if isinstance(self.timeout, str):
- self.timeout = int(self.timeout)
-
- def has_rag_options(self):
- """判断是否有 RAG 选项
-
- Returns:
- bool: 是否有 RAG 选项
-
- """
- if self.rag_options and (
- len(self.rag_options.get("pipeline_ids", [])) > 0
- or len(self.rag_options.get("file_ids", [])) > 0
- ):
- return True
- return False
-
- async def text_chat(
- self,
- prompt: str,
- session_id=None,
- image_urls=None,
- func_tool=None,
- contexts=None,
- system_prompt=None,
- model=None,
- **kwargs,
- ) -> LLMResponse:
- if image_urls is None:
- image_urls = []
- if contexts is None:
- contexts = []
- # 获得会话变量
- payload_vars = self.variables.copy()
- # 动态变量
- session_var = await sp.session_get(session_id, "session_variables", default={})
- payload_vars.update(session_var)
-
- if (
- self.dashscope_app_type in ["agent", "dialog-workflow"]
- and not self.has_rag_options()
- ):
- # 支持多轮对话的
- new_record = {"role": "user", "content": prompt}
- if image_urls:
- logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
- contexts_no_img = await self._remove_image_from_context(contexts)
- context_query = [*contexts_no_img, new_record]
- if system_prompt:
- context_query.insert(0, {"role": "system", "content": system_prompt})
- for part in context_query:
- if "_no_save" in part:
- del part["_no_save"]
- # 调用阿里云百炼 API
- payload = {
- "app_id": self.app_id,
- "api_key": self.api_key,
- "messages": context_query,
- "biz_params": payload_vars or None,
- }
- partial = functools.partial(
- Application.call,
- **payload,
- )
- response = await asyncio.get_event_loop().run_in_executor(None, partial)
- else:
- # 不支持多轮对话的
- # 调用阿里云百炼 API
- payload = {
- "app_id": self.app_id,
- "prompt": prompt,
- "api_key": self.api_key,
- "biz_params": payload_vars or None,
- }
- if self.rag_options:
- payload["rag_options"] = self.rag_options
- partial = functools.partial(
- Application.call,
- **payload,
- )
- response = await asyncio.get_event_loop().run_in_executor(None, partial)
-
- assert isinstance(response, ApplicationResponse)
-
- logger.debug(f"dashscope resp: {response}")
-
- if response.status_code != 200:
- logger.error(
- f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
- )
- return LLMResponse(
- role="err",
- result_chain=MessageChain().message(
- f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
- ),
- )
-
- output_text = response.output.get("text", "") or ""
- # RAG 引用脚标格式化
- output_text = re.sub(r"[\[(\d+)\]]", r"[\1]", output_text)
- if self.output_reference and response.output.get("doc_references", None):
- ref_parts = []
- for ref in response.output.get("doc_references", []) or []:
- ref_title = (
- ref.get("title", "")
- if ref.get("title")
- else ref.get("doc_name", "")
- )
- ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
- ref_str = "".join(ref_parts)
- output_text += f"\n\n回答来源:\n{ref_str}"
-
- llm_response = LLMResponse("assistant")
- llm_response.result_chain = MessageChain().message(output_text)
-
- return llm_response
-
- async def text_chat_stream(
- self,
- prompt,
- session_id=None,
- image_urls=...,
- func_tool=None,
- contexts=...,
- system_prompt=None,
- tool_calls_result=None,
- model=None,
- **kwargs,
- ):
- # raise NotImplementedError("This method is not implemented yet.")
- # 调用 text_chat 模拟流式
- llm_response = await self.text_chat(
- prompt=prompt,
- session_id=session_id,
- image_urls=image_urls,
- func_tool=func_tool,
- contexts=contexts,
- system_prompt=system_prompt,
- tool_calls_result=tool_calls_result,
- )
- llm_response.is_chunk = True
- yield llm_response
- llm_response.is_chunk = False
- yield llm_response
-
- async def forget(self, session_id):
- return True
-
- async def get_current_key(self):
- return self.api_key
-
- async def set_key(self, key):
- raise Exception("阿里云百炼 适配器不支持设置 API Key。")
-
- async def get_models(self):
- return [self.get_model()]
-
- async def get_human_readable_context(self, session_id, page, page_size):
- raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
-
- async def terminate(self):
- pass
diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py
index 44e9965cc..50bc421fd 100644
--- a/astrbot/core/provider/sources/dashscope_tts.py
+++ b/astrbot/core/provider/sources/dashscope_tts.py
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
- self.set_model(provider_config.get("model"))
+ self.set_model(provider_config["model"])
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
kwargs = {
"model": model,
- "text": text,
+ "messages": None,
"api_key": self.chosen_api_key,
"voice": self.voice or "Cherry",
+ "text": text,
}
if not self.voice:
logging.warning(
diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py
deleted file mode 100644
index 9f9f146aa..000000000
--- a/astrbot/core/provider/sources/dify_source.py
+++ /dev/null
@@ -1,287 +0,0 @@
-import os
-
-import astrbot.core.message.components as Comp
-from astrbot.core import logger, sp
-from astrbot.core.message.message_event_result import MessageChain
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
-from astrbot.core.utils.dify_api_client import DifyAPIClient
-from astrbot.core.utils.io import download_file, download_image_by_url
-
-from .. import Provider
-from ..entities import LLMResponse
-from ..register import register_provider_adapter
-
-
-@register_provider_adapter("dify", "Dify APP 适配器。")
-class ProviderDify(Provider):
- def __init__(
- self,
- provider_config,
- provider_settings,
- default_persona=None,
- ) -> None:
- super().__init__(
- provider_config,
- provider_settings,
- default_persona,
- )
- self.api_key = provider_config.get("dify_api_key", "")
- if not self.api_key:
- raise Exception("Dify API Key 不能为空。")
- api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
- self.api_type = provider_config.get("dify_api_type", "")
- if not self.api_type:
- raise Exception("Dify API 类型不能为空。")
- self.model_name = "dify"
- self.workflow_output_key = provider_config.get(
- "dify_workflow_output_key",
- "astrbot_wf_output",
- )
- self.dify_query_input_key = provider_config.get(
- "dify_query_input_key",
- "astrbot_text_query",
- )
- if not self.dify_query_input_key:
- self.dify_query_input_key = "astrbot_text_query"
- if not self.workflow_output_key:
- self.workflow_output_key = "astrbot_wf_output"
- self.variables: dict = provider_config.get("variables", {})
- self.timeout = provider_config.get("timeout", 120)
- if isinstance(self.timeout, str):
- self.timeout = int(self.timeout)
- self.conversation_ids = {}
- """记录当前 session id 的对话 ID"""
-
- self.api_client = DifyAPIClient(self.api_key, api_base)
-
- async def text_chat(
- self,
- prompt: str,
- session_id=None,
- image_urls=None,
- func_tool=None,
- contexts=None,
- system_prompt=None,
- tool_calls_result=None,
- model=None,
- **kwargs,
- ) -> LLMResponse:
- if image_urls is None:
- image_urls = []
- result = ""
- session_id = session_id or kwargs.get("user") or "unknown" # 1734
- conversation_id = self.conversation_ids.get(session_id, "")
-
- files_payload = []
- for image_url in image_urls:
- image_path = (
- await download_image_by_url(image_url)
- if image_url.startswith("http")
- else image_url
- )
- file_response = await self.api_client.file_upload(
- image_path,
- user=session_id,
- )
- logger.debug(f"Dify 上传图片响应:{file_response}")
- if "id" not in file_response:
- logger.warning(
- f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。",
- )
- continue
- files_payload.append(
- {
- "type": "image",
- "transfer_method": "local_file",
- "upload_file_id": file_response["id"],
- },
- )
-
- # 获得会话变量
- payload_vars = self.variables.copy()
- # 动态变量
- session_var = await sp.session_get(session_id, "session_variables", default={})
- payload_vars.update(session_var)
- payload_vars["system_prompt"] = system_prompt
-
- try:
- match self.api_type:
- case "chat" | "agent" | "chatflow":
- if not prompt:
- prompt = "请描述这张图片。"
-
- async for chunk in self.api_client.chat_messages(
- inputs={
- **payload_vars,
- },
- query=prompt,
- user=session_id,
- conversation_id=conversation_id,
- files=files_payload,
- timeout=self.timeout,
- ):
- logger.debug(f"dify resp chunk: {chunk}")
- if (
- chunk["event"] == "message"
- or chunk["event"] == "agent_message"
- ):
- result += chunk["answer"]
- if not conversation_id:
- self.conversation_ids[session_id] = chunk[
- "conversation_id"
- ]
- conversation_id = chunk["conversation_id"]
- elif chunk["event"] == "message_end":
- logger.debug("Dify message end")
- break
- elif chunk["event"] == "error":
- logger.error(f"Dify 出现错误:{chunk}")
- raise Exception(
- f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}",
- )
-
- case "workflow":
- async for chunk in self.api_client.workflow_run(
- inputs={
- self.dify_query_input_key: prompt,
- "astrbot_session_id": session_id,
- **payload_vars,
- },
- user=session_id,
- files=files_payload,
- timeout=self.timeout,
- ):
- match chunk["event"]:
- case "workflow_started":
- logger.info(
- f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。",
- )
- case "node_finished":
- logger.debug(
- f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。",
- )
- case "workflow_finished":
- logger.info(
- f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束",
- )
- logger.debug(f"Dify 工作流结果:{chunk}")
- if chunk["data"]["error"]:
- logger.error(
- f"Dify 工作流出现错误:{chunk['data']['error']}",
- )
- raise Exception(
- f"Dify 工作流出现错误:{chunk['data']['error']}",
- )
- if (
- self.workflow_output_key
- not in chunk["data"]["outputs"]
- ):
- raise Exception(
- f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}",
- )
- result = chunk
- case _:
- raise Exception(f"未知的 Dify API 类型:{self.api_type}")
- except Exception as e:
- logger.error(f"Dify 请求失败:{e!s}")
- return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}")
-
- if not result:
- logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
-
- chain = await self.parse_dify_result(result)
-
- return LLMResponse(role="assistant", result_chain=chain)
-
- async def text_chat_stream(
- self,
- prompt,
- session_id=None,
- image_urls=...,
- func_tool=None,
- contexts=...,
- system_prompt=None,
- tool_calls_result=None,
- model=None,
- **kwargs,
- ):
- # raise NotImplementedError("This method is not implemented yet.")
- # 调用 text_chat 模拟流式
- llm_response = await self.text_chat(
- prompt=prompt,
- session_id=session_id,
- image_urls=image_urls,
- func_tool=func_tool,
- contexts=contexts,
- system_prompt=system_prompt,
- tool_calls_result=tool_calls_result,
- )
- llm_response.is_chunk = True
- yield llm_response
- llm_response.is_chunk = False
- yield llm_response
-
- async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
- if isinstance(chunk, str):
- # Chat
- return MessageChain(chain=[Comp.Plain(chunk)])
-
- async def parse_file(item: dict):
- match item["type"]:
- case "image":
- return Comp.Image(file=item["url"], url=item["url"])
- case "audio":
- # 仅支持 wav
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- path = os.path.join(temp_dir, f"{item['filename']}.wav")
- await download_file(item["url"], path)
- return Comp.Image(file=item["url"], url=item["url"])
- case "video":
- return Comp.Video(file=item["url"])
- case _:
- return Comp.File(name=item["filename"], file=item["url"])
-
- output = chunk["data"]["outputs"][self.workflow_output_key]
- chains = []
- if isinstance(output, str):
- # 纯文本输出
- chains.append(Comp.Plain(output))
- elif isinstance(output, list):
- # 主要适配 Dify 的 HTTP 请求结点的多模态输出
- for item in output:
- # handle Array[File]
- if (
- not isinstance(item, dict)
- or item.get("dify_model_identity", "") != "__dify__file__"
- ):
- chains.append(Comp.Plain(str(output)))
- break
- else:
- chains.append(Comp.Plain(str(output)))
-
- # scan file
- files = chunk["data"].get("files", [])
- for item in files:
- comp = await parse_file(item)
- chains.append(comp)
-
- return MessageChain(chain=chains)
-
- async def forget(self, session_id):
- self.conversation_ids[session_id] = ""
- return True
-
- async def get_current_key(self):
- return self.api_key
-
- async def set_key(self, key):
- raise Exception("Dify 适配器不支持设置 API Key。")
-
- async def get_models(self):
- return [self.get_model()]
-
- async def get_human_readable_context(self, session_id, page, page_size):
- raise Exception("暂不支持获得 Dify 的历史消息记录。")
-
- async def terminate(self):
- await self.api_client.close()
diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py
index 8bbf62325..71a5a82d6 100644
--- a/astrbot/core/provider/sources/edge_tts_source.py
+++ b/astrbot/core/provider/sources/edge_tts_source.py
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
from pyffmpeg import FFmpeg
ff = FFmpeg()
- ff.convert(input=mp3_path, output=wav_path)
+ ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line
diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py
index ca571c3ee..8362ce1b4 100644
--- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py
+++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}",
}
- self.set_model(provider_config.get("model"))
+ self.set_model(provider_config["model"])
- async def _get_reference_id_by_character(self, character: str) -> str:
+ async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id
Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
pattern = r"^[a-fA-F0-9]{32}$"
return bool(re.match(pattern, reference_id.strip()))
- async def _generate_request(self, text: str) -> dict:
+ async def _generate_request(self, text: str) -> ServeTTSRequest:
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip():
# 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
async for chunk in response.aiter_bytes():
f.write(chunk)
return path
- text = await response.aread()
+ body = await response.aread()
+ text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}")
diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py
index 8d11cce5f..146b50a4e 100644
--- a/astrbot/core/provider/sources/gemini_embedding_source.py
+++ b/astrbot/core/provider/sources/gemini_embedding_source.py
@@ -1,3 +1,5 @@
+from typing import cast
+
from google import genai
from google.genai import types
from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
self.provider_config = provider_config
self.provider_settings = provider_settings
- api_key: str = provider_config.get("embedding_api_key")
- api_base: str = provider_config.get("embedding_api_base")
+ api_key: str = provider_config["embedding_api_key"]
+ api_base: str = provider_config["embedding_api_base"]
timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
model=self.model,
contents=text,
)
+ assert result.embeddings is not None
+ assert result.embeddings[0].values is not None
return result.embeddings[0].values
except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
- async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
+ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
try:
result = await self.client.models.embed_content(
model=self.model,
- contents=texts,
+ contents=cast(types.ContentListUnion, text),
)
- return [embedding.values for embedding in result.embeddings]
+ assert result.embeddings is not None
+
+ embeddings: list[list[float]] = []
+ for embedding in result.embeddings:
+ assert embedding.values is not None
+ embeddings.append(embedding.values)
+ return embeddings
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py
index c3c9253a5..5a56170a5 100644
--- a/astrbot/core/provider/sources/gemini_source.py
+++ b/astrbot/core/provider/sources/gemini_source.py
@@ -4,6 +4,7 @@ import json
import logging
import random
from collections.abc import AsyncGenerator
+from typing import cast
from google import genai
from google.genai import types
@@ -13,7 +14,7 @@ import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
-from astrbot.core.provider.entities import LLMResponse
+from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
@@ -53,12 +54,10 @@ class ProviderGoogleGenAI(Provider):
self,
provider_config,
provider_settings,
- default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
- default_persona,
)
self.api_keys: list = super().get_keys()
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
@@ -113,9 +112,9 @@ class ProviderGoogleGenAI(Provider):
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
- logger.error(
- f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
- )
+ # logger.error(
+ # f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
+ # )
raise e
async def _prepare_query_config(
@@ -128,18 +127,18 @@ class ProviderGoogleGenAI(Provider):
) -> types.GenerateContentConfig:
"""准备查询配置"""
if not modalities:
- modalities = ["Text"]
+ modalities = ["TEXT"]
# 流式输出不支持图片模态
if (
self.provider_settings.get("streaming_response", False)
- and "Image" in modalities
+ and "IMAGE" in modalities
):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
- modalities = ["Text"]
+ modalities = ["TEXT"]
- tool_list = []
- model_name = self.get_model()
+ tool_list: list[types.Tool] | None = []
+ model_name = payloads.get("model", self.get_model())
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False)
url_context = self.provider_config.get("gm_url_context", False)
@@ -198,6 +197,37 @@ class ProviderGoogleGenAI(Provider):
types.Tool(function_declarations=func_desc["function_declarations"]),
]
+ # oper thinking config
+ thinking_config = None
+ if model_name.startswith("gemini-2.5"):
+ # The thinkingBudget parameter, introduced with the Gemini 2.5 series
+ thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
+ "budget", 0
+ )
+ if thinking_budget is not None:
+ thinking_config = types.ThinkingConfig(
+ thinking_budget=thinking_budget,
+ )
+ elif model_name.startswith("gemini-3"):
+ # The thinkingLevel parameter, recommended for Gemini 3 models and onwards
+ # Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
+ thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
+ "level", "HIGH"
+ )
+ if thinking_level and isinstance(thinking_level, str):
+ thinking_level = thinking_level.upper()
+ if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
+ logger.warning(
+ f"Invalid thinking level: {thinking_level}, using HIGH"
+ )
+ thinking_level = "HIGH"
+ level = types.ThinkingLevel(thinking_level)
+ thinking_config = types.ThinkingConfig()
+ if not hasattr(types.ThinkingConfig, "thinking_level"):
+ setattr(types.ThinkingConfig, "thinking_level", level)
+ else:
+ thinking_config.thinking_level = level
+
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
@@ -215,24 +245,9 @@ class ProviderGoogleGenAI(Provider):
logprobs=payloads.get("logprobs"),
seed=payloads.get("seed"),
response_modalities=modalities,
- tools=tool_list,
+ tools=cast(types.ToolListUnion | None, tool_list),
safety_settings=self.safety_settings if self.safety_settings else None,
- thinking_config=(
- types.ThinkingConfig(
- thinking_budget=min(
- int(
- self.provider_config.get("gm_thinking_config", {}).get(
- "budget",
- 0,
- ),
- ),
- 24576,
- ),
- )
- if "gemini-2.5-flash" in self.get_model()
- and hasattr(types.ThinkingConfig, "thinking_budget")
- else None
- ),
+ thinking_config=thinking_config,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True,
),
@@ -259,6 +274,7 @@ class ProviderGoogleGenAI(Provider):
content_cls: type[types.Content],
) -> None:
if contents and isinstance(contents[-1], content_cls):
+ assert contents[-1].parts is not None
contents[-1].parts.extend(part)
else:
contents.append(content_cls(parts=part))
@@ -292,13 +308,24 @@ class ProviderGoogleGenAI(Provider):
parts = [types.Part.from_text(text=content)]
append_or_extend(gemini_contents, parts, types.ModelContent)
elif not native_tool_enabled and "tool_calls" in message:
- parts = [
- types.Part.from_function_call(
+ parts = []
+ for tool in message["tool_calls"]:
+ part = types.Part.from_function_call(
name=tool["function"]["name"],
args=json.loads(tool["function"]["arguments"]),
)
- for tool in message["tool_calls"]
- ]
+ # we should set thought_signature back to part if exists
+ # for more info about thought_signature, see:
+ # https://ai.google.dev/gemini-api/docs/thought-signatures
+ if "extra_content" in tool and tool["extra_content"]:
+ ts_bs64 = (
+ tool["extra_content"]
+ .get("google", {})
+ .get("thought_signature")
+ )
+ if ts_bs64:
+ part.thought_signature = base64.b64decode(ts_bs64)
+ parts.append(part)
append_or_extend(gemini_contents, parts, types.ModelContent)
else:
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
@@ -326,8 +353,28 @@ class ProviderGoogleGenAI(Provider):
return gemini_contents
- @staticmethod
+ def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
+ """Extract reasoning content from candidate parts"""
+ if not candidate.content or not candidate.content.parts:
+ return ""
+
+ thought_buf: list[str] = [
+ (p.text or "") for p in candidate.content.parts if p.thought
+ ]
+ return "".join(thought_buf).strip()
+
+ def _extract_usage(
+ self, usage_metadata: types.GenerateContentResponseUsageMetadata
+ ) -> TokenUsage:
+ """Extract usage from candidate"""
+ return TokenUsage(
+ input_other=usage_metadata.prompt_token_count or 0,
+ input_cached=usage_metadata.cached_content_token_count or 0,
+ output=usage_metadata.candidates_token_count or 0,
+ )
+
def _process_content_parts(
+ self,
candidate: types.Candidate,
llm_response: LLMResponse,
) -> MessageChain:
@@ -358,6 +405,11 @@ class ProviderGoogleGenAI(Provider):
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
raise Exception("API 返回的 candidate.content.parts 为空。")
+ # 提取 reasoning content
+ reasoning = self._extract_reasoning_content(candidate)
+ if reasoning:
+ llm_response.reasoning_content = reasoning
+
chain = []
part: types.Part
@@ -380,10 +432,15 @@ class ProviderGoogleGenAI(Provider):
llm_response.role = "tool"
llm_response.tools_call_name.append(part.function_call.name)
llm_response.tools_call_args.append(part.function_call.args)
- # gemini 返回的 function_call.id 可能为 None
- llm_response.tools_call_ids.append(
- part.function_call.id or part.function_call.name,
- )
+ # function_call.id might be None, use name as fallback
+ tool_call_id = part.function_call.id or part.function_call.name
+ llm_response.tools_call_ids.append(tool_call_id)
+ # extra_content
+ if part.thought_signature:
+ ts_bs64 = base64.b64encode(part.thought_signature).decode("utf-8")
+ llm_response.tools_call_extra_content[tool_call_id] = {
+ "google": {"thought_signature": ts_bs64}
+ }
elif (
part.inline_data
and part.inline_data.mime_type
@@ -400,9 +457,11 @@ class ProviderGoogleGenAI(Provider):
None,
)
- modalities = ["Text"]
+ model = payloads.get("model", self.get_model())
+
+ modalities = ["TEXT"]
if self.provider_config.get("gm_resp_image_modal", False):
- modalities.append("Image")
+ modalities.append("IMAGE")
conversation = self._prepare_conversation(payloads)
temperature = payloads.get("temperature", 0.7)
@@ -418,10 +477,11 @@ class ProviderGoogleGenAI(Provider):
temperature,
)
result = await self.client.models.generate_content(
- model=self.get_model(),
- contents=conversation,
+ model=model,
+ contents=cast(types.ContentListUnion, conversation),
config=config,
)
+ logger.debug(f"genai result: {result}")
if not result.candidates:
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
@@ -443,11 +503,11 @@ class ProviderGoogleGenAI(Provider):
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
- f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
+ f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
- logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
+ logger.warning(f"{model} 不支持函数调用,已自动去除")
tools = None
elif (
"Multi-modal output is not supported" in e.message
@@ -456,9 +516,9 @@ class ProviderGoogleGenAI(Provider):
or "only supports text output" in e.message
):
logger.warning(
- f"{self.get_model()} 不支持多模态输出,降级为文本模态",
+ f"{model} 不支持多模态输出,降级为文本模态",
)
- modalities = ["Text"]
+ modalities = ["TEXT"]
else:
raise
continue
@@ -469,6 +529,9 @@ class ProviderGoogleGenAI(Provider):
result.candidates[0],
llm_response,
)
+ llm_response.id = result.response_id
+ if result.usage_metadata:
+ llm_response.usage = self._extract_usage(result.usage_metadata)
return llm_response
async def _query_stream(
@@ -481,7 +544,7 @@ class ProviderGoogleGenAI(Provider):
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
-
+ model = payloads.get("model", self.get_model())
conversation = self._prepare_conversation(payloads)
result = None
@@ -493,8 +556,8 @@ class ProviderGoogleGenAI(Provider):
system_instruction,
)
result = await self.client.models.generate_content_stream(
- model=self.get_model(),
- contents=conversation,
+ model=model,
+ contents=cast(types.ContentListUnion, conversation),
config=config,
)
break
@@ -503,11 +566,11 @@ class ProviderGoogleGenAI(Provider):
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
- f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
+ f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
- logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
+ logger.warning(f"{model} 不支持函数调用,已自动去除")
tools = None
else:
raise
@@ -515,6 +578,7 @@ class ProviderGoogleGenAI(Provider):
# Accumulate the complete response text for the final response
accumulated_text = ""
+ accumulated_reasoning = ""
final_response = None
async for chunk in result:
@@ -536,12 +600,25 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0],
llm_response,
)
+ llm_response.id = chunk.response_id
+ if chunk.usage_metadata:
+ llm_response.usage = self._extract_usage(chunk.usage_metadata)
yield llm_response
return
+ _f = False
+
+ # 提取 reasoning content
+ reasoning = self._extract_reasoning_content(chunk.candidates[0])
+ if reasoning:
+ _f = True
+ accumulated_reasoning += reasoning
+ llm_response.reasoning_content = reasoning
if chunk.text:
+ _f = True
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
+ if _f:
yield llm_response
if chunk.candidates[0].finish_reason:
@@ -553,12 +630,19 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0],
final_response,
)
+ final_response.id = chunk.response_id
+ if chunk.usage_metadata:
+ final_response.usage = self._extract_usage(chunk.usage_metadata)
break
# Yield final complete response with accumulated text
if not final_response:
final_response = LLMResponse("assistant", is_chunk=False)
+ # Set the complete accumulated reasoning in the final response
+ if accumulated_reasoning:
+ final_response.reasoning_content = accumulated_reasoning
+
# Set the complete accumulated text in the final response
if accumulated_text:
final_response.result_chain = MessageChain(
diff --git a/astrbot/core/provider/sources/groq_source.py b/astrbot/core/provider/sources/groq_source.py
new file mode 100644
index 000000000..fcc8f238f
--- /dev/null
+++ b/astrbot/core/provider/sources/groq_source.py
@@ -0,0 +1,15 @@
+from ..register import register_provider_adapter
+from .openai_source import ProviderOpenAIOfficial
+
+
+@register_provider_adapter(
+ "groq_chat_completion", "Groq Chat Completion Provider Adapter"
+)
+class ProviderGroq(ProviderOpenAIOfficial):
+ def __init__(
+ self,
+ provider_config: dict,
+ provider_settings: dict,
+ ) -> None:
+ super().__init__(provider_config, provider_settings)
+ self.reasoning_key = "reasoning"
diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py
index 5ffc7cc63..9e2d665c7 100644
--- a/astrbot/core/provider/sources/minimax_tts_api_source.py
+++ b/astrbot/core/provider/sources/minimax_tts_api_source.py
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return json.dumps(dict_body)
- async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
+ async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
"""进行流式请求"""
try:
async with (
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
data = json.loads(message[6:])
if "extra_info" in data:
continue
- audio = data.get("data", {}).get("audio")
+ audio: str | None = data.get("data", {}).get(
+ "audio"
+ )
if audio is not None:
yield audio
except json.JSONDecodeError:
diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py
index 368e610ec..c9e03d7af 100644
--- a/astrbot/core/provider/sources/openai_embedding_source.py
+++ b/astrbot/core/provider/sources/openai_embedding_source.py
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
- async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
+ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
- embeddings = await self.client.embeddings.create(input=texts, model=self.model)
+ embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py
index 823287b6f..4aeacf672 100644
--- a/astrbot/core/provider/sources/openai_source.py
+++ b/astrbot/core/provider/sources/openai_source.py
@@ -4,12 +4,15 @@ import inspect
import json
import os
import random
+import re
from collections.abc import AsyncGenerator
from openai import AsyncAzureOpenAI, AsyncOpenAI
-from openai._exceptions import NotFoundError, UnprocessableEntityError
+from openai._exceptions import NotFoundError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion
+from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
+from openai.types.completion_usage import CompletionUsage
import astrbot.core.message.components as Comp
from astrbot import logger
@@ -17,7 +20,7 @@ from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
-from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
+from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
@@ -28,17 +31,8 @@ from ..register import register_provider_adapter
"OpenAI API Chat Completion 提供商适配器",
)
class ProviderOpenAIOfficial(Provider):
- def __init__(
- self,
- provider_config,
- provider_settings,
- default_persona=None,
- ) -> None:
- super().__init__(
- provider_config,
- provider_settings,
- default_persona,
- )
+ def __init__(self, provider_config, provider_settings) -> None:
+ super().__init__(provider_config, provider_settings)
self.chosen_api_key = None
self.api_keys: list = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
@@ -53,9 +47,8 @@ class ProviderOpenAIOfficial(Provider):
for key in self.custom_headers:
self.custom_headers[key] = str(self.custom_headers[key])
- # 适配 azure openai #332
if "api_version" in provider_config:
- # 使用 azure api
+ # Using Azure OpenAI API
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
@@ -64,7 +57,7 @@ class ProviderOpenAIOfficial(Provider):
timeout=self.timeout,
)
else:
- # 使用 openai api
+ # Using OpenAI Official API
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
@@ -80,6 +73,8 @@ class ProviderOpenAIOfficial(Provider):
model = model_config.get("model", "unknown")
self.set_model(model)
+ self.reasoning_key = "reasoning_content"
+
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
@@ -157,7 +152,7 @@ class ProviderOpenAIOfficial(Provider):
logger.debug(f"completion: {completion}")
- llm_response = await self.parse_openai_completion(completion, tools)
+ llm_response = await self._parse_openai_completion(completion, tools)
return llm_response
@@ -210,43 +205,102 @@ class ProviderOpenAIOfficial(Provider):
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
- # 处理文本内容
+ # logger.debug(f"chunk delta: {delta}")
+ # handle the content delta
+ reasoning = self._extract_reasoning_content(chunk)
+ _y = False
+ llm_response.id = chunk.id
+ if reasoning:
+ llm_response.reasoning_content = reasoning
+ _y = True
if delta.content:
completion_text = delta.content
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
+ _y = True
+ if chunk.usage:
+ llm_response.usage = self._extract_usage(chunk.usage)
+ if _y:
yield llm_response
final_completion = state.get_final_completion()
- llm_response = await self.parse_openai_completion(final_completion, tools)
+ llm_response = await self._parse_openai_completion(final_completion, tools)
yield llm_response
- async def parse_openai_completion(
+ def _extract_reasoning_content(
+ self,
+ completion: ChatCompletion | ChatCompletionChunk,
+ ) -> str:
+ """Extract reasoning content from OpenAI ChatCompletion if available."""
+ reasoning_text = ""
+ if len(completion.choices) == 0:
+ return reasoning_text
+ if isinstance(completion, ChatCompletion):
+ choice = completion.choices[0]
+ reasoning_attr = getattr(choice.message, self.reasoning_key, None)
+ if reasoning_attr:
+ reasoning_text = str(reasoning_attr)
+ elif isinstance(completion, ChatCompletionChunk):
+ delta = completion.choices[0].delta
+ reasoning_attr = getattr(delta, self.reasoning_key, None)
+ if reasoning_attr:
+ reasoning_text = str(reasoning_attr)
+ return reasoning_text
+
+ def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
+ ptd = usage.prompt_tokens_details
+ cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
+ return TokenUsage(
+ input_other=usage.prompt_tokens - cached,
+ input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
+ output=usage.completion_tokens,
+ )
+
+ async def _parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
- """解析 OpenAI 的 ChatCompletion 响应"""
+ """Parse OpenAI ChatCompletion into LLMResponse"""
llm_response = LLMResponse("assistant")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
+ # parse the text completion
if choice.message.content is not None:
# text completion
completion_text = str(choice.message.content).strip()
+ # specially, some providers may set tags around reasoning content in the completion text,
+ # we use regex to remove them, and store then in reasoning_content field
+ reasoning_pattern = re.compile(r"(.*?) ", re.DOTALL)
+ matches = reasoning_pattern.findall(completion_text)
+ if matches:
+ llm_response.reasoning_content = "\n".join(
+ [match.strip() for match in matches],
+ )
+ completion_text = reasoning_pattern.sub("", completion_text).strip()
llm_response.result_chain = MessageChain().message(completion_text)
+ # parse the reasoning content if any
+ # the priority is higher than the tag extraction
+ llm_response.reasoning_content = self._extract_reasoning_content(completion)
+
+ # parse tool calls if any
if choice.message.tool_calls and tools is not None:
- # tools call (function calling)
args_ls = []
func_name_ls = []
tool_call_ids = []
+ tool_call_extra_content_dict = {}
for tool_call in choice.message.tool_calls:
if isinstance(tool_call, str):
# workaround for #1359
tool_call = json.loads(tool_call)
+ if tools is None:
+ # 工具集未提供
+ # Should be unreachable
+ raise Exception("工具集未提供")
for tool in tools.func_list:
if (
tool_call.type == "function"
@@ -260,21 +314,30 @@ class ProviderOpenAIOfficial(Provider):
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
+
+ # gemini-2.5 / gemini-3 series extra_content handling
+ extra_content = getattr(tool_call, "extra_content", None)
+ if extra_content is not None:
+ tool_call_extra_content_dict[tool_call.id] = extra_content
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_call_ids
-
+ llm_response.tools_call_extra_content = tool_call_extra_content_dict
+ # specially handle finish reason
if choice.finish_reason == "content_filter":
raise Exception(
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
)
-
if llm_response.completion_text is None and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}。")
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
llm_response.raw_completion = completion
+ llm_response.id = completion.id
+
+ if completion.usage:
+ llm_response.usage = self._extract_usage(completion.usage)
return llm_response
@@ -317,7 +380,7 @@ class ProviderOpenAIOfficial(Provider):
payloads = {"messages": context_query, **model_config}
- # xAI 原生搜索参数(最小侵入地在此处注入)
+ # xAI origin search tool inject
self._maybe_inject_xai_search(payloads, **kwargs)
return payloads, context_query
@@ -391,7 +454,7 @@ class ProviderOpenAIOfficial(Provider):
)
payloads.pop("tools", None)
return False, chosen_key, available_api_keys, payloads, context_query, None
- logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
+ # logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
@@ -439,12 +502,6 @@ class ProviderOpenAIOfficial(Provider):
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
- except UnprocessableEntityError as e:
- logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
- # 尝试删除所有 image
- new_contexts = await self._remove_image_from_context(context_query)
- payloads["messages"] = new_contexts
- context_query = new_contexts
except Exception as e:
last_exception = e
(
@@ -509,12 +566,6 @@ class ProviderOpenAIOfficial(Provider):
async for response in self._query_stream(payloads, func_tool):
yield response
break
- except UnprocessableEntityError as e:
- logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
- # 尝试删除所有 image
- new_contexts = await self._remove_image_from_context(context_query)
- payloads["messages"] = new_contexts
- context_query = new_contexts
except Exception as e:
last_exception = e
(
@@ -610,4 +661,3 @@ class ProviderOpenAIOfficial(Provider):
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
- return ""
diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py
index 67947c685..a41bd72fd 100644
--- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py
+++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py
@@ -7,6 +7,7 @@ import asyncio
import os
import re
from datetime import datetime
+from typing import cast
from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
- self.set_model(provider_config.get("stt_model"))
+ self.set_model(provider_config["stt_model"])
self.model = None
self.is_emotion = provider_config.get("is_emotion", False)
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
loop = asyncio.get_event_loop()
res = await loop.run_in_executor(
None, # 使用默认的线程池
- lambda: self.model(audio_url, language="auto", use_itn=True),
+ lambda: cast(SenseVoiceSmall, self.model)(
+ audio_url, language="auto", use_itn=True
+ ),
)
# res = self.model(audio_url, language="auto", use_itn=True)
diff --git a/astrbot/core/provider/sources/vllm_rerank_source.py b/astrbot/core/provider/sources/vllm_rerank_source.py
index 3e6f3d33c..edd8a5491 100644
--- a/astrbot/core/provider/sources/vllm_rerank_source.py
+++ b/astrbot/core/provider/sources/vllm_rerank_source.py
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
}
if top_n is not None:
payload["top_n"] = top_n
+ assert self.client is not None
async with self.client.post(
f"{self.base_url}/v1/rerank",
json=payload,
diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py
index 8f6d9e292..fa69206ef 100644
--- a/astrbot/core/provider/sources/whisper_api_source.py
+++ b/astrbot/core/provider/sources/whisper_api_source.py
@@ -6,7 +6,10 @@ from openai import NOT_GIVEN, AsyncOpenAI
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
-from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
+from astrbot.core.utils.tencent_record_helper import (
+ convert_to_pcm_wav,
+ tencent_silk_to_wav,
+)
from ..entities import ProviderType
from ..provider import STTProvider
@@ -33,20 +36,30 @@ class ProviderOpenAIWhisperAPI(STTProvider):
timeout=provider_config.get("timeout", NOT_GIVEN),
)
- self.set_model(provider_config.get("model"))
+ self.set_model(provider_config["model"])
- async def _is_silk_file(self, file_path):
+ async def _get_audio_format(self, file_path):
+ # 定义要检测的头部字节
silk_header = b"SILK"
- with open(file_path, "rb") as f:
- file_header = f.read(8)
+ amr_header = b"#!AMR"
+
+ try:
+ with open(file_path, "rb") as f:
+ file_header = f.read(8)
+ except FileNotFoundError:
+ return None
if silk_header in file_header:
- return True
- return False
+ return "silk"
+
+ if amr_header in file_header:
+ return "amr"
+ return None
async def get_text(self, audio_url: str) -> str:
"""Only supports mp3, mp4, mpeg, m4a, wav, webm"""
is_tencent = False
+ output_path = None
if audio_url.startswith("http"):
if "multimedia.nt.qq.com.cn" in audio_url:
@@ -62,16 +75,35 @@ class ProviderOpenAIWhisperAPI(STTProvider):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
- is_silk = await self._is_silk_file(audio_url)
- if is_silk:
- logger.info("Converting silk file to wav ...")
+ file_format = await self._get_audio_format(audio_url)
+
+ # 判断是否需要转换
+ if file_format in ["silk", "amr"]:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
- await tencent_silk_to_wav(audio_url, output_path)
+
+ if file_format == "silk":
+ logger.info(
+ "Converting silk file to wav using tencent_silk_to_wav..."
+ )
+ await tencent_silk_to_wav(audio_url, output_path)
+ elif file_format == "amr":
+ logger.info(
+ "Converting amr file to wav using convert_to_pcm_wav..."
+ )
+ await convert_to_pcm_wav(audio_url, output_path)
+
audio_url = output_path
result = await self.client.audio.transcriptions.create(
model=self.model_name,
- file=open(audio_url, "rb"),
+ file=("audio.wav", open(audio_url, "rb")),
)
+
+ # remove temp file
+ if output_path and os.path.exists(output_path):
+ try:
+ os.remove(audio_url)
+ except Exception as e:
+ logger.error(f"Failed to remove temp file {audio_url}: {e}")
return result.text
diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py
index fbdc7d626..a14f93f14 100644
--- a/astrbot/core/provider/sources/whisper_selfhosted_source.py
+++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py
@@ -1,6 +1,7 @@
import asyncio
import os
import uuid
+from typing import cast
import whisper
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
- self.set_model(provider_config.get("model"))
+ self.set_model(provider_config["model"])
self.model = None
async def initialize(self):
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
+ if not self.model:
+ raise RuntimeError("Whisper 模型未初始化")
+
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
- return result["text"]
+ return cast(str, result["text"])
diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py
index 29f3ab095..960408550 100644
--- a/astrbot/core/provider/sources/xinference_rerank_source.py
+++ b/astrbot/core/provider/sources/xinference_rerank_source.py
@@ -1,6 +1,11 @@
+from typing import cast
+
from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client,
)
+from xinference_client.client.restful.async_restful_client import (
+ AsyncRESTfulRerankModelHandle,
+)
from astrbot import logger
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
False,
)
self.client = None
- self.model = None
+ self.model: AsyncRESTfulRerankModelHandle | None = None
self.model_uid = None
async def initialize(self):
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
return
if self.model_uid:
- self.model = await self.client.get_model(self.model_uid)
+ self.model = cast(
+ AsyncRESTfulRerankModelHandle,
+ await self.client.get_model(self.model_uid),
+ )
except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}")
diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py
index e7b6ee4f4..ed4bc0bf8 100644
--- a/astrbot/core/provider/sources/zhipu_source.py
+++ b/astrbot/core/provider/sources/zhipu_source.py
@@ -12,10 +12,5 @@ class ProviderZhipu(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
- default_persona=None,
) -> None:
- super().__init__(
- provider_config,
- provider_settings,
- default_persona,
- )
+ super().__init__(provider_config, provider_settings)
diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py
index e27db7405..c474962c5 100644
--- a/astrbot/core/star/__init__.py
+++ b/astrbot/core/star/__init__.py
@@ -2,15 +2,19 @@ from astrbot.core import html_renderer
from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin
+from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .context import Context
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
-class Star(CommandParserMixin):
+class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
+ author: str
+ name: str
+
def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context)
self.context = context
diff --git a/astrbot/core/star/command_management.py b/astrbot/core/star/command_management.py
new file mode 100644
index 000000000..a0b125d33
--- /dev/null
+++ b/astrbot/core/star/command_management.py
@@ -0,0 +1,449 @@
+from __future__ import annotations
+
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import Any
+
+from astrbot.core import db_helper
+from astrbot.core.db.po import CommandConfig
+from astrbot.core.star.filter.command import CommandFilter
+from astrbot.core.star.filter.command_group import CommandGroupFilter
+from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
+from astrbot.core.star.star import star_map
+from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
+
+
+@dataclass
+class CommandDescriptor:
+ handler: StarHandlerMetadata = field(repr=False)
+ filter_ref: CommandFilter | CommandGroupFilter | None = field(
+ default=None,
+ repr=False,
+ )
+ handler_full_name: str = ""
+ handler_name: str = ""
+ plugin_name: str = ""
+ plugin_display_name: str | None = None
+ module_path: str = ""
+ description: str = ""
+ command_type: str = "command" # "command" | "group" | "sub_command"
+ raw_command_name: str | None = None
+ current_fragment: str | None = None
+ parent_signature: str = ""
+ parent_group_handler: str = ""
+ original_command: str | None = None
+ effective_command: str | None = None
+ aliases: list[str] = field(default_factory=list)
+ permission: str = "everyone"
+ enabled: bool = True
+ is_group: bool = False
+ is_sub_command: bool = False
+ reserved: bool = False
+ config: CommandConfig | None = None
+ has_conflict: bool = False
+ sub_commands: list[CommandDescriptor] = field(default_factory=list)
+
+
+async def sync_command_configs() -> None:
+ """同步指令配置,清理过期配置。"""
+ descriptors = _collect_descriptors(include_sub_commands=False)
+ config_records = await db_helper.get_command_configs()
+ config_map = _bind_configs_to_descriptors(descriptors, config_records)
+ live_handlers = {desc.handler_full_name for desc in descriptors}
+
+ stale_configs = [key for key in config_map if key not in live_handlers]
+ if stale_configs:
+ await db_helper.delete_command_configs(stale_configs)
+
+
+async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
+ descriptor = _build_descriptor_by_full_name(handler_full_name)
+ if not descriptor:
+ raise ValueError("指定的处理函数不存在或不是指令。")
+
+ existing_cfg = await db_helper.get_command_config(handler_full_name)
+ config = await db_helper.upsert_command_config(
+ handler_full_name=handler_full_name,
+ plugin_name=descriptor.plugin_name or "",
+ module_path=descriptor.module_path,
+ original_command=descriptor.original_command or descriptor.handler_name,
+ resolved_command=(
+ existing_cfg.resolved_command
+ if existing_cfg
+ else descriptor.current_fragment
+ ),
+ enabled=enabled,
+ keep_original_alias=False,
+ conflict_key=existing_cfg.conflict_key
+ if existing_cfg and existing_cfg.conflict_key
+ else descriptor.original_command,
+ resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
+ note=existing_cfg.note if existing_cfg else None,
+ extra_data=existing_cfg.extra_data if existing_cfg else None,
+ auto_managed=False,
+ )
+ _bind_descriptor_with_config(descriptor, config)
+ await sync_command_configs()
+ return descriptor
+
+
+async def rename_command(
+ handler_full_name: str,
+ new_fragment: str,
+) -> CommandDescriptor:
+ descriptor = _build_descriptor_by_full_name(handler_full_name)
+ if not descriptor:
+ raise ValueError("指定的处理函数不存在或不是指令。")
+
+ new_fragment = new_fragment.strip()
+ if not new_fragment:
+ raise ValueError("指令名不能为空。")
+
+ candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
+ if _is_command_in_use(handler_full_name, candidate_full):
+ raise ValueError("新的指令名已被其他指令占用,请换一个名称。")
+
+ config = await db_helper.upsert_command_config(
+ handler_full_name=handler_full_name,
+ plugin_name=descriptor.plugin_name or "",
+ module_path=descriptor.module_path,
+ original_command=descriptor.original_command or descriptor.handler_name,
+ resolved_command=new_fragment,
+ enabled=True if descriptor.enabled else False,
+ keep_original_alias=False,
+ conflict_key=descriptor.original_command,
+ resolution_strategy="manual_rename",
+ note=None,
+ extra_data=None,
+ auto_managed=False,
+ )
+ _bind_descriptor_with_config(descriptor, config)
+
+ await sync_command_configs()
+ return descriptor
+
+
+async def list_commands() -> list[dict[str, Any]]:
+ descriptors = _collect_descriptors(include_sub_commands=True)
+ config_records = await db_helper.get_command_configs()
+ _bind_configs_to_descriptors(descriptors, config_records)
+
+ conflict_groups = _group_conflicts(descriptors)
+ conflict_handler_names: set[str] = {
+ d.handler_full_name for group in conflict_groups.values() for d in group
+ }
+
+ # 分类,设置冲突标志,将子指令挂载到父指令组
+ group_map: dict[str, CommandDescriptor] = {}
+ sub_commands: list[CommandDescriptor] = []
+ root_commands: list[CommandDescriptor] = []
+
+ for desc in descriptors:
+ desc.has_conflict = desc.handler_full_name in conflict_handler_names
+ if desc.is_group:
+ group_map[desc.handler_full_name] = desc
+ elif desc.is_sub_command:
+ sub_commands.append(desc)
+ else:
+ root_commands.append(desc)
+
+ for sub in sub_commands:
+ if sub.parent_group_handler and sub.parent_group_handler in group_map:
+ group_map[sub.parent_group_handler].sub_commands.append(sub)
+ else:
+ root_commands.append(sub)
+
+ # 指令组 + 普通指令,按 effective_command 字母排序
+ all_commands = list(group_map.values()) + root_commands
+ all_commands.sort(key=lambda d: (d.effective_command or "").lower())
+
+ result = [_descriptor_to_dict(desc) for desc in all_commands]
+ return result
+
+
+async def list_command_conflicts() -> list[dict[str, Any]]:
+ """列出所有冲突的指令组。"""
+ descriptors = _collect_descriptors(include_sub_commands=False)
+ config_records = await db_helper.get_command_configs()
+ _bind_configs_to_descriptors(descriptors, config_records)
+
+ conflict_groups = _group_conflicts(descriptors)
+ details = [
+ {
+ "conflict_key": key,
+ "handlers": [
+ {
+ "handler_full_name": item.handler_full_name,
+ "plugin": item.plugin_name,
+ "current_name": item.effective_command,
+ }
+ for item in group
+ ],
+ }
+ for key, group in conflict_groups.items()
+ ]
+ return details
+
+
+# Internal helpers ----------------------------------------------------------
+
+
+def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
+ """收集指令,按需包含子指令。"""
+ descriptors: list[CommandDescriptor] = []
+ for handler in star_handlers_registry:
+ desc = _build_descriptor(handler)
+ if not desc:
+ continue
+ if not include_sub_commands and desc.is_sub_command:
+ continue
+ descriptors.append(desc)
+ return descriptors
+
+
+def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
+ filter_ref = _locate_primary_filter(handler)
+ if filter_ref is None:
+ return None
+
+ plugin_meta = star_map.get(handler.handler_module_path)
+ plugin_name = (
+ plugin_meta.name if plugin_meta else None
+ ) or handler.handler_module_path
+ plugin_display = plugin_meta.display_name if plugin_meta else None
+
+ is_sub_command = bool(handler.extras_configs.get("sub_command"))
+ parent_group_handler = ""
+
+ if isinstance(filter_ref, CommandFilter):
+ raw_fragment = getattr(
+ filter_ref, "_original_command_name", filter_ref.command_name
+ )
+ current_fragment = filter_ref.command_name
+ parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
+ # 如果是子指令,尝试找到父指令组的 handler_full_name
+ if is_sub_command and parent_signature:
+ parent_group_handler = _find_parent_group_handler(
+ handler.handler_module_path, parent_signature
+ )
+ else:
+ raw_fragment = getattr(
+ filter_ref, "_original_group_name", filter_ref.group_name
+ )
+ current_fragment = filter_ref.group_name
+ parent_signature = _resolve_group_parent_signature(filter_ref)
+
+ original_command = _compose_command(parent_signature, raw_fragment)
+ effective_command = _compose_command(parent_signature, current_fragment)
+
+ # 确定 command_type
+ if isinstance(filter_ref, CommandGroupFilter):
+ command_type = "group"
+ elif is_sub_command:
+ command_type = "sub_command"
+ else:
+ command_type = "command"
+
+ descriptor = CommandDescriptor(
+ handler=handler,
+ filter_ref=filter_ref,
+ handler_full_name=handler.handler_full_name,
+ handler_name=handler.handler_name,
+ plugin_name=plugin_name,
+ plugin_display_name=plugin_display,
+ module_path=handler.handler_module_path,
+ description=handler.desc or "",
+ command_type=command_type,
+ raw_command_name=raw_fragment,
+ current_fragment=current_fragment,
+ parent_signature=parent_signature,
+ parent_group_handler=parent_group_handler,
+ original_command=original_command,
+ effective_command=effective_command,
+ aliases=sorted(getattr(filter_ref, "alias", set())),
+ permission=_determine_permission(handler),
+ enabled=handler.enabled,
+ is_group=isinstance(filter_ref, CommandGroupFilter),
+ is_sub_command=is_sub_command,
+ reserved=plugin_meta.reserved if plugin_meta else False,
+ )
+ return descriptor
+
+
+def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
+ handler = star_handlers_registry.get_handler_by_full_name(full_name)
+ if not handler:
+ return None
+ return _build_descriptor(handler)
+
+
+def _locate_primary_filter(
+ handler: StarHandlerMetadata,
+) -> CommandFilter | CommandGroupFilter | None:
+ for filter_ref in handler.event_filters:
+ if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
+ return filter_ref
+ return None
+
+
+def _determine_permission(handler: StarHandlerMetadata) -> str:
+ for filter_ref in handler.event_filters:
+ if isinstance(filter_ref, PermissionTypeFilter):
+ return (
+ "admin"
+ if filter_ref.permission_type == PermissionType.ADMIN
+ else "member"
+ )
+ return "everyone"
+
+
+def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
+ signatures: list[str] = []
+ parent = group_filter.parent_group
+ while parent:
+ signatures.append(getattr(parent, "_original_group_name", parent.group_name))
+ parent = parent.parent_group
+ return " ".join(reversed(signatures)).strip()
+
+
+def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
+ """根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
+ parent_sig_normalized = parent_signature.strip()
+ for handler in star_handlers_registry:
+ if handler.handler_module_path != module_path:
+ continue
+ filter_ref = _locate_primary_filter(handler)
+ if not isinstance(filter_ref, CommandGroupFilter):
+ continue
+ # 检查该指令组的完整指令名是否匹配 parent_signature
+ group_names = filter_ref.get_complete_command_names()
+ if parent_sig_normalized in group_names:
+ return handler.handler_full_name
+ return ""
+
+
+def _compose_command(parent_signature: str, fragment: str | None) -> str:
+ fragment = (fragment or "").strip()
+ parent_signature = parent_signature.strip()
+ if not parent_signature:
+ return fragment
+ if not fragment:
+ return parent_signature
+ return f"{parent_signature} {fragment}"
+
+
+def _bind_descriptor_with_config(
+ descriptor: CommandDescriptor,
+ config: CommandConfig,
+) -> None:
+ _apply_config_to_descriptor(descriptor, config)
+ _apply_config_to_runtime(descriptor, config)
+
+
+def _apply_config_to_descriptor(
+ descriptor: CommandDescriptor,
+ config: CommandConfig,
+) -> None:
+ descriptor.config = config
+ descriptor.enabled = config.enabled
+
+ if config.original_command:
+ descriptor.original_command = config.original_command
+
+ new_fragment = config.resolved_command or descriptor.current_fragment
+ descriptor.current_fragment = new_fragment
+ descriptor.effective_command = _compose_command(
+ descriptor.parent_signature,
+ new_fragment,
+ )
+
+
+def _apply_config_to_runtime(
+ descriptor: CommandDescriptor,
+ config: CommandConfig,
+) -> None:
+ descriptor.handler.enabled = config.enabled
+ if descriptor.filter_ref and descriptor.current_fragment:
+ _set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
+
+
+def _bind_configs_to_descriptors(
+ descriptors: list[CommandDescriptor],
+ config_records: list[CommandConfig],
+) -> dict[str, CommandConfig]:
+ config_map = {cfg.handler_full_name: cfg for cfg in config_records}
+ for desc in descriptors:
+ if cfg := config_map.get(desc.handler_full_name):
+ _bind_descriptor_with_config(desc, cfg)
+ return config_map
+
+
+def _group_conflicts(
+ descriptors: list[CommandDescriptor],
+) -> dict[str, list[CommandDescriptor]]:
+ conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
+ for desc in descriptors:
+ if desc.effective_command and desc.enabled:
+ conflicts[desc.effective_command].append(desc)
+ return {k: v for k, v in conflicts.items() if len(v) > 1}
+
+
+def _set_filter_fragment(
+ filter_ref: CommandFilter | CommandGroupFilter,
+ fragment: str,
+) -> None:
+ attr = (
+ "group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
+ )
+ current_value = getattr(filter_ref, attr)
+ if fragment == current_value:
+ return
+ setattr(filter_ref, attr, fragment)
+ if hasattr(filter_ref, "_cmpl_cmd_names"):
+ filter_ref._cmpl_cmd_names = None
+
+
+def _is_command_in_use(
+ target_handler_full_name: str,
+ candidate_full_command: str,
+) -> bool:
+ candidate = candidate_full_command.strip()
+ for handler in star_handlers_registry:
+ if handler.handler_full_name == target_handler_full_name:
+ continue
+ filter_ref = _locate_primary_filter(handler)
+ if not filter_ref:
+ continue
+ names = {name.strip() for name in filter_ref.get_complete_command_names()}
+ if candidate in names:
+ return True
+ return False
+
+
+def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
+ result = {
+ "handler_full_name": desc.handler_full_name,
+ "handler_name": desc.handler_name,
+ "plugin": desc.plugin_name,
+ "plugin_display_name": desc.plugin_display_name,
+ "module_path": desc.module_path,
+ "description": desc.description,
+ "type": desc.command_type,
+ "parent_signature": desc.parent_signature,
+ "parent_group_handler": desc.parent_group_handler,
+ "original_command": desc.original_command,
+ "current_fragment": desc.current_fragment,
+ "effective_command": desc.effective_command,
+ "aliases": desc.aliases,
+ "permission": desc.permission,
+ "enabled": desc.enabled,
+ "is_group": desc.is_group,
+ "has_conflict": desc.has_conflict,
+ "reserved": desc.reserved,
+ }
+ # 如果是指令组,包含子指令列表
+ if desc.is_group and desc.sub_commands:
+ result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
+ else:
+ result["sub_commands"] = []
+ return result
diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py
index 1a5bc53d9..2561762f1 100644
--- a/astrbot/core/star/context.py
+++ b/astrbot/core/star/context.py
@@ -5,6 +5,10 @@ from typing import Any
from deprecated import deprecated
+from astrbot.core.agent.hooks import BaseAgentRunHooks
+from astrbot.core.agent.message import Message
+from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
+from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
@@ -13,10 +17,10 @@ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
-from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
-from astrbot.core.provider.entities import ProviderType
+from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import (
@@ -31,6 +35,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
)
+from ..exceptions import ProviderNotFoundError
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from .star import StarMetadata, star_map, star_registry
@@ -75,6 +80,153 @@ class Context:
self.astrbot_config_mgr = astrbot_config_mgr
self.kb_manager = knowledge_base_manager
+ async def llm_generate(
+ self,
+ *,
+ chat_provider_id: str,
+ prompt: str | None = None,
+ image_urls: list[str] | None = None,
+ tools: ToolSet | None = None,
+ system_prompt: str | None = None,
+ contexts: list[Message] | None = None,
+ **kwargs: Any,
+ ) -> LLMResponse:
+ """Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
+
+ .. versionadded:: 4.5.7 (sdk)
+
+ Args:
+ chat_provider_id: The chat provider ID to use.
+ prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
+ image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
+ tools: ToolSet of tools available to the LLM
+ system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
+ contexts: context messages for the LLM
+ **kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
+
+ Raises:
+ ChatProviderNotFoundError: If the specified chat provider ID is not found
+ Exception: For other errors during LLM generation
+ """
+ prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
+ if not prov or not isinstance(prov, Provider):
+ raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
+ llm_resp = await prov.text_chat(
+ prompt=prompt,
+ image_urls=image_urls,
+ func_tool=tools,
+ contexts=contexts,
+ system_prompt=system_prompt,
+ **kwargs,
+ )
+ return llm_resp
+
+ async def tool_loop_agent(
+ self,
+ *,
+ event: AstrMessageEvent,
+ chat_provider_id: str,
+ prompt: str | None = None,
+ image_urls: list[str] | None = None,
+ tools: ToolSet | None = None,
+ system_prompt: str | None = None,
+ contexts: list[Message] | None = None,
+ max_steps: int = 30,
+ tool_call_timeout: int = 60,
+ **kwargs: Any,
+ ) -> LLMResponse:
+ """Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
+ If you do not pass the agent_context parameter, the method will recreate a new agent context.
+
+ .. versionadded:: 4.5.7 (sdk)
+
+ Args:
+ chat_provider_id: The chat provider ID to use.
+ prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
+ image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
+ tools: ToolSet of tools available to the LLM
+ system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
+ contexts: context messages for the LLM
+ max_steps: Maximum number of tool calls before stopping the loop
+ **kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
+ agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
+ agent_context: AstrAgentContext - context to use for the agent
+
+ Returns:
+ The final LLMResponse after tool calls are completed.
+
+ Raises:
+ ChatProviderNotFoundError: If the specified chat provider ID is not found
+ Exception: For other errors during LLM generation
+ """
+ # Import here to avoid circular imports
+ from astrbot.core.astr_agent_context import (
+ AgentContextWrapper,
+ AstrAgentContext,
+ )
+ from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
+
+ prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
+ if not prov or not isinstance(prov, Provider):
+ raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
+
+ agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
+ agent_context = kwargs.get("agent_context")
+
+ context_ = []
+ for msg in contexts or []:
+ if isinstance(msg, Message):
+ context_.append(msg.model_dump())
+ else:
+ context_.append(msg)
+
+ request = ProviderRequest(
+ prompt=prompt,
+ image_urls=image_urls or [],
+ func_tool=tools,
+ contexts=context_,
+ system_prompt=system_prompt or "",
+ )
+ if agent_context is None:
+ agent_context = AstrAgentContext(
+ context=self,
+ event=event,
+ )
+ agent_runner = ToolLoopAgentRunner()
+ tool_executor = FunctionToolExecutor()
+ await agent_runner.reset(
+ provider=prov,
+ request=request,
+ run_context=AgentContextWrapper(
+ context=agent_context,
+ tool_call_timeout=tool_call_timeout,
+ ),
+ tool_executor=tool_executor,
+ agent_hooks=agent_hooks,
+ streaming=kwargs.get("stream", False),
+ )
+ async for _ in agent_runner.step_until_done(max_steps):
+ pass
+ llm_resp = agent_runner.get_final_llm_resp()
+ if not llm_resp:
+ raise Exception("Agent did not produce a final LLM response")
+ return llm_resp
+
+ async def get_current_chat_provider_id(self, umo: str) -> str:
+ """Get the ID of the currently used chat provider.
+
+ Args:
+ umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
+
+ Raises:
+ ProviderNotFoundError: If the specified chat provider is not found
+
+ """
+ prov = self.get_using_provider(umo)
+ if not prov:
+ raise ProviderNotFoundError("Provider not found")
+ return prov.meta().id
+
def get_registered_star(self, star_name: str) -> StarMetadata | None:
"""根据插件名获取插件的 Metadata"""
for star in star_registry:
@@ -107,10 +259,6 @@ class Context:
"""
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
- def register_provider(self, provider: Provider):
- """注册一个 LLM Provider(Chat_Completion 类型)。"""
- self.provider_manager.provider_insts.append(provider)
-
def get_provider_by_id(
self,
provider_id: str,
@@ -137,7 +285,7 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
- def get_using_provider(self, umo: str | None = None) -> Provider | None:
+ def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args:
@@ -148,7 +296,11 @@ class Context:
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
- if prov and not isinstance(prov, Provider):
+ if prov is None:
+ raise ProviderNotFoundError(
+ "provider not found, please choose provider first"
+ )
+ if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
@@ -189,45 +341,6 @@ class Context:
return self._config
return self.astrbot_config_mgr.get_conf(umo)
- def get_db(self) -> BaseDatabase:
- """获取 AstrBot 数据库。"""
- return self._db
-
- def get_event_queue(self) -> Queue:
- """获取事件队列。"""
- return self._event_queue
-
- @deprecated(version="4.0.0", reason="Use get_platform_inst instead")
- def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
- """获取指定类型的平台适配器。
-
- 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
- """
- for platform in self.platform_manager.platform_insts:
- name = platform.meta().name
- if isinstance(platform_type, str):
- if name == platform_type:
- return platform
- elif (
- name in ADAPTER_NAME_2_TYPE
- and ADAPTER_NAME_2_TYPE[name] & platform_type
- ):
- return platform
-
- def get_platform_inst(self, platform_id: str) -> Platform | None:
- """获取指定 ID 的平台适配器实例。
-
- Args:
- platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
-
- Returns:
- Platform: 平台适配器实例,如果未找到则返回 None。
-
- """
- for platform in self.platform_manager.platform_insts:
- if platform.meta().id == platform_id:
- return platform
-
async def send_message(
self,
session: str | MessageSesion,
@@ -300,6 +413,49 @@ class Context:
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""
+ def get_event_queue(self) -> Queue:
+ """获取事件队列。"""
+ return self._event_queue
+
+ @deprecated(version="4.0.0", reason="Use get_platform_inst instead")
+ def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
+ """获取指定类型的平台适配器。
+
+ 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
+ """
+ for platform in self.platform_manager.platform_insts:
+ name = platform.meta().name
+ if isinstance(platform_type, str):
+ if name == platform_type:
+ return platform
+ elif (
+ name in ADAPTER_NAME_2_TYPE
+ and ADAPTER_NAME_2_TYPE[name] & platform_type
+ ):
+ return platform
+
+ def get_platform_inst(self, platform_id: str) -> Platform | None:
+ """获取指定 ID 的平台适配器实例。
+
+ Args:
+ platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
+
+ Returns:
+ Platform: 平台适配器实例,如果未找到则返回 None。
+
+ """
+ for platform in self.platform_manager.platform_insts:
+ if platform.meta().id == platform_id:
+ return platform
+
+ def get_db(self) -> BaseDatabase:
+ """获取 AstrBot 数据库。"""
+ return self._db
+
+ def register_provider(self, provider: Provider):
+ """注册一个 LLM Provider(Chat_Completion 类型)。"""
+ self.provider_manager.provider_insts.append(provider)
+
def register_llm_tool(
self,
name: str,
diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py
index 2a9868fdc..51ad5f089 100755
--- a/astrbot/core/star/filter/command.py
+++ b/astrbot/core/star/filter/command.py
@@ -40,6 +40,7 @@ class CommandFilter(HandlerFilter):
):
self.command_name = command_name
self.alias = alias if alias else set()
+ self._original_command_name = command_name
self.parent_command_names = (
parent_command_names if parent_command_names is not None else [""]
)
diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py
index e1c2efb22..4cbd2c007 100755
--- a/astrbot/core/star/filter/command_group.py
+++ b/astrbot/core/star/filter/command_group.py
@@ -18,6 +18,7 @@ class CommandGroupFilter(HandlerFilter):
):
self.group_name = group_name
self.alias = alias if alias else set()
+ self._original_group_name = group_name
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
self.custom_filter_list: list[CustomFilter] = []
self.parent_group = parent_group
diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py
index 7ce5febd5..daf36a8f6 100644
--- a/astrbot/core/star/register/star_handler.py
+++ b/astrbot/core/star/register/star_handler.py
@@ -1,6 +1,7 @@
from __future__ import annotations
-from collections.abc import Awaitable, Callable
+import re
+from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import docstring_parser
@@ -11,7 +12,8 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
-from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
+from astrbot.core.message.message_event_result import MessageEventResult
+from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from ..filter.command import CommandFilter
@@ -27,13 +29,19 @@ from ..filter.regex import RegexFilter
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
-def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
+def get_handler_full_name(
+ awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
+) -> str:
"""获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(
- handler: Callable[..., Awaitable[Any]],
+ handler: Callable[
+ ...,
+ Awaitable[MessageEventResult | str | None]
+ | AsyncGenerator[MessageEventResult | str | None],
+ ],
event_type: EventType,
dont_add=False,
**kwargs,
@@ -168,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
for (
sub_handle
) in parent_register_commandable.parent_group.sub_command_filters:
+ if isinstance(sub_handle, CommandGroupFilter):
+ continue
# 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
sub_handle_md = sub_handle.get_handler_md()
@@ -179,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
else:
# 裸指令
+ # 确保运行时是可调用的 handler,针对类型检查器添加忽略
+ assert isinstance(awaitable, Callable)
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
@@ -236,7 +248,7 @@ class RegisteringCommandable:
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
command: Callable[..., Callable[..., None]] = register_command
- custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
+ custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
@@ -411,24 +423,49 @@ def register_llm_tool(name: str | None = None, **kwargs):
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
- def decorator(awaitable: Callable[..., Awaitable[Any]]):
+ def decorator(
+ awaitable: Callable[
+ ...,
+ AsyncGenerator[MessageEventResult | str | None]
+ | Awaitable[MessageEventResult | str | None],
+ ],
+ ):
llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc)
args = []
for arg in docstring.params:
- if arg.type_name not in SUPPORTED_TYPES:
+ sub_type_name = None
+ type_name = arg.type_name
+ if not type_name:
+ raise ValueError(
+ f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 的参数 {arg.arg_name} 缺少类型注释。",
+ )
+ # parse type_name to handle cases like "list[string]"
+ match = re.match(r"(\w+)\[(\w+)\]", type_name)
+ if match:
+ type_name = match.group(1)
+ sub_type_name = match.group(2)
+ type_name = PY_TO_JSON_TYPE.get(type_name, type_name)
+ if sub_type_name:
+ sub_type_name = PY_TO_JSON_TYPE.get(sub_type_name, sub_type_name)
+ if type_name not in SUPPORTED_TYPES or (
+ sub_type_name and sub_type_name not in SUPPORTED_TYPES
+ ):
raise ValueError(
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}",
)
- args.append(
- {
- "type": arg.type_name,
- "name": arg.arg_name,
- "description": arg.description,
- },
- )
- # print(llm_tool_name, registering_agent)
+
+ arg_json_schema = {
+ "type": type_name,
+ "name": arg.arg_name,
+ "description": arg.description,
+ }
+ if sub_type_name:
+ if type_name == "array":
+ arg_json_schema["items"] = {"type": sub_type_name}
+ args.append(arg_json_schema)
+
if not registering_agent:
doc_desc = docstring.description.strip() if docstring.description else ""
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py
index 8c40f25c1..9fdca1457 100644
--- a/astrbot/core/star/session_llm_manager.py
+++ b/astrbot/core/star/session_llm_manager.py
@@ -171,110 +171,3 @@ class SessionServiceManager:
# 如果没有配置,默认为启用(兼容性考虑)
return True
-
- @staticmethod
- def set_session_status(session_id: str, enabled: bool) -> None:
- """设置会话的整体启停状态
-
- Args:
- session_id: 会话ID (unified_msg_origin)
- enabled: True表示启用,False表示禁用
-
- """
- session_config = (
- sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
- )
- session_config["session_enabled"] = enabled
- sp.put(
- "session_service_config",
- session_config,
- scope="umo",
- scope_id=session_id,
- )
-
- logger.info(
- f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}",
- )
-
- @staticmethod
- def should_process_session_request(event: AstrMessageEvent) -> bool:
- """检查是否应该处理会话请求(会话整体启停检查)
-
- Args:
- event: 消息事件
-
- Returns:
- bool: True表示应该处理,False表示跳过
-
- """
- session_id = event.unified_msg_origin
- return SessionServiceManager.is_session_enabled(session_id)
-
- # =============================================================================
- # 会话命名相关方法
- # =============================================================================
-
- @staticmethod
- def get_session_custom_name(session_id: str) -> str | None:
- """获取会话的自定义名称
-
- Args:
- session_id: 会话ID (unified_msg_origin)
-
- Returns:
- str: 自定义名称,如果没有设置则返回None
-
- """
- session_services = sp.get(
- "session_service_config",
- {},
- scope="umo",
- scope_id=session_id,
- )
- return session_services.get("custom_name")
-
- @staticmethod
- def set_session_custom_name(session_id: str, custom_name: str) -> None:
- """设置会话的自定义名称
-
- Args:
- session_id: 会话ID (unified_msg_origin)
- custom_name: 自定义名称,可以为空字符串来清除名称
-
- """
- session_config = (
- sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
- )
- if custom_name and custom_name.strip():
- session_config["custom_name"] = custom_name.strip()
- else:
- # 如果传入空名称,则删除自定义名称
- session_config.pop("custom_name", None)
- sp.put(
- "session_service_config",
- session_config,
- scope="umo",
- scope_id=session_id,
- )
-
- logger.info(
- f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}",
- )
-
- @staticmethod
- def get_session_display_name(session_id: str) -> str:
- """获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
-
- Args:
- session_id: 会话ID (unified_msg_origin)
-
- Returns:
- str: 显示名称
-
- """
- custom_name = SessionServiceManager.get_session_custom_name(session_id)
- if custom_name:
- return custom_name
-
- # 如果没有自定义名称,返回session_id的最后一段
- return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py
index c74546fe7..e2ebd11f0 100644
--- a/astrbot/core/star/session_plugin_manager.py
+++ b/astrbot/core/star/session_plugin_manager.py
@@ -42,87 +42,6 @@ class SessionPluginManager:
# 如果都没有配置,默认为启用(兼容性考虑)
return True
- @staticmethod
- def set_plugin_status_for_session(
- session_id: str,
- plugin_name: str,
- enabled: bool,
- ) -> None:
- """设置插件在指定会话中的启停状态
-
- Args:
- session_id: 会话ID (unified_msg_origin)
- plugin_name: 插件名称
- enabled: True表示启用,False表示禁用
-
- """
- # 获取当前配置
- session_plugin_config = sp.get(
- "session_plugin_config",
- {},
- scope="umo",
- scope_id=session_id,
- )
- if session_id not in session_plugin_config:
- session_plugin_config[session_id] = {
- "enabled_plugins": [],
- "disabled_plugins": [],
- }
-
- session_config = session_plugin_config[session_id]
- enabled_plugins = session_config.get("enabled_plugins", [])
- disabled_plugins = session_config.get("disabled_plugins", [])
-
- if enabled:
- # 启用插件
- if plugin_name in disabled_plugins:
- disabled_plugins.remove(plugin_name)
- if plugin_name not in enabled_plugins:
- enabled_plugins.append(plugin_name)
- else:
- # 禁用插件
- if plugin_name in enabled_plugins:
- enabled_plugins.remove(plugin_name)
- if plugin_name not in disabled_plugins:
- disabled_plugins.append(plugin_name)
-
- # 保存配置
- session_config["enabled_plugins"] = enabled_plugins
- session_config["disabled_plugins"] = disabled_plugins
- session_plugin_config[session_id] = session_config
- sp.put(
- "session_plugin_config",
- session_plugin_config,
- scope="umo",
- scope_id=session_id,
- )
-
- logger.info(
- f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}",
- )
-
- @staticmethod
- def get_session_plugin_config(session_id: str) -> dict[str, list[str]]:
- """获取指定会话的插件配置
-
- Args:
- session_id: 会话ID (unified_msg_origin)
-
- Returns:
- Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
-
- """
- session_plugin_config = sp.get(
- "session_plugin_config",
- {},
- scope="umo",
- scope_id=session_id,
- )
- return session_plugin_config.get(
- session_id,
- {"enabled_plugins": [], "disabled_plugins": []},
- )
-
@staticmethod
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
"""根据会话配置过滤处理器列表
diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py
index 141f9180a..be5b4679f 100644
--- a/astrbot/core/star/star_handler.py
+++ b/astrbot/core/star/star_handler.py
@@ -1,9 +1,9 @@
from __future__ import annotations
import enum
-from collections.abc import Awaitable, Callable
+from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
-from typing import Any, Generic, TypeVar
+from typing import Any, Generic, Literal, TypeVar, overload
from .filter import HandlerFilter
from .star import star_map
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
for handler in self._handlers:
print(handler.handler_full_name)
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.OnAstrBotLoadedEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.OnPlatformLoadedEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.AdapterMessageEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[
+ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
+ ]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.OnLLMRequestEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.OnLLMResponseEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.OnDecoratingResultEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.OnCallingFuncToolEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[
+ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
+ ]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: Literal[EventType.OnAfterMessageSentEvent],
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
+
+ @overload
+ def get_handlers_by_event_type(
+ self,
+ event_type: EventType,
+ only_activated=True,
+ plugins_name: list[str] | None = None,
+ ) -> list[
+ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
+ ]: ...
+
def get_handlers_by_event_type(
self,
event_type: EventType,
@@ -40,6 +118,8 @@ class StarHandlerRegistry(Generic[T]):
# 过滤事件类型
if handler.event_type != event_type:
continue
+ if not handler.enabled:
+ continue
# 过滤启用状态
if only_activated:
plugin = star_map.get(handler.handler_module_path)
@@ -111,8 +191,11 @@ class EventType(enum.Enum):
OnAfterMessageSentEvent = enum.auto() # 发送消息后
+H = TypeVar("H", bound=Callable[..., Any])
+
+
@dataclass
-class StarHandlerMetadata:
+class StarHandlerMetadata(Generic[H]):
"""描述一个 Star 所注册的某一个 Handler。"""
event_type: EventType
@@ -127,7 +210,7 @@ class StarHandlerMetadata:
handler_module_path: str
"""Handler 所在的模块路径。"""
- handler: Callable[..., Awaitable[Any]]
+ handler: H
"""Handler 的函数对象,应当是一个异步函数"""
event_filters: list[HandlerFilter]
@@ -139,6 +222,8 @@ class StarHandlerMetadata:
extras_configs: dict = field(default_factory=dict)
"""插件注册的一些其他的信息, 如 priority 等"""
+ enabled: bool = True
+
def __lt__(self, other: StarHandlerMetadata):
"""定义小于运算符以支持优先队列"""
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index abdedc249..1f9f95ae5 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -23,6 +23,7 @@ from astrbot.core.utils.astrbot_path import (
from astrbot.core.utils.io import remove_dir
from . import StarMetadata
+from .command_management import sync_command_configs
from .context import Context
from .filter.permission import PermissionType, PermissionTypeFilter
from .star import star_map, star_registry
@@ -467,6 +468,18 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type(
context=self.context,
)
+
+ p_name = (metadata.name or "unknown").lower().replace("/", "_")
+ p_author = (
+ (metadata.author or "unknown").lower().replace("/", "_")
+ )
+ setattr(metadata.star_cls, "name", p_name)
+ setattr(metadata.star_cls, "author", p_author)
+ setattr(
+ metadata.star_cls,
+ "plugin_id",
+ f"{p_author}/{p_name}",
+ )
else:
logger.info(f"插件 {metadata.name} 已被禁用。")
@@ -618,6 +631,7 @@ class PluginManager:
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
+ await sync_command_configs()
if not fail_rec:
return True, None
diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py
index 07858da5f..27f6232aa 100644
--- a/astrbot/core/umop_config_router.py
+++ b/astrbot/core/umop_config_router.py
@@ -85,3 +85,22 @@ class UmopConfigRouter:
self.umop_to_conf_id[umo] = conf_id
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
+
+ async def delete_route(self, umo: str):
+ """删除一条路由
+
+ Args:
+ umo (str): 需要删除的 UMO 字符串
+
+ Raises:
+ ValueError: 当 umo 格式不正确时抛出
+ """
+
+ if not isinstance(umo, str) or len(umo.split(":")) != 3:
+ raise ValueError(
+ "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
+ )
+
+ if umo in self.umop_to_conf_id:
+ del self.umop_to_conf_id[umo]
+ await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py
index d13bab687..0a7116a0d 100644
--- a/astrbot/core/updator.py
+++ b/astrbot/core/updator.py
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
async def check_update(
self,
- url: str,
- current_version: str,
+ url: str | None,
+ current_version: str | None,
consider_prerelease: bool = True,
- ) -> ReleaseInfo:
+ ) -> ReleaseInfo | None:
"""检查更新"""
return await super().check_update(
self.ASTRBOT_RELEASE_API,
diff --git a/astrbot/core/utils/file_extract.py b/astrbot/core/utils/file_extract.py
new file mode 100644
index 000000000..020ecc67d
--- /dev/null
+++ b/astrbot/core/utils/file_extract.py
@@ -0,0 +1,23 @@
+from pathlib import Path
+
+from openai import AsyncOpenAI
+
+
+async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
+ """Extract text from a file using Moonshot AI API"""
+ """
+ Args:
+ file_path: The path to the file to extract text from
+ api_key: The API key to use to extract text from the file
+ Returns:
+ The text extracted from the file
+ """
+ client = AsyncOpenAI(
+ api_key=api_key,
+ base_url="https://api.moonshot.cn/v1",
+ )
+ file_object = await client.files.create(
+ file=Path(file_path),
+ purpose="file-extract", # type: ignore
+ )
+ return (await client.files.content(file_id=file_object.id)).text
diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py
index 073c04938..fcf5bb3c7 100644
--- a/astrbot/core/utils/io.py
+++ b/astrbot/core/utils/io.py
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
return False
-def save_temp_img(img: Image.Image | str) -> str:
+def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
# 获得文件创建时间,清除超过 12 小时的
try:
diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py
new file mode 100644
index 000000000..5642d606e
--- /dev/null
+++ b/astrbot/core/utils/migra_helper.py
@@ -0,0 +1,73 @@
+import traceback
+
+from astrbot.core import astrbot_config, logger
+from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
+from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
+from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
+
+
+def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None:
+ """
+ Migra agent runner configs from provider configs.
+ """
+ try:
+ default_prov_id = conf["provider_settings"]["default_provider_id"]
+ if default_prov_id in ids_map:
+ conf["provider_settings"]["default_provider_id"] = ""
+ p = ids_map[default_prov_id]
+ if p["type"] == "dify":
+ conf["provider_settings"]["dify_agent_runner_provider_id"] = p["id"]
+ conf["provider_settings"]["agent_runner_type"] = "dify"
+ elif p["type"] == "coze":
+ conf["provider_settings"]["coze_agent_runner_provider_id"] = p["id"]
+ conf["provider_settings"]["agent_runner_type"] = "coze"
+ elif p["type"] == "dashscope":
+ conf["provider_settings"]["dashscope_agent_runner_provider_id"] = p[
+ "id"
+ ]
+ conf["provider_settings"]["agent_runner_type"] = "dashscope"
+ conf.save_config()
+ except Exception as e:
+ logger.error(f"Migration for third party agent runner configs failed: {e!s}")
+ logger.error(traceback.format_exc())
+
+
+async def migra(
+ db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager
+) -> None:
+ """
+ Stores the migration logic here.
+ btw, i really don't like migration :(
+ """
+ # 4.5 to 4.6 migration for umop_config_router
+ try:
+ await migrate_45_to_46(astrbot_config_mgr, umop_config_router)
+ except Exception as e:
+ logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
+ logger.error(traceback.format_exc())
+
+ # migration for webchat session
+ try:
+ await migrate_webchat_session(db)
+ except Exception as e:
+ logger.error(f"Migration for webchat session failed: {e!s}")
+ logger.error(traceback.format_exc())
+
+ # migra third party agent runner configs
+ _c = False
+ providers = astrbot_config["provider"]
+ ids_map = {}
+ for prov in providers:
+ type_ = prov.get("type")
+ if type_ in ["dify", "coze", "dashscope"]:
+ prov["provider_type"] = "agent_runner"
+ ids_map[prov["id"]] = {
+ "type": type_,
+ "id": prov["id"],
+ }
+ _c = True
+ if _c:
+ astrbot_config.save_config()
+
+ for conf in acm.confs.values():
+ _migra_agent_runner_configs(conf, ids_map)
diff --git a/astrbot/core/utils/plugin_kv_store.py b/astrbot/core/utils/plugin_kv_store.py
new file mode 100644
index 000000000..88460c8e1
--- /dev/null
+++ b/astrbot/core/utils/plugin_kv_store.py
@@ -0,0 +1,28 @@
+from typing import TypeVar
+
+from astrbot.core import sp
+
+SUPPORTED_VALUE_TYPES = int | float | str | bytes | bool | dict | list | None
+_VT = TypeVar("_VT")
+
+
+class PluginKVStoreMixin:
+ """为插件提供键值存储功能的 Mixin 类"""
+
+ plugin_id: str
+
+ async def put_kv_data(
+ self,
+ key: str,
+ value: SUPPORTED_VALUE_TYPES,
+ ) -> None:
+ """为指定插件存储一个键值对"""
+ await sp.put_async("plugin", self.plugin_id, key, value)
+
+ async def get_kv_data(self, key: str, default: _VT) -> _VT | None:
+ """获取指定插件存储的键值对"""
+ return await sp.get_async("plugin", self.plugin_id, key, default)
+
+ async def delete_kv_data(self, key: str) -> None:
+ """删除指定插件存储的键值对"""
+ await sp.remove_async("plugin", self.plugin_id, key)
diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py
index 33b7cb17a..e1f2fbef7 100644
--- a/astrbot/core/utils/session_waiter.py
+++ b/astrbot/core/utils/session_waiter.py
@@ -20,16 +20,16 @@ class SessionController:
def __init__(self):
self.future = asyncio.Future()
- self.current_event: asyncio.Event = None
+ self.current_event: asyncio.Event | None = None
"""当前正在等待的所用的异步事件"""
- self.ts: float = None
+ self.ts: float | None = None
"""上次保持(keep)开始时的时间"""
- self.timeout: float | int = None
+ self.timeout: float | int | None = None
"""上次保持(keep)开始时的超时时间"""
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
- def stop(self, error: Exception = None):
+ def stop(self, error: Exception | None = None):
"""立即结束这个会话"""
if not self.future.done():
if error:
@@ -53,6 +53,8 @@ class SessionController:
self.stop()
return
else:
+ assert self.timeout is not None
+ assert self.ts is not None
left_timeout = self.timeout - (new_ts - self.ts)
timeout = left_timeout + timeout
if timeout <= 0:
@@ -69,7 +71,7 @@ class SessionController:
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
- async def _holding(self, event: asyncio.Event, timeout: int):
+ async def _holding(self, event: asyncio.Event, timeout: float):
"""等待事件结束或超时"""
try:
await asyncio.wait_for(event.wait(), timeout)
@@ -108,7 +110,9 @@ class SessionWaiter:
):
self.session_id = session_id
self.session_filter = session_filter
- self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
+ self.handler: (
+ Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
+ ) = None # 处理函数
self.session_controller = SessionController()
self.record_history_chains = record_history_chains
@@ -119,7 +123,7 @@ class SessionWaiter:
async def register_wait(
self,
- handler: Callable[[str], Awaitable[Any]],
+ handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
timeout: int = 30,
) -> Any:
"""等待外部输入并处理"""
@@ -137,7 +141,7 @@ class SessionWaiter:
finally:
self._cleanup()
- def _cleanup(self, error: Exception = None):
+ def _cleanup(self, error: Exception | None = None):
"""清理会话"""
USER_SESSIONS.pop(self.session_id, None)
try:
@@ -161,6 +165,7 @@ class SessionWaiter:
)
try:
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
+ assert session.handler is not None
await session.handler(session.session_controller, event)
except Exception as e:
session.session_controller.stop(e)
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
"""
- def decorator(func: Callable[[str], Awaitable[Any]]):
+ def decorator(
+ func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
+ ):
@functools.wraps(func)
async def wrapper(
event: AstrMessageEvent,
- session_filter: SessionFilter = None,
+ session_filter: SessionFilter | None = None,
*args,
**kwargs,
):
diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py
index c6b4c5ede..ccd394ee4 100644
--- a/astrbot/core/utils/shared_preferences.py
+++ b/astrbot/core/utils/shared_preferences.py
@@ -40,9 +40,6 @@ class SharedPreferences:
else:
ret = default
return ret
- raise ValueError(
- "scope_id and key cannot be None when getting a specific preference.",
- )
async def range_get_async(
self,
@@ -56,6 +53,14 @@ class SharedPreferences:
ret = await self.db_helper.get_preferences(scope, scope_id, key)
return ret
+ @overload
+ async def session_get(
+ self,
+ umo: str,
+ key: str,
+ default: _VT = None,
+ ) -> _VT: ...
+
@overload
async def session_get(
self,
@@ -88,7 +93,7 @@ class SharedPreferences:
) -> _VT | list[Preference]:
"""获取会话范围的偏好设置
- Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
+ Note: 当 umo 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
"""
if umo is None or key is None:
return await self.range_get_async("umo", umo, key)
diff --git a/astrbot/core/utils/t2i/__init__.py b/astrbot/core/utils/t2i/__init__.py
index 5038a46f7..e4112c354 100644
--- a/astrbot/core/utils/t2i/__init__.py
+++ b/astrbot/core/utils/t2i/__init__.py
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
class RenderStrategy(ABC):
@abstractmethod
- def render(self, text: str, return_url: bool) -> str:
+ async def render(self, text: str, return_url: bool) -> str:
pass
@abstractmethod
- def render_custom_template(
+ async def render_custom_template(
self,
tmpl_str: str,
tmpl_data: dict,
diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py
index 19eab2efe..2fa235129 100644
--- a/astrbot/core/utils/t2i/local_strategy.py
+++ b/astrbot/core/utils/t2i/local_strategy.py
@@ -20,7 +20,7 @@ class FontManager:
_font_cache = {}
@classmethod
- def get_font(cls, size: int) -> ImageFont.FreeTypeFont:
+ def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont:
"""获取指定大小的字体,优先从缓存获取"""
if size in cls._font_cache:
return cls._font_cache[size]
@@ -66,23 +66,17 @@ class TextMeasurer:
"""测量文本尺寸的工具类"""
@staticmethod
- def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]:
+ def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]:
"""获取文本的尺寸"""
- try:
- # PIL 9.0.0 以上版本
- return (
- font.getbbox(text)[2:]
- if hasattr(font, "getbbox")
- else font.getsize(text)
- )
- except Exception:
- # 兼容旧版本
- return font.getsize(text)
+
+ # 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0
+ left, top, right, bottom = font.getbbox("Hello world")
+ return int(right - left), int(bottom - top)
@staticmethod
def split_text_to_fit_width(
- text: str, font: ImageFont.FreeTypeFont, max_width: int
- ) -> List[str]:
+ text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int
+ ) -> list[str]:
"""将文本拆分为多行,确保每行不超过指定宽度"""
lines = []
if not text:
@@ -126,7 +120,7 @@ class MarkdownElement(ABC):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -152,7 +146,7 @@ class TextElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement):
# 倾斜变换,使用仿射变换实现斜体效果
# 变换矩阵: [1, 0.2, 0, 0, 1, 0]
italic_img = text_img.transform(
- text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC
+ text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC
)
# 粘贴到原图像
@@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement):
class CodeBlockElement(MarkdownElement):
"""代码块元素"""
- def __init__(self, content: List[str]):
+ def __init__(self, content: list[str]):
super().__init__("\n".join(content))
def calculate_height(self, image_width: int, font_size: int) -> int:
@@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -667,7 +661,7 @@ class ImageElement(MarkdownElement):
def render(
self,
image: Image.Image,
- draw: ImageDraw.Draw,
+ draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -686,7 +680,7 @@ class ImageElement(MarkdownElement):
if pasted_image.width > max_width:
ratio = max_width / pasted_image.width
new_size = (int(max_width), int(pasted_image.height * ratio))
- pasted_image = pasted_image.resize(new_size, Image.LANCZOS)
+ pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS)
# 计算居中位置
paste_x = x + (image_width - pasted_image.width) // 2 - 10
@@ -705,7 +699,7 @@ class MarkdownParser:
"""Markdown解析器,将文本解析为元素"""
@staticmethod
- async def parse(text: str) -> List[MarkdownElement]:
+ async def parse(text: str) -> list[MarkdownElement]:
elements = []
lines = text.split("\n")
@@ -847,7 +841,7 @@ class MarkdownRenderer:
self,
font_size: int = 26,
width: int = 800,
- bg_color: Tuple[int, int, int] = (255, 255, 255),
+ bg_color: tuple[int, int, int] = (255, 255, 255),
):
self.font_size = font_size
self.width = width
diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py
index 9cc36571e..b58643bd3 100644
--- a/astrbot/core/utils/tencent_record_helper.py
+++ b/astrbot/core/utils/tencent_record_helper.py
@@ -36,7 +36,7 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
import pilk
except (ImportError, ModuleNotFoundError) as _:
raise Exception(
- "pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库",
+ "pilk 模块未安装,请前往管理面板->平台日志->安装pip库 安装 pilk 这个库",
)
# with wave.open(wav_path, 'rb') as wav:
# wav_data = wav.readframes(wav.getnframes())
@@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
from pyffmpeg import FFmpeg
ff = FFmpeg()
- ff.convert(input=input_path, output=output_path)
+ ff.convert(input_file=input_path, output_file=output_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
diff --git a/astrbot/core/utils/version_comparator.py b/astrbot/core/utils/version_comparator.py
index e3bf74951..4ad2da10e 100644
--- a/astrbot/core/utils/version_comparator.py
+++ b/astrbot/core/utils/version_comparator.py
@@ -60,9 +60,12 @@ class VersionComparator:
return -1
if isinstance(p1, str) and isinstance(p2, int):
return 1
- if (isinstance(p1, int) and isinstance(p2, int)) or (
- isinstance(p1, str) and isinstance(p2, str)
- ):
+ if isinstance(p1, int) and isinstance(p2, int):
+ if p1 > p2:
+ return 1
+ if p1 < p2:
+ return -1
+ if isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
if p1 < p2:
diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py
new file mode 100644
index 000000000..0e1c3f9cd
--- /dev/null
+++ b/astrbot/core/utils/webhook_utils.py
@@ -0,0 +1,66 @@
+import uuid
+
+from astrbot.core import astrbot_config, logger
+from astrbot.core.config.default import WEBHOOK_SUPPORTED_PLATFORMS
+
+
+def _get_callback_api_base() -> str:
+ try:
+ return astrbot_config.get("callback_api_base", "").rstrip("/")
+ except Exception as e:
+ logger.error(f"获取 callback_api_base 失败: {e!s}")
+ return ""
+
+
+def _get_dashboard_port() -> int:
+ try:
+ return astrbot_config.get("dashboard", {}).get("port", 6185)
+ except Exception as e:
+ logger.error(f"获取 dashboard 端口失败: {e!s}")
+ return 6185
+
+
+def log_webhook_info(platform_name: str, webhook_uuid: str):
+ """打印美观的 webhook 信息日志
+
+ Args:
+ platform_name: 平台名称
+ webhook_uuid: webhook 的 UUID
+ """
+
+ callback_base = _get_callback_api_base()
+
+ if not callback_base:
+ callback_base = "http(s)://"
+
+ if not callback_base.startswith("http"):
+ callback_base = f"http(s)://{callback_base}"
+
+ callback_base = callback_base.rstrip("/")
+ webhook_url = f"{callback_base}/api/platform/webhook/{webhook_uuid}"
+
+ display_log = (
+ "\n====================\n"
+ f"🔗 机器人平台 {platform_name} 已启用统一 Webhook 模式\n"
+ f"📍 Webhook 回调地址: \n"
+ f" ➜ http://:{_get_dashboard_port()}/api/platform/webhook/{webhook_uuid}\n"
+ f" ➜ {webhook_url}\n"
+ "====================\n"
+ )
+ logger.info(display_log)
+
+
+def ensure_platform_webhook_config(platform_cfg: dict) -> bool:
+ """为支持统一 webhook 的平台自动生成 webhook_uuid
+
+ Args:
+ platform_cfg (dict): 平台配置字典
+
+ Returns:
+ bool: 如果生成了 webhook_uuid 则返回 True,否则返回 False
+ """
+ pt = platform_cfg.get("type", "")
+ if pt in WEBHOOK_SUPPORTED_PLATFORMS and not platform_cfg.get("webhook_uuid"):
+ platform_cfg["webhook_uuid"] = uuid.uuid4().hex[:16]
+ return True
+ return False
diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py
index b7997cf8e..951db956c 100644
--- a/astrbot/dashboard/routes/__init__.py
+++ b/astrbot/dashboard/routes/__init__.py
@@ -1,11 +1,13 @@
from .auth import AuthRoute
from .chat import ChatRoute
+from .command import CommandRoute
from .config import ConfigRoute
from .conversation import ConversationRoute
from .file import FileRoute
from .knowledge_base import KnowledgeBaseRoute
from .log import LogRoute
from .persona import PersonaRoute
+from .platform import PlatformRoute
from .plugin import PluginRoute
from .session_management import SessionManagementRoute
from .stat import StatRoute
@@ -16,12 +18,14 @@ from .update import UpdateRoute
__all__ = [
"AuthRoute",
"ChatRoute",
+ "CommandRoute",
"ConfigRoute",
"ConversationRoute",
"FileRoute",
"KnowledgeBaseRoute",
"LogRoute",
"PersonaRoute",
+ "PlatformRoute",
"PluginRoute",
"SessionManagementRoute",
"StatRoute",
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index d7afcbc17..c2b991ef7 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -1,16 +1,17 @@
import asyncio
import json
+import mimetypes
import os
import uuid
from contextlib import asynccontextmanager
+from typing import cast
from quart import Response as QuartResponse
-from quart import g, make_response, request
+from quart import g, make_response, request, send_file
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
-from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -36,13 +37,16 @@ class ChatRoute(Route):
super().__init__(context)
self.routes = {
"/chat/send": ("POST", self.chat),
- "/chat/new_conversation": ("GET", self.new_conversation),
- "/chat/conversations": ("GET", self.get_conversations),
- "/chat/get_conversation": ("GET", self.get_conversation),
- "/chat/delete_conversation": ("GET", self.delete_conversation),
- "/chat/rename_conversation": ("POST", self.rename_conversation),
+ "/chat/new_session": ("GET", self.new_session),
+ "/chat/sessions": ("GET", self.get_sessions),
+ "/chat/get_session": ("GET", self.get_session),
+ "/chat/delete_session": ("GET", self.delete_webchat_session),
+ "/chat/update_session_display_name": (
+ "POST",
+ self.update_session_display_name,
+ ),
"/chat/get_file": ("GET", self.get_file),
- "/chat/post_image": ("POST", self.post_image),
+ "/chat/get_attachment": ("GET", self.get_attachment),
"/chat/post_file": ("POST", self.post_file),
}
self.core_lifecycle = core_lifecycle
@@ -53,6 +57,8 @@ class ChatRoute(Route):
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
self.conv_mgr = core_lifecycle.conversation_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
+ self.db = db
+ self.umop_config_router = core_lifecycle.umop_config_router
self.running_convs: dict[str, bool] = {}
@@ -69,95 +75,230 @@ class ChatRoute(Route):
if not real_file_path.startswith(real_imgs_dir):
return Response().error("Invalid file path").__dict__
- with open(real_file_path, "rb") as f:
- filename_ext = os.path.splitext(filename)[1].lower()
-
- if filename_ext == ".wav":
- return QuartResponse(f.read(), mimetype="audio/wav")
- if filename_ext[1:] in self.supported_imgs:
- return QuartResponse(f.read(), mimetype="image/jpeg")
- return QuartResponse(f.read())
+ filename_ext = os.path.splitext(filename)[1].lower()
+ if filename_ext == ".wav":
+ return await send_file(real_file_path, mimetype="audio/wav")
+ if filename_ext[1:] in self.supported_imgs:
+ return await send_file(real_file_path, mimetype="image/jpeg")
+ return await send_file(real_file_path)
except (FileNotFoundError, OSError):
return Response().error("File access error").__dict__
- async def post_image(self):
- post_data = await request.files
- if "file" not in post_data:
- return Response().error("Missing key: file").__dict__
+ async def get_attachment(self):
+ """Get attachment file by attachment_id."""
+ attachment_id = request.args.get("attachment_id")
+ if not attachment_id:
+ return Response().error("Missing key: attachment_id").__dict__
- file = post_data["file"]
- filename = str(uuid.uuid4()) + ".jpg"
- path = os.path.join(self.imgs_dir, filename)
- await file.save(path)
+ try:
+ attachment = await self.db.get_attachment_by_id(attachment_id)
+ if not attachment:
+ return Response().error("Attachment not found").__dict__
- return Response().ok(data={"filename": filename}).__dict__
+ file_path = attachment.path
+ real_file_path = os.path.realpath(file_path)
+
+ return await send_file(real_file_path, mimetype=attachment.mime_type)
+
+ except (FileNotFoundError, OSError):
+ return Response().error("File access error").__dict__
async def post_file(self):
+ """Upload a file and create an attachment record, return attachment_id."""
post_data = await request.files
if "file" not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data["file"]
- filename = f"{uuid.uuid4()!s}"
- # 通过文件格式判断文件类型
- if file.content_type.startswith("audio"):
- filename += ".wav"
+ filename = file.filename or f"{uuid.uuid4()!s}"
+ content_type = file.content_type or "application/octet-stream"
+
+ # 根据 content_type 判断文件类型并添加扩展名
+ if content_type.startswith("image"):
+ attach_type = "image"
+ elif content_type.startswith("audio"):
+ attach_type = "record"
+ elif content_type.startswith("video"):
+ attach_type = "video"
+ else:
+ attach_type = "file"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
- return Response().ok(data={"filename": filename}).__dict__
+ # 创建 attachment 记录
+ attachment = await self.db.insert_attachment(
+ path=path,
+ type=attach_type,
+ mime_type=content_type,
+ )
+
+ if not attachment:
+ return Response().error("Failed to create attachment").__dict__
+
+ filename = os.path.basename(attachment.path)
+
+ return (
+ Response()
+ .ok(
+ data={
+ "attachment_id": attachment.attachment_id,
+ "filename": filename,
+ "type": attach_type,
+ }
+ )
+ .__dict__
+ )
+
+ async def _build_user_message_parts(self, message: str | list) -> list[dict]:
+ """构建用户消息的部分列表
+
+ Args:
+ message: 文本消息 (str) 或消息段列表 (list)
+ """
+ parts = []
+
+ if isinstance(message, list):
+ for part in message:
+ part_type = part.get("type")
+ if part_type == "plain":
+ parts.append({"type": "plain", "text": part.get("text", "")})
+ elif part_type == "reply":
+ parts.append(
+ {"type": "reply", "message_id": part.get("message_id")}
+ )
+ elif attachment_id := part.get("attachment_id"):
+ attachment = await self.db.get_attachment_by_id(attachment_id)
+ if attachment:
+ parts.append(
+ {
+ "type": attachment.type,
+ "attachment_id": attachment.attachment_id,
+ "filename": os.path.basename(attachment.path),
+ "path": attachment.path, # will be deleted
+ }
+ )
+ return parts
+
+ if message:
+ parts.append({"type": "plain", "text": message})
+
+ return parts
+
+ async def _create_attachment_from_file(
+ self, filename: str, attach_type: str
+ ) -> dict | None:
+ """从本地文件创建 attachment 并返回消息部分
+
+ 用于处理 bot 回复中的媒体文件
+
+ Args:
+ filename: 存储的文件名
+ attach_type: 附件类型 (image, record, file, video)
+ """
+ file_path = os.path.join(self.imgs_dir, os.path.basename(filename))
+ if not os.path.exists(file_path):
+ return None
+
+ # guess mime type
+ mime_type, _ = mimetypes.guess_type(filename)
+ if not mime_type:
+ mime_type = "application/octet-stream"
+
+ # insert attachment
+ attachment = await self.db.insert_attachment(
+ path=file_path,
+ type=attach_type,
+ mime_type=mime_type,
+ )
+ if not attachment:
+ return None
+
+ return {
+ "type": attach_type,
+ "attachment_id": attachment.attachment_id,
+ "filename": os.path.basename(file_path),
+ }
+
+ async def _save_bot_message(
+ self,
+ webchat_conv_id: str,
+ text: str,
+ media_parts: list,
+ reasoning: str,
+ agent_stats: dict,
+ ):
+ """保存 bot 消息到历史记录,返回保存的记录"""
+ bot_message_parts = []
+ bot_message_parts.extend(media_parts)
+ if text:
+ bot_message_parts.append({"type": "plain", "text": text})
+
+ new_his = {"type": "bot", "message": bot_message_parts}
+ if reasoning:
+ new_his["reasoning"] = reasoning
+ if agent_stats:
+ new_his["agent_stats"] = agent_stats
+
+ record = await self.platform_history_mgr.insert(
+ platform_id="webchat",
+ user_id=webchat_conv_id,
+ content=new_his,
+ sender_id="bot",
+ sender_name="bot",
+ )
+ return record
async def chat(self):
username = g.get("username", "guest")
post_data = await request.json
- if "message" not in post_data and "image_url" not in post_data:
- return Response().error("Missing key: message or image_url").__dict__
+ if "message" not in post_data and "files" not in post_data:
+ return Response().error("Missing key: message or files").__dict__
- if "conversation_id" not in post_data:
- return Response().error("Missing key: conversation_id").__dict__
+ if "session_id" not in post_data and "conversation_id" not in post_data:
+ return (
+ Response().error("Missing key: session_id or conversation_id").__dict__
+ )
message = post_data["message"]
- conversation_id = post_data["conversation_id"]
- image_url = post_data.get("image_url")
- audio_url = post_data.get("audio_url")
+ session_id = post_data.get("session_id", post_data.get("conversation_id"))
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
- enable_streaming = post_data.get("enable_streaming", True) # 默认为 True
+ enable_streaming = post_data.get("enable_streaming", True)
- if not message and not image_url and not audio_url:
- return (
- Response()
- .error("Message and image_url and audio_url are empty")
- .__dict__
+ # 检查消息是否为空
+ if isinstance(message, list):
+ has_content = any(
+ part.get("type") in ("plain", "image", "record", "file", "video")
+ for part in message
)
- if not conversation_id:
- return Response().error("conversation_id is empty").__dict__
+ if not has_content:
+ return (
+ Response()
+ .error("Message content is empty (reply only is not allowed)")
+ .__dict__
+ )
+ elif not message:
+ return Response().error("Message are both empty").__dict__
- # 追加用户消息
- webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
+ if not session_id:
+ return Response().error("session_id is empty").__dict__
- # 获取会话特定的队列
+ webchat_conv_id = session_id
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
- new_his = {"type": "user", "message": message}
- if image_url:
- new_his["image_url"] = image_url
- if audio_url:
- new_his["audio_url"] = audio_url
- await self.platform_history_mgr.insert(
- platform_id="webchat",
- user_id=webchat_conv_id,
- content=new_his,
- sender_id=username,
- sender_name=username,
- )
+ # 构建用户消息段(包含 path 用于传递给 adapter)
+ message_parts = await self._build_user_message_parts(message)
async def stream():
client_disconnected = False
-
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ tool_calls = {}
+ agent_stats = {}
try:
async with track_conversation(self.running_convs, webchat_conv_id):
while True:
@@ -175,16 +316,27 @@ class ChatRoute(Route):
continue
result_text = result["data"]
- type = result.get("type")
+ msg_type = result.get("type")
streaming = result.get("streaming", False)
+ chain_type = result.get("chain_type")
+ if chain_type == "agent_stats":
+ stats_info = {
+ "type": "agent_stats",
+ "data": json.loads(result_text),
+ }
+ yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n"
+ agent_stats = stats_info["data"]
+ continue
+
+ # 发送 SSE 数据
try:
if not client_disconnected:
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
except Exception as e:
if not client_disconnected:
logger.debug(
- f"[WebChat] 用户 {username} 断开聊天长连接。 {e}",
+ f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
)
client_disconnected = True
@@ -195,22 +347,97 @@ class ChatRoute(Route):
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True
- if type == "end":
+ # 累积消息部分
+ if msg_type == "plain":
+ chain_type = result.get("chain_type")
+ if chain_type == "tool_call":
+ tool_call = json.loads(result_text)
+ tool_calls[tool_call.get("id")] = tool_call
+ if accumulated_text:
+ # 如果累积了文本,则先保存文本
+ accumulated_parts.append(
+ {"type": "plain", "text": accumulated_text}
+ )
+ accumulated_text = ""
+ elif chain_type == "tool_call_result":
+ tcr = json.loads(result_text)
+ tc_id = tcr.get("id")
+ if tc_id in tool_calls:
+ tool_calls[tc_id]["result"] = tcr.get("result")
+ tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
+ accumulated_parts.append(
+ {
+ "type": "tool_call",
+ "tool_calls": [tool_calls[tc_id]],
+ }
+ )
+ tool_calls.pop(tc_id, None)
+ elif chain_type == "reasoning":
+ accumulated_reasoning += result_text
+ elif streaming:
+ accumulated_text += result_text
+ else:
+ accumulated_text = result_text
+ elif msg_type == "image":
+ filename = result_text.replace("[IMAGE]", "")
+ part = await self._create_attachment_from_file(
+ filename, "image"
+ )
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "record":
+ filename = result_text.replace("[RECORD]", "")
+ part = await self._create_attachment_from_file(
+ filename, "record"
+ )
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "file":
+ # 格式: [FILE]filename
+ filename = result_text.replace("[FILE]", "")
+ part = await self._create_attachment_from_file(
+ filename, "file"
+ )
+ if part:
+ accumulated_parts.append(part)
+
+ # 消息结束处理
+ if msg_type == "end":
break
elif (
- (streaming and type == "complete")
- or not streaming
- or type == "break"
+ (streaming and msg_type == "complete") or not streaming
+ # or msg_type == "break"
):
- # 追加机器人消息
- new_his = {"type": "bot", "message": result_text}
- await self.platform_history_mgr.insert(
- platform_id="webchat",
- user_id=webchat_conv_id,
- content=new_his,
- sender_id="bot",
- sender_name="bot",
+ if (
+ chain_type == "tool_call"
+ or chain_type == "tool_call_result"
+ ):
+ continue
+ saved_record = await self._save_bot_message(
+ webchat_conv_id,
+ accumulated_text,
+ accumulated_parts,
+ accumulated_reasoning,
+ agent_stats,
)
+ # 发送保存的消息信息给前端
+ if saved_record and not client_disconnected:
+ saved_info = {
+ "type": "message_saved",
+ "data": {
+ "id": saved_record.id,
+ "created_at": saved_record.created_at.astimezone().isoformat(),
+ },
+ }
+ try:
+ yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
+ except Exception:
+ pass
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ tool_calls = {}
+ agent_stats = {}
except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
@@ -221,9 +448,7 @@ class ChatRoute(Route):
username,
webchat_conv_id,
{
- "message": message,
- "image_url": image_url, # list
- "audio_url": audio_url,
+ "message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"enable_streaming": enable_streaming,
@@ -231,100 +456,195 @@ class ChatRoute(Route):
),
)
- response = await make_response(
- stream(),
- {
- "Content-Type": "text/event-stream",
- "Cache-Control": "no-cache",
- "Transfer-Encoding": "chunked",
- "Connection": "keep-alive",
- },
+ message_parts_for_storage = []
+ for part in message_parts:
+ part_copy = {k: v for k, v in part.items() if k != "path"}
+ message_parts_for_storage.append(part_copy)
+
+ await self.platform_history_mgr.insert(
+ platform_id="webchat",
+ user_id=webchat_conv_id,
+ content={"type": "user", "message": message_parts_for_storage},
+ sender_id=username,
+ sender_name=username,
+ )
+
+ response = cast(
+ QuartResponse,
+ await make_response(
+ stream(),
+ {
+ "Content-Type": "text/event-stream",
+ "Cache-Control": "no-cache",
+ "Transfer-Encoding": "chunked",
+ "Connection": "keep-alive",
+ },
+ ),
)
response.timeout = None # fix SSE auto disconnect issue
return response
- async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
- """从对话 ID 中提取 WebChat 会话 ID
-
- NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。
- """
- conversation = await self.conv_mgr.get_conversation(
- unified_msg_origin="webchat",
- conversation_id=conversation_id,
- )
- if not conversation:
- raise ValueError(f"Conversation with ID {conversation_id} not found.")
- conv_user_id = conversation.user_id
- webchat_session_id = MessageSession.from_str(conv_user_id).session_id
- if "!" not in webchat_session_id:
- raise ValueError(f"Invalid conv user ID: {conv_user_id}")
- return webchat_session_id.split("!")[-1]
-
- async def delete_conversation(self):
- conversation_id = request.args.get("conversation_id")
- if not conversation_id:
- return Response().error("Missing key: conversation_id").__dict__
+ async def delete_webchat_session(self):
+ """Delete a Platform session and all its related data."""
+ session_id = request.args.get("session_id")
+ if not session_id:
+ return Response().error("Missing key: session_id").__dict__
username = g.get("username", "guest")
- # Clean up queues when deleting conversation
- webchat_queue_mgr.remove_queues(conversation_id)
- webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
- await self.conv_mgr.delete_conversation(
- unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
- conversation_id=conversation_id,
+ # 验证会话是否存在且属于当前用户
+ session = await self.db.get_platform_session_by_id(session_id)
+ if not session:
+ return Response().error(f"Session {session_id} not found").__dict__
+ if session.creator != username:
+ return Response().error("Permission denied").__dict__
+
+ # 删除该会话下的所有对话
+ message_type = "GroupMessage" if session.is_group else "FriendMessage"
+ unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}"
+ await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
+
+ # 获取消息历史中的所有附件 ID 并删除附件
+ history_list = await self.platform_history_mgr.get(
+ platform_id=session.platform_id,
+ user_id=session_id,
+ page=1,
+ page_size=100000, # 获取足够多的记录
)
+ attachment_ids = self._extract_attachment_ids(history_list)
+ if attachment_ids:
+ await self._delete_attachments(attachment_ids)
+
+ # 删除消息历史
await self.platform_history_mgr.delete(
- platform_id="webchat",
- user_id=webchat_conv_id,
+ platform_id=session.platform_id,
+ user_id=session_id,
offset_sec=99999999,
)
+
+ # 删除与会话关联的配置路由
+ try:
+ await self.umop_config_router.delete_route(unified_msg_origin)
+ except ValueError as exc:
+ logger.warning(
+ "Failed to delete UMO route %s during session cleanup: %s",
+ unified_msg_origin,
+ exc,
+ )
+
+ # 清理队列(仅对 webchat)
+ if session.platform_id == "webchat":
+ webchat_queue_mgr.remove_queues(session_id)
+
+ # 删除会话
+ await self.db.delete_platform_session(session_id)
+
return Response().ok().__dict__
- async def new_conversation(self):
+ def _extract_attachment_ids(self, history_list) -> list[str]:
+ """从消息历史中提取所有 attachment_id"""
+ attachment_ids = []
+ for history in history_list:
+ content = history.content
+ if not content or "message" not in content:
+ continue
+ message_parts = content.get("message", [])
+ for part in message_parts:
+ if isinstance(part, dict) and "attachment_id" in part:
+ attachment_ids.append(part["attachment_id"])
+ return attachment_ids
+
+ async def _delete_attachments(self, attachment_ids: list[str]):
+ """删除附件(包括数据库记录和磁盘文件)"""
+ try:
+ attachments = await self.db.get_attachments(attachment_ids)
+ for attachment in attachments:
+ if not os.path.exists(attachment.path):
+ continue
+ try:
+ os.remove(attachment.path)
+ except OSError as e:
+ logger.warning(
+ f"Failed to delete attachment file {attachment.path}: {e}"
+ )
+ except Exception as e:
+ logger.warning(f"Failed to get attachments: {e}")
+
+ # 批量删除数据库记录
+ try:
+ await self.db.delete_attachments(attachment_ids)
+ except Exception as e:
+ logger.warning(f"Failed to delete attachments: {e}")
+
+ async def new_session(self):
+ """Create a new Platform session (default: webchat)."""
username = g.get("username", "guest")
- webchat_conv_id = str(uuid.uuid4())
- conv_id = await self.conv_mgr.new_conversation(
- unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
- platform_id="webchat",
- content=[],
+
+ # 获取可选的 platform_id 参数,默认为 webchat
+ platform_id = request.args.get("platform_id", "webchat")
+
+ # 创建新会话
+ session = await self.db.create_platform_session(
+ creator=username,
+ platform_id=platform_id,
+ is_group=0,
)
- return Response().ok(data={"conversation_id": conv_id}).__dict__
- async def rename_conversation(self):
- post_data = await request.json
- if "conversation_id" not in post_data or "title" not in post_data:
- return Response().error("Missing key: conversation_id or title").__dict__
-
- conversation_id = post_data["conversation_id"]
- title = post_data["title"]
-
- await self.conv_mgr.update_conversation(
- unified_msg_origin="webchat", # fake
- conversation_id=conversation_id,
- title=title,
+ return (
+ Response()
+ .ok(
+ data={
+ "session_id": session.session_id,
+ "platform_id": session.platform_id,
+ }
+ )
+ .__dict__
)
- return Response().ok(message="重命名成功!").__dict__
- async def get_conversations(self):
- conversations = await self.conv_mgr.get_conversations(platform_id="webchat")
- # remove content
- conversations_ = []
- for conv in conversations:
- conv.history = None
- conversations_.append(conv)
- return Response().ok(data=conversations_).__dict__
+ async def get_sessions(self):
+ """Get all Platform sessions for the current user."""
+ username = g.get("username", "guest")
- async def get_conversation(self):
- conversation_id = request.args.get("conversation_id")
- if not conversation_id:
- return Response().error("Missing key: conversation_id").__dict__
+ # 获取可选的 platform_id 参数
+ platform_id = request.args.get("platform_id")
- webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
+ sessions = await self.db.get_platform_sessions_by_creator(
+ creator=username,
+ platform_id=platform_id,
+ page=1,
+ page_size=100, # 暂时返回前100个
+ )
- # Get platform message history
+ # 转换为字典格式,并添加额外信息
+ sessions_data = []
+ for session in sessions:
+ sessions_data.append(
+ {
+ "session_id": session.session_id,
+ "platform_id": session.platform_id,
+ "creator": session.creator,
+ "display_name": session.display_name,
+ "is_group": session.is_group,
+ "created_at": session.created_at.astimezone().isoformat(),
+ "updated_at": session.updated_at.astimezone().isoformat(),
+ }
+ )
+
+ return Response().ok(data=sessions_data).__dict__
+
+ async def get_session(self):
+ """Get session information and message history by session_id."""
+ session_id = request.args.get("session_id")
+ if not session_id:
+ return Response().error("Missing key: session_id").__dict__
+
+ # 获取会话信息以确定 platform_id
+ session = await self.db.get_platform_session_by_id(session_id)
+ platform_id = session.platform_id if session else "webchat"
+
+ # Get platform message history using session_id
history_ls = await self.platform_history_mgr.get(
- platform_id="webchat",
- user_id=webchat_conv_id,
+ platform_id=platform_id,
+ user_id=session_id,
page=1,
page_size=1000,
)
@@ -336,8 +656,37 @@ class ChatRoute(Route):
.ok(
data={
"history": history_res,
- "is_running": self.running_convs.get(webchat_conv_id, False),
+ "is_running": self.running_convs.get(session_id, False),
},
)
.__dict__
)
+
+ async def update_session_display_name(self):
+ """Update a Platform session's display name."""
+ post_data = await request.json
+
+ session_id = post_data.get("session_id")
+ display_name = post_data.get("display_name")
+
+ if not session_id:
+ return Response().error("Missing key: session_id").__dict__
+ if display_name is None:
+ return Response().error("Missing key: display_name").__dict__
+
+ username = g.get("username", "guest")
+
+ # 验证会话是否存在且属于当前用户
+ session = await self.db.get_platform_session_by_id(session_id)
+ if not session:
+ return Response().error(f"Session {session_id} not found").__dict__
+ if session.creator != username:
+ return Response().error("Permission denied").__dict__
+
+ # 更新 display_name
+ await self.db.update_platform_session(
+ session_id=session_id,
+ display_name=display_name,
+ )
+
+ return Response().ok().__dict__
diff --git a/astrbot/dashboard/routes/command.py b/astrbot/dashboard/routes/command.py
new file mode 100644
index 000000000..5cb267169
--- /dev/null
+++ b/astrbot/dashboard/routes/command.py
@@ -0,0 +1,82 @@
+from quart import request
+
+from astrbot.core.star.command_management import (
+ list_command_conflicts,
+ list_commands,
+)
+from astrbot.core.star.command_management import (
+ rename_command as rename_command_service,
+)
+from astrbot.core.star.command_management import (
+ toggle_command as toggle_command_service,
+)
+
+from .route import Response, Route, RouteContext
+
+
+class CommandRoute(Route):
+ def __init__(self, context: RouteContext) -> None:
+ super().__init__(context)
+ self.routes = {
+ "/commands": ("GET", self.get_commands),
+ "/commands/conflicts": ("GET", self.get_conflicts),
+ "/commands/toggle": ("POST", self.toggle_command),
+ "/commands/rename": ("POST", self.rename_command),
+ }
+ self.register_routes()
+
+ async def get_commands(self):
+ commands = await list_commands()
+ summary = {
+ "total": len(commands),
+ "disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
+ "conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
+ }
+ return Response().ok({"items": commands, "summary": summary}).__dict__
+
+ async def get_conflicts(self):
+ conflicts = await list_command_conflicts()
+ return Response().ok(conflicts).__dict__
+
+ async def toggle_command(self):
+ data = await request.get_json()
+ handler_full_name = data.get("handler_full_name")
+ enabled = data.get("enabled")
+
+ if handler_full_name is None or enabled is None:
+ return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
+
+ if isinstance(enabled, str):
+ enabled = enabled.lower() in ("1", "true", "yes", "on")
+
+ try:
+ await toggle_command_service(handler_full_name, bool(enabled))
+ except ValueError as exc:
+ return Response().error(str(exc)).__dict__
+
+ payload = await _get_command_payload(handler_full_name)
+ return Response().ok(payload).__dict__
+
+ async def rename_command(self):
+ data = await request.get_json()
+ handler_full_name = data.get("handler_full_name")
+ new_name = data.get("new_name")
+
+ if not handler_full_name or not new_name:
+ return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
+
+ try:
+ await rename_command_service(handler_full_name, new_name)
+ except ValueError as exc:
+ return Response().error(str(exc)).__dict__
+
+ payload = await _get_command_payload(handler_full_name)
+ return Response().ok(payload).__dict__
+
+
+async def _get_command_payload(handler_full_name: str):
+ commands = await list_commands()
+ for cmd in commands:
+ if cmd["handler_full_name"] == handler_full_name:
+ return cmd
+ return {}
diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py
index b947d26f2..0edbe8377 100644
--- a/astrbot/dashboard/routes/config.py
+++ b/astrbot/dashboard/routes/config.py
@@ -2,6 +2,7 @@ import asyncio
import inspect
import os
import traceback
+from typing import Any
from quart import request
@@ -14,19 +15,18 @@ from astrbot.core.config.default import (
DEFAULT_CONFIG,
DEFAULT_VALUE_MAP,
)
+from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider
-from astrbot.core.provider.entities import ProviderType
-from astrbot.core.provider.provider import RerankProvider
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
-from astrbot.core.utils.astrbot_path import get_astrbot_path
+from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
from .route import Response, Route, RouteContext
-def try_cast(value: str, type_: str):
+def try_cast(value: Any, type_: str):
if type_ == "int":
try:
return int(value)
@@ -133,7 +133,9 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
is_core,
)
else:
- errors, post_config = validate_config(post_config, config.schema, is_core)
+ errors, post_config = validate_config(
+ post_config, getattr(config, "schema", {}), is_core
+ )
except BaseException as e:
logger.error(traceback.format_exc())
logger.warning(f"验证配置时出现异常: {e}")
@@ -247,11 +249,8 @@ class ConfigRoute(Route):
async def get_default_config(self):
"""获取默认配置文件"""
- return (
- Response()
- .ok({"config": DEFAULT_CONFIG, "metadata": CONFIG_METADATA_3})
- .__dict__
- )
+ metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3)
+ return Response().ok({"config": DEFAULT_CONFIG, "metadata": metadata}).__dict__
async def get_abconf_list(self):
"""获取所有 AstrBot 配置文件的列表"""
@@ -282,17 +281,15 @@ class ConfigRoute(Route):
try:
if system_config:
abconf = self.acm.confs["default"]
- return (
- Response()
- .ok({"config": abconf, "metadata": CONFIG_METADATA_3_SYSTEM})
- .__dict__
+ metadata = ConfigMetadataI18n.convert_to_i18n_keys(
+ CONFIG_METADATA_3_SYSTEM
)
+ return Response().ok({"config": abconf, "metadata": metadata}).__dict__
+ if abconf_id is None:
+ raise ValueError("abconf_id cannot be None")
abconf = self.acm.confs[abconf_id]
- return (
- Response()
- .ok({"config": abconf, "metadata": CONFIG_METADATA_3})
- .__dict__
- )
+ metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3)
+ return Response().ok({"config": abconf, "metadata": metadata}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
@@ -358,169 +355,20 @@ class ConfigRoute(Route):
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
)
- if provider_capability_type == ProviderType.CHAT_COMPLETION:
- try:
- logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
- response = await asyncio.wait_for(
- provider.text_chat(prompt="REPLY `PONG` ONLY"),
- timeout=45.0,
- )
- logger.debug(
- f"Received response from {status_info['name']}: {response}",
- )
- if response is not None:
- status_info["status"] = "available"
- response_text_snippet = ""
- if (
- hasattr(response, "completion_text")
- and response.completion_text
- ):
- response_text_snippet = (
- response.completion_text[:70] + "..."
- if len(response.completion_text) > 70
- else response.completion_text
- )
- elif hasattr(response, "result_chain") and response.result_chain:
- try:
- response_text_snippet = (
- response.result_chain.get_plain_text()[:70] + "..."
- if len(response.result_chain.get_plain_text()) > 70
- else response.result_chain.get_plain_text()
- )
- except Exception as _:
- pass
- logger.info(
- f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
- )
- else:
- status_info["error"] = (
- "Test call returned None, but expected an LLMResponse object."
- )
- logger.warning(
- f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
- )
-
- except asyncio.TimeoutError:
- status_info["error"] = (
- "Connection timed out after 45 seconds during test call."
- )
- logger.warning(
- f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
- )
- except Exception as e:
- error_message = str(e)
- status_info["error"] = error_message
- logger.warning(
- f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
- )
- logger.debug(
- f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
- )
-
- elif provider_capability_type == ProviderType.EMBEDDING:
- try:
- # For embedding, we can call the get_embedding method with a short prompt.
- embedding_result = await provider.get_embedding("health_check")
- if isinstance(embedding_result, list) and (
- not embedding_result or isinstance(embedding_result[0], float)
- ):
- status_info["status"] = "available"
- else:
- status_info["status"] = "unavailable"
- status_info["error"] = (
- f"Embedding test failed: unexpected result type {type(embedding_result)}"
- )
- except Exception as e:
- logger.error(
- f"Error testing embedding provider {provider_name}: {e}",
- exc_info=True,
- )
- status_info["status"] = "unavailable"
- status_info["error"] = f"Embedding test failed: {e!s}"
-
- elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
- try:
- # For TTS, we can call the get_audio method with a short prompt.
- audio_result = await provider.get_audio("你好")
- if isinstance(audio_result, str) and audio_result:
- status_info["status"] = "available"
- else:
- status_info["status"] = "unavailable"
- status_info["error"] = (
- f"TTS test failed: unexpected result type {type(audio_result)}"
- )
- except Exception as e:
- logger.error(
- f"Error testing TTS provider {provider_name}: {e}",
- exc_info=True,
- )
- status_info["status"] = "unavailable"
- status_info["error"] = f"TTS test failed: {e!s}"
- elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
- try:
- logger.debug(
- f"Sending health check audio to provider: {status_info['name']}",
- )
- sample_audio_path = os.path.join(
- get_astrbot_path(),
- "samples",
- "stt_health_check.wav",
- )
- if not os.path.exists(sample_audio_path):
- status_info["status"] = "unavailable"
- status_info["error"] = (
- "STT test failed: sample audio file not found."
- )
- logger.warning(
- f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
- )
- else:
- text_result = await provider.get_text(sample_audio_path)
- if isinstance(text_result, str) and text_result:
- status_info["status"] = "available"
- snippet = (
- text_result[:70] + "..."
- if len(text_result) > 70
- else text_result
- )
- logger.info(
- f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
- )
- else:
- status_info["status"] = "unavailable"
- status_info["error"] = (
- f"STT test failed: unexpected result type {type(text_result)}"
- )
- logger.warning(
- f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
- )
- except Exception as e:
- logger.error(
- f"Error testing STT provider {provider_name}: {e}",
- exc_info=True,
- )
- status_info["status"] = "unavailable"
- status_info["error"] = f"STT test failed: {e!s}"
- elif provider_capability_type == ProviderType.RERANK:
- try:
- assert isinstance(provider, RerankProvider)
- await provider.rerank("Apple", documents=["apple", "banana"])
- status_info["status"] = "available"
- except Exception as e:
- logger.error(
- f"Error testing rerank provider {provider_name}: {e}",
- exc_info=True,
- )
- status_info["status"] = "unavailable"
- status_info["error"] = f"Rerank test failed: {e!s}"
-
- else:
- logger.debug(
- f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}",
- )
+ try:
+ await provider.test()
status_info["status"] = "available"
- status_info["error"] = (
- "This provider type is not tested and is assumed to be available."
+ logger.info(
+ f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
+ )
+ except Exception as e:
+ error_message = str(e)
+ status_info["error"] = error_message
+ logger.warning(
+ f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
+ )
+ logger.debug(
+ f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
return status_info
@@ -598,9 +446,15 @@ class ConfigRoute(Route):
return Response().error("缺少参数 provider_id").__dict__
prov_mgr = self.core_lifecycle.provider_manager
- provider: Provider | None = prov_mgr.inst_map.get(provider_id, None)
+ provider = prov_mgr.inst_map.get(provider_id, None)
if not provider:
return Response().error(f"未找到 ID 为 {provider_id} 的提供商").__dict__
+ if not isinstance(provider, Provider):
+ return (
+ Response()
+ .error(f"提供商 {provider_id} 类型不支持获取模型列表")
+ .__dict__
+ )
try:
models = await provider.get_models()
@@ -651,9 +505,9 @@ class ConfigRoute(Route):
if not isinstance(inst, EmbeddingProvider):
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
- # 初始化
- if getattr(inst, "initialize", None):
- await inst.initialize()
+ init_fn = getattr(inst, "initialize", None)
+ if inspect.iscoroutinefunction(init_fn):
+ await init_fn()
# 获取嵌入向量维度
vec = await inst.get_embedding("echo")
@@ -703,6 +557,10 @@ class ConfigRoute(Route):
async def post_new_platform(self):
new_platform_config = await request.json
+
+ # 如果是支持统一 webhook 模式的平台,生成 webhook_uuid
+ ensure_platform_webhook_config(new_platform_config)
+
self.config["platform"].append(new_platform_config)
try:
save_config(self.config, self.config, is_core=True)
@@ -732,6 +590,9 @@ class ConfigRoute(Route):
if not platform_id or not new_config:
return Response().error("参数错误").__dict__
+ # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid
+ ensure_platform_webhook_config(new_config)
+
for i, platform in enumerate(self.config["platform"]):
if platform["id"] == platform_id:
self.config["platform"][i] = new_config
@@ -906,7 +767,7 @@ class ConfigRoute(Route):
return {"metadata": CONFIG_METADATA_2, "config": config}
async def _get_plugin_config(self, plugin_name: str):
- ret = {"metadata": None, "config": None}
+ ret: dict = {"metadata": None, "config": None}
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py
index d19fdf793..513d3603f 100644
--- a/astrbot/dashboard/routes/conversation.py
+++ b/astrbot/dashboard/routes/conversation.py
@@ -1,7 +1,9 @@
import json
import traceback
+from datetime import datetime
+from io import BytesIO
-from quart import request
+from quart import request, send_file
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -30,6 +32,7 @@ class ConversationRoute(Route):
"POST",
self.update_history,
),
+ "/conversation/export": ("POST", self.export_conversations),
}
self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager
@@ -283,3 +286,90 @@ class ConversationRoute(Route):
except Exception as e:
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {e!s}").__dict__
+
+ async def export_conversations(self):
+ """批量导出对话为 JSONL 格式"""
+ try:
+ data = await request.get_json()
+ conversations_to_export = data.get("conversations", [])
+
+ if not conversations_to_export:
+ return Response().error("导出列表不能为空").__dict__
+
+ # 收集所有对话的内容
+ jsonl_lines = []
+ exported_count = 0
+ failed_items = []
+
+ for conv_info in conversations_to_export:
+ user_id = conv_info.get("user_id")
+ cid = conv_info.get("cid")
+
+ if not user_id or not cid:
+ failed_items.append(
+ f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
+ )
+ continue
+
+ try:
+ conversation = await self.conv_mgr.get_conversation(
+ unified_msg_origin=user_id,
+ conversation_id=cid,
+ )
+
+ if not conversation:
+ failed_items.append(
+ f"user_id:{user_id}, cid:{cid} - 对话不存在"
+ )
+ continue
+
+ # 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
+ content = json.loads(conversation.history)
+
+ # 创建导出记录
+ export_record = {
+ "cid": cid,
+ "user_id": user_id,
+ "platform_id": conversation.platform_id,
+ "title": conversation.title,
+ "persona_id": conversation.persona_id,
+ "created_at": conversation.created_at,
+ "updated_at": conversation.updated_at,
+ "content": content,
+ }
+
+ # 将记录转换为 JSON 字符串并添加到 JSONL
+ jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
+ exported_count += 1
+
+ except Exception as e:
+ failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
+ logger.error(
+ f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
+ )
+
+ if exported_count == 0:
+ return Response().error("没有成功导出任何对话").__dict__
+
+ # 创建 JSONL 内容
+ jsonl_content = "\n".join(jsonl_lines)
+
+ # 创建一个内存文件对象
+ file_obj = BytesIO(jsonl_content.encode("utf-8"))
+ file_obj.seek(0)
+
+ # 生成文件名
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ filename = f"astrbot_conversations_export_{timestamp}.jsonl"
+
+ # 返回文件流
+ return await send_file(
+ file_obj,
+ mimetype="application/jsonl",
+ as_attachment=True,
+ attachment_filename=filename,
+ )
+
+ except Exception as e:
+ logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
+ return Response().error(f"批量导出对话失败: {e!s}").__dict__
diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py
index b4e21382a..537a81f0b 100644
--- a/astrbot/dashboard/routes/knowledge_base.py
+++ b/astrbot/dashboard/routes/knowledge_base.py
@@ -48,6 +48,8 @@ class KnowledgeBaseRoute(Route):
# 文档管理
"/kb/document/list": ("GET", self.list_documents),
"/kb/document/upload": ("POST", self.upload_document),
+ "/kb/document/import": ("POST", self.import_documents),
+ "/kb/document/upload/url": ("POST", self.upload_document_from_url),
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
"/kb/document/get": ("GET", self.get_document),
"/kb/document/delete": ("POST", self.delete_document),
@@ -59,16 +61,71 @@ class KnowledgeBaseRoute(Route):
# "/kb/media/delete": ("POST", self.delete_media),
# 检索
"/kb/retrieve": ("POST", self.retrieve),
- # 会话知识库配置
- "/kb/session/config/get": ("GET", self.get_session_kb_config),
- "/kb/session/config/set": ("POST", self.set_session_kb_config),
- "/kb/session/config/delete": ("POST", self.delete_session_kb_config),
}
self.register_routes()
def _get_kb_manager(self):
return self.core_lifecycle.kb_manager
+ def _init_task(self, task_id: str, status: str = "pending") -> None:
+ self.upload_tasks[task_id] = {
+ "status": status,
+ "result": None,
+ "error": None,
+ }
+
+ def _set_task_result(
+ self, task_id: str, status: str, result: any = None, error: str | None = None
+ ) -> None:
+ self.upload_tasks[task_id] = {
+ "status": status,
+ "result": result,
+ "error": error,
+ }
+ if task_id in self.upload_progress:
+ self.upload_progress[task_id]["status"] = status
+
+ def _update_progress(
+ self,
+ task_id: str,
+ *,
+ status: str | None = None,
+ file_index: int | None = None,
+ file_name: str | None = None,
+ stage: str | None = None,
+ current: int | None = None,
+ total: int | None = None,
+ ) -> None:
+ if task_id not in self.upload_progress:
+ return
+ p = self.upload_progress[task_id]
+ if status is not None:
+ p["status"] = status
+ if file_index is not None:
+ p["file_index"] = file_index
+ if file_name is not None:
+ p["file_name"] = file_name
+ if stage is not None:
+ p["stage"] = stage
+ if current is not None:
+ p["current"] = current
+ if total is not None:
+ p["total"] = total
+
+ def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
+ async def _callback(stage: str, current: int, total: int):
+ self._update_progress(
+ task_id,
+ status="processing",
+ file_index=file_idx,
+ file_name=file_name,
+ stage=stage,
+ current=current,
+ total=total,
+ )
+
+ return _callback
+
async def _background_upload_task(
self,
task_id: str,
@@ -83,11 +140,7 @@ class KnowledgeBaseRoute(Route):
"""后台上传任务"""
try:
# 初始化任务状态
- self.upload_tasks[task_id] = {
- "status": "processing",
- "result": None,
- "error": None,
- }
+ self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
@@ -103,30 +156,20 @@ class KnowledgeBaseRoute(Route):
for file_idx, file_info in enumerate(files_to_upload):
try:
# 更新整体进度
- self.upload_progress[task_id].update(
- {
- "status": "processing",
- "file_index": file_idx,
- "file_name": file_info["file_name"],
- "stage": "parsing",
- "current": 0,
- "total": 100,
- },
+ self._update_progress(
+ task_id,
+ status="processing",
+ file_index=file_idx,
+ file_name=file_info["file_name"],
+ stage="parsing",
+ current=0,
+ total=100,
)
# 创建进度回调函数
- async def progress_callback(stage, current, total):
- if task_id in self.upload_progress:
- self.upload_progress[task_id].update(
- {
- "status": "processing",
- "file_index": file_idx,
- "file_name": file_info["file_name"],
- "stage": stage,
- "current": current,
- "total": total,
- },
- )
+ progress_callback = self._make_progress_callback(
+ task_id, file_idx, file_info["file_name"]
+ )
doc = await kb_helper.upload_document(
file_name=file_info["file_name"],
@@ -157,23 +200,99 @@ class KnowledgeBaseRoute(Route):
"failed_count": len(failed_docs),
}
- self.upload_tasks[task_id] = {
- "status": "completed",
- "result": result,
- "error": None,
- }
- self.upload_progress[task_id]["status"] = "completed"
+ self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台上传任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
- self.upload_tasks[task_id] = {
- "status": "failed",
- "result": None,
- "error": str(e),
+ self._set_task_result(task_id, "failed", error=str(e))
+
+ async def _background_import_task(
+ self,
+ task_id: str,
+ kb_helper,
+ documents: list,
+ batch_size: int,
+ tasks_limit: int,
+ max_retries: int,
+ ):
+ """后台导入预切片文档任务"""
+ try:
+ # 初始化任务状态
+ self._init_task(task_id, status="processing")
+ self.upload_progress[task_id] = {
+ "status": "processing",
+ "file_index": 0,
+ "file_total": len(documents),
+ "stage": "waiting",
+ "current": 0,
+ "total": 100,
}
- if task_id in self.upload_progress:
- self.upload_progress[task_id]["status"] = "failed"
+
+ uploaded_docs = []
+ failed_docs = []
+
+ for file_idx, doc_info in enumerate(documents):
+ file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
+ chunks = doc_info.get("chunks", [])
+
+ try:
+ # 更新整体进度
+ self._update_progress(
+ task_id,
+ status="processing",
+ file_index=file_idx,
+ file_name=file_name,
+ stage="importing",
+ current=0,
+ total=100,
+ )
+
+ # 创建进度回调函数
+ progress_callback = self._make_progress_callback(
+ task_id, file_idx, file_name
+ )
+
+ # 调用 upload_document,传入 pre_chunked_text
+ doc = await kb_helper.upload_document(
+ file_name=file_name,
+ file_content=None, # 预切片模式下不需要原始内容
+ file_type=doc_info.get("file_type")
+ or (
+ file_name.rsplit(".", 1)[-1].lower()
+ if "." in file_name
+ else "txt"
+ ),
+ batch_size=batch_size,
+ tasks_limit=tasks_limit,
+ max_retries=max_retries,
+ progress_callback=progress_callback,
+ pre_chunked_text=chunks,
+ )
+
+ uploaded_docs.append(doc.model_dump())
+ except Exception as e:
+ logger.error(f"导入文档 {file_name} 失败: {e}")
+ failed_docs.append(
+ {"file_name": file_name, "error": str(e)},
+ )
+
+ # 更新任务完成状态
+ result = {
+ "task_id": task_id,
+ "uploaded": uploaded_docs,
+ "failed": failed_docs,
+ "total": len(documents),
+ "success_count": len(uploaded_docs),
+ "failed_count": len(failed_docs),
+ }
+
+ self._set_task_result(task_id, "completed", result=result)
+
+ except Exception as e:
+ logger.error(f"后台导入任务 {task_id} 失败: {e}")
+ logger.error(traceback.format_exc())
+ self._set_task_result(task_id, "failed", error=str(e))
async def list_kbs(self):
"""获取知识库列表
@@ -277,7 +396,7 @@ class KnowledgeBaseRoute(Route):
except Exception as e:
return (
Response()
- .error(f"测试重排序模型失败: {e!s},请检查控制台日志输出。")
+ .error(f"测试重排序模型失败: {e!s},请检查平台日志输出。")
.__dict__
)
@@ -617,11 +736,7 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4())
# 初始化任务状态
- self.upload_tasks[task_id] = {
- "status": "pending",
- "result": None,
- "error": None,
- }
+ self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
@@ -656,6 +771,93 @@ class KnowledgeBaseRoute(Route):
logger.error(traceback.format_exc())
return Response().error(f"上传文档失败: {e!s}").__dict__
+ def _validate_import_request(self, data: dict):
+ kb_id = data.get("kb_id")
+ if not kb_id:
+ raise ValueError("缺少参数 kb_id")
+
+ documents = data.get("documents")
+ if not documents or not isinstance(documents, list):
+ raise ValueError("缺少参数 documents 或格式错误")
+
+ for doc in documents:
+ if "file_name" not in doc or "chunks" not in doc:
+ raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
+ if not isinstance(doc["chunks"], list):
+ raise ValueError("chunks 必须是列表")
+ if not all(
+ isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
+ ):
+ raise ValueError("chunks 必须是非空字符串列表")
+
+ batch_size = data.get("batch_size", 32)
+ tasks_limit = data.get("tasks_limit", 3)
+ max_retries = data.get("max_retries", 3)
+ return kb_id, documents, batch_size, tasks_limit, max_retries
+
+ async def import_documents(self):
+ """导入预切片文档
+
+ Body:
+ - kb_id: 知识库 ID (必填)
+ - documents: 文档列表 (必填)
+ - file_name: 文件名 (必填)
+ - chunks: 切片列表 (必填, list[str])
+ - file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
+ - batch_size: 批处理大小 (可选, 默认32)
+ - tasks_limit: 并发任务限制 (可选, 默认3)
+ - max_retries: 最大重试次数 (可选, 默认3)
+ """
+ try:
+ kb_manager = self._get_kb_manager()
+ data = await request.json
+
+ kb_id, documents, batch_size, tasks_limit, max_retries = (
+ self._validate_import_request(data)
+ )
+
+ # 获取知识库
+ kb_helper = await kb_manager.get_kb(kb_id)
+ if not kb_helper:
+ return Response().error("知识库不存在").__dict__
+
+ # 生成任务ID
+ task_id = str(uuid.uuid4())
+
+ # 初始化任务状态
+ self._init_task(task_id, status="pending")
+
+ # 启动后台任务
+ asyncio.create_task(
+ self._background_import_task(
+ task_id=task_id,
+ kb_helper=kb_helper,
+ documents=documents,
+ batch_size=batch_size,
+ tasks_limit=tasks_limit,
+ max_retries=max_retries,
+ ),
+ )
+
+ return (
+ Response()
+ .ok(
+ {
+ "task_id": task_id,
+ "doc_count": len(documents),
+ "message": "import task created, processing in background",
+ },
+ )
+ .__dict__
+ )
+
+ except ValueError as e:
+ return Response().error(str(e)).__dict__
+ except Exception as e:
+ logger.error(f"导入文档失败: {e}")
+ logger.error(traceback.format_exc())
+ return Response().error(f"导入文档失败: {e!s}").__dict__
+
async def get_upload_progress(self):
"""获取上传进度和结果
@@ -919,154 +1121,143 @@ class KnowledgeBaseRoute(Route):
logger.error(traceback.format_exc())
return Response().error(f"检索失败: {e!s}").__dict__
- # ===== 会话知识库配置 API =====
+ async def upload_document_from_url(self):
+ """从 URL 上传文档
- async def get_session_kb_config(self):
- """获取会话的知识库配置
-
- Query 参数:
- - session_id: 会话 ID (必填)
+ Body:
+ - kb_id: 知识库 ID (必填)
+ - url: 要提取内容的网页 URL (必填)
+ - chunk_size: 分块大小 (可选, 默认512)
+ - chunk_overlap: 块重叠大小 (可选, 默认50)
+ - batch_size: 批处理大小 (可选, 默认32)
+ - tasks_limit: 并发任务限制 (可选, 默认3)
+ - max_retries: 最大重试次数 (可选, 默认3)
返回:
- - kb_ids: 知识库 ID 列表
- - top_k: 返回结果数量
- - enable_rerank: 是否启用重排序
+ - task_id: 任务ID,用于查询上传进度和结果
"""
try:
- from astrbot.core import sp
-
- session_id = request.args.get("session_id")
-
- if not session_id:
- return Response().error("缺少参数 session_id").__dict__
-
- # 从 SharedPreferences 获取配置
- config = await sp.session_get(session_id, "kb_config", default={})
-
- logger.debug(f"[KB配置] 读取到配置: session_id={session_id}")
-
- # 如果没有配置,返回默认值
- if not config:
- config = {"kb_ids": [], "top_k": 5, "enable_rerank": True}
-
- return Response().ok(config).__dict__
-
- except Exception as e:
- logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True)
- return Response().error(f"获取会话知识库配置失败: {e!s}").__dict__
-
- async def set_session_kb_config(self):
- """设置会话的知识库配置
-
- Body:
- - scope: 配置范围 (目前只支持 "session")
- - scope_id: 会话 ID (必填)
- - kb_ids: 知识库 ID 列表 (必填)
- - top_k: 返回结果数量 (可选, 默认 5)
- - enable_rerank: 是否启用重排序 (可选, 默认 true)
- """
- try:
- from astrbot.core import sp
-
+ kb_manager = self._get_kb_manager()
data = await request.json
- scope = data.get("scope")
- scope_id = data.get("scope_id")
- kb_ids = data.get("kb_ids", [])
- top_k = data.get("top_k", 5)
- enable_rerank = data.get("enable_rerank", True)
+ kb_id = data.get("kb_id")
+ if not kb_id:
+ return Response().error("缺少参数 kb_id").__dict__
- # 验证参数
- if scope != "session":
- return Response().error("目前仅支持 session 范围的配置").__dict__
+ url = data.get("url")
+ if not url:
+ return Response().error("缺少参数 url").__dict__
- if not scope_id:
- return Response().error("缺少参数 scope_id").__dict__
+ chunk_size = data.get("chunk_size", 512)
+ chunk_overlap = data.get("chunk_overlap", 50)
+ batch_size = data.get("batch_size", 32)
+ tasks_limit = data.get("tasks_limit", 3)
+ max_retries = data.get("max_retries", 3)
+ enable_cleaning = data.get("enable_cleaning", False)
+ cleaning_provider_id = data.get("cleaning_provider_id")
- if not isinstance(kb_ids, list):
- return Response().error("kb_ids 必须是列表").__dict__
+ # 获取知识库
+ kb_helper = await kb_manager.get_kb(kb_id)
+ if not kb_helper:
+ return Response().error("知识库不存在").__dict__
- # 验证知识库是否存在
- kb_mgr = self._get_kb_manager()
- invalid_ids = []
- valid_ids = []
- for kb_id in kb_ids:
- kb_helper = await kb_mgr.get_kb(kb_id)
- if kb_helper:
- valid_ids.append(kb_id)
- else:
- invalid_ids.append(kb_id)
- logger.warning(f"[KB配置] 知识库不存在: {kb_id}")
+ # 生成任务ID
+ task_id = str(uuid.uuid4())
- if invalid_ids:
- logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}")
+ # 初始化任务状态
+ self._init_task(task_id, status="pending")
- # 允许保存空列表,表示明确不使用任何知识库
- if kb_ids and not valid_ids:
- # 只有当用户提供了 kb_ids 但全部无效时才报错
- return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__
+ # 启动后台任务
+ asyncio.create_task(
+ self._background_upload_from_url_task(
+ task_id=task_id,
+ kb_helper=kb_helper,
+ url=url,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ batch_size=batch_size,
+ tasks_limit=tasks_limit,
+ max_retries=max_retries,
+ enable_cleaning=enable_cleaning,
+ cleaning_provider_id=cleaning_provider_id,
+ ),
+ )
- # 如果 kb_ids 为空列表,表示用户想清空配置
- if not kb_ids:
- valid_ids = []
+ return (
+ Response()
+ .ok(
+ {
+ "task_id": task_id,
+ "url": url,
+ "message": "URL upload task created, processing in background",
+ },
+ )
+ .__dict__
+ )
- # 构建配置对象(只保存有效的ID)
- config = {
- "kb_ids": valid_ids,
- "top_k": top_k,
- "enable_rerank": enable_rerank,
+ except ValueError as e:
+ return Response().error(str(e)).__dict__
+ except Exception as e:
+ logger.error(f"从URL上传文档失败: {e}")
+ logger.error(traceback.format_exc())
+ return Response().error(f"从URL上传文档失败: {e!s}").__dict__
+
+ async def _background_upload_from_url_task(
+ self,
+ task_id: str,
+ kb_helper,
+ url: str,
+ chunk_size: int,
+ chunk_overlap: int,
+ batch_size: int,
+ tasks_limit: int,
+ max_retries: int,
+ enable_cleaning: bool,
+ cleaning_provider_id: str | None,
+ ):
+ """后台上传URL任务"""
+ try:
+ # 初始化任务状态
+ self._init_task(task_id, status="processing")
+ self.upload_progress[task_id] = {
+ "status": "processing",
+ "file_index": 0,
+ "file_total": 1,
+ "file_name": f"URL: {url}",
+ "stage": "extracting",
+ "current": 0,
+ "total": 100,
}
- # 保存到 SharedPreferences
- await sp.session_put(scope_id, "kb_config", config)
+ # 创建进度回调函数
+ progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}")
- # 立即验证是否保存成功
- verify_config = await sp.session_get(scope_id, "kb_config", default={})
+ # 上传文档
+ doc = await kb_helper.upload_from_url(
+ url=url,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ batch_size=batch_size,
+ tasks_limit=tasks_limit,
+ max_retries=max_retries,
+ progress_callback=progress_callback,
+ enable_cleaning=enable_cleaning,
+ cleaning_provider_id=cleaning_provider_id,
+ )
- if verify_config == config:
- return (
- Response()
- .ok(
- {"valid_ids": valid_ids, "invalid_ids": invalid_ids},
- "保存知识库配置成功",
- )
- .__dict__
- )
- logger.error("[KB配置] 配置保存失败,验证不匹配")
- return Response().error("配置保存失败").__dict__
+ # 更新任务完成状态
+ result = {
+ "task_id": task_id,
+ "uploaded": [doc.model_dump()],
+ "failed": [],
+ "total": 1,
+ "success_count": 1,
+ "failed_count": 0,
+ }
+
+ self._set_task_result(task_id, "completed", result=result)
except Exception as e:
- logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True)
- return Response().error(f"设置会话知识库配置失败: {e!s}").__dict__
-
- async def delete_session_kb_config(self):
- """删除会话的知识库配置
-
- Body:
- - scope: 配置范围 (目前只支持 "session")
- - scope_id: 会话 ID (必填)
- """
- try:
- from astrbot.core import sp
-
- data = await request.json
-
- scope = data.get("scope")
- scope_id = data.get("scope_id")
-
- # 验证参数
- if scope != "session":
- return Response().error("目前仅支持 session 范围的配置").__dict__
-
- if not scope_id:
- return Response().error("缺少参数 scope_id").__dict__
-
- # 从 SharedPreferences 删除配置
- await sp.session_remove(scope_id, "kb_config")
-
- return Response().ok(message="删除知识库配置成功").__dict__
-
- except Exception as e:
- logger.error(f"删除会话知识库配置失败: {e}")
+ logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
- return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__
+ self._set_task_result(task_id, "failed", error=str(e))
diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py
index eb02fdf40..86cc8c6ca 100644
--- a/astrbot/dashboard/routes/log.py
+++ b/astrbot/dashboard/routes/log.py
@@ -1,6 +1,8 @@
import asyncio
import json
+from typing import cast
+from quart import Response as QuartResponse
from quart import make_response
from astrbot.core import LogBroker, logger
@@ -39,14 +41,17 @@ class LogRoute(Route):
if queue:
self.log_broker.unregister(queue)
- response = await make_response(
- stream(),
- {
- "Content-Type": "text/event-stream",
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Transfer-Encoding": "chunked",
- },
+ response = cast(
+ QuartResponse,
+ await make_response(
+ stream(),
+ {
+ "Content-Type": "text/event-stream",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Transfer-Encoding": "chunked",
+ },
+ ),
)
response.timeout = None
return response
diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py
new file mode 100644
index 000000000..4d8fdddfe
--- /dev/null
+++ b/astrbot/dashboard/routes/platform.py
@@ -0,0 +1,100 @@
+"""统一 Webhook 路由
+
+提供统一的 webhook 回调入口,支持多个平台使用同一端口接收回调。
+"""
+
+from quart import request
+
+from astrbot.core import logger
+from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
+from astrbot.core.platform import Platform
+
+from .route import Response, Route, RouteContext
+
+
+class PlatformRoute(Route):
+ """统一 Webhook 路由"""
+
+ def __init__(
+ self,
+ context: RouteContext,
+ core_lifecycle: AstrBotCoreLifecycle,
+ ) -> None:
+ super().__init__(context)
+ self.core_lifecycle = core_lifecycle
+ self.platform_manager = core_lifecycle.platform_manager
+
+ self._register_webhook_routes()
+
+ def _register_webhook_routes(self):
+ """注册 webhook 路由"""
+ # 统一 webhook 入口,支持 GET 和 POST
+ self.app.add_url_rule(
+ "/api/platform/webhook/",
+ view_func=self.unified_webhook_callback,
+ methods=["GET", "POST"],
+ )
+
+ # 平台统计信息接口
+ self.app.add_url_rule(
+ "/api/platform/stats",
+ view_func=self.get_platform_stats,
+ methods=["GET"],
+ )
+
+ async def unified_webhook_callback(self, webhook_uuid: str):
+ """统一 webhook 回调入口
+
+ Args:
+ webhook_uuid: 平台配置中的 webhook_uuid
+
+ Returns:
+ 根据平台适配器返回相应的响应
+ """
+ # 根据 webhook_uuid 查找对应的平台
+ platform_adapter = self._find_platform_by_uuid(webhook_uuid)
+
+ if not platform_adapter:
+ logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台")
+ return Response().error("未找到对应平台").__dict__, 404
+
+ # 调用平台适配器的 webhook_callback 方法
+ try:
+ result = await platform_adapter.webhook_callback(request)
+ return result
+ except NotImplementedError:
+ logger.error(
+ f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法"
+ )
+ return Response().error("平台未支持统一 Webhook 模式").__dict__, 500
+ except Exception as e:
+ logger.error(f"处理 webhook 回调时发生错误: {e}", exc_info=True)
+ return Response().error("处理回调失败").__dict__, 500
+
+ def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None:
+ """根据 webhook_uuid 查找对应的平台适配器
+
+ Args:
+ webhook_uuid: webhook UUID
+
+ Returns:
+ 平台适配器实例,未找到则返回 None
+ """
+ for platform in self.platform_manager.platform_insts:
+ if platform.config.get("webhook_uuid") == webhook_uuid:
+ if platform.unified_webhook():
+ return platform
+ return None
+
+ async def get_platform_stats(self):
+ """获取所有平台的统计信息
+
+ Returns:
+ 包含平台统计信息的响应
+ """
+ try:
+ stats = self.platform_manager.get_all_stats()
+ return Response().ok(stats).__dict__
+ except Exception as e:
+ logger.error(f"获取平台统计信息失败: {e}", exc_info=True)
+ return Response().error(f"获取统计信息失败: {e}").__dict__, 500
diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py
index 597a245d4..fd808c6c9 100644
--- a/astrbot/dashboard/routes/plugin.py
+++ b/astrbot/dashboard/routes/plugin.py
@@ -1,13 +1,17 @@
+import asyncio
+import hashlib
import json
import os
import ssl
import traceback
+from dataclasses import dataclass
from datetime import datetime
import aiohttp
import certifi
from quart import request
+from astrbot.api import sp
from astrbot.core import DEMO_MODE, file_token_service, logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star.filter.command import CommandFilter
@@ -19,6 +23,17 @@ from astrbot.core.star.star_manager import PluginManager
from .route import Response, Route, RouteContext
+PLUGIN_UPDATE_CONCURRENCY = (
+ 3 # limit concurrent updates to avoid overwhelming plugin sources
+)
+
+
+@dataclass
+class RegistrySource:
+ urls: list[str]
+ cache_file: str
+ md5_url: str | None # None means "no remote MD5, always treat cache as stale"
+
class PluginRoute(Route):
def __init__(
@@ -33,12 +48,15 @@ class PluginRoute(Route):
"/plugin/install": ("POST", self.install_plugin),
"/plugin/install-upload": ("POST", self.install_plugin_upload),
"/plugin/update": ("POST", self.update_plugin),
+ "/plugin/update-all": ("POST", self.update_all_plugins),
"/plugin/uninstall": ("POST", self.uninstall_plugin),
"/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme),
+ "/plugin/source/get": ("GET", self.get_custom_source),
+ "/plugin/source/save": ("POST", self.save_custom_source),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -63,7 +81,7 @@ class PluginRoute(Route):
.__dict__
)
- data = await request.json
+ data = await request.get_json()
plugin_name = data.get("name", None)
try:
success, message = await self.plugin_manager.reload(plugin_name)
@@ -78,22 +96,15 @@ class PluginRoute(Route):
custom = request.args.get("custom_registry")
force_refresh = request.args.get("force_refresh", "false").lower() == "true"
- cache_file = "data/plugins.json"
-
- if custom:
- urls = [custom]
- else:
- urls = [
- "https://api.soulter.top/astrbot/plugins",
- "https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json",
- ]
+ # 构建注册表源信息
+ source = self._build_registry_source(custom)
# 如果不是强制刷新,先检查缓存是否有效
cached_data = None
if not force_refresh:
# 先检查MD5是否匹配,如果匹配则使用缓存
- if await self._is_cache_valid(cache_file):
- cached_data = self._load_plugin_cache(cache_file)
+ if await self._is_cache_valid(source):
+ cached_data = self._load_plugin_cache(source.cache_file)
if cached_data:
logger.debug("缓存MD5匹配,使用缓存的插件市场数据")
return Response().ok(cached_data).__dict__
@@ -103,7 +114,7 @@ class PluginRoute(Route):
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
- for url in urls:
+ for url in source.urls:
try:
async with (
aiohttp.ClientSession(
@@ -113,7 +124,11 @@ class PluginRoute(Route):
session.get(url) as response,
):
if response.status == 200:
- remote_data = await response.json()
+ try:
+ remote_data = await response.json()
+ except aiohttp.ContentTypeError:
+ remote_text = await response.text()
+ remote_data = json.loads(remote_text)
# 检查远程数据是否为空
if not remote_data or (
@@ -122,11 +137,13 @@ class PluginRoute(Route):
logger.warning(f"远程插件市场数据为空: {url}")
continue # 继续尝试其他URL或使用缓存
- logger.info("成功获取远程插件市场数据")
+ logger.info(
+ f"成功获取远程插件市场数据,包含 {len(remote_data)} 个插件"
+ )
# 获取最新的MD5并保存到缓存
- current_md5 = await self._get_remote_md5()
+ current_md5 = await self._fetch_remote_md5(source.md5_url)
self._save_plugin_cache(
- cache_file,
+ source.cache_file,
remote_data,
current_md5,
)
@@ -137,7 +154,7 @@ class PluginRoute(Route):
# 如果远程获取失败,尝试使用缓存数据
if not cached_data:
- cached_data = self._load_plugin_cache(cache_file)
+ cached_data = self._load_plugin_cache(source.cache_file)
if cached_data:
logger.warning("远程插件市场数据获取失败,使用缓存数据")
@@ -145,24 +162,75 @@ class PluginRoute(Route):
return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__
- async def _is_cache_valid(self, cache_file: str) -> bool:
- """检查缓存是否有效(基于MD5)"""
- try:
- if not os.path.exists(cache_file):
- return False
+ def _build_registry_source(self, custom_url: str | None) -> RegistrySource:
+ """构建注册表源信息"""
+ if custom_url:
+ # 对自定义URL生成一个安全的文件名
+ url_hash = hashlib.md5(custom_url.encode()).hexdigest()[:8]
+ cache_file = f"data/plugins_custom_{url_hash}.json"
- # 加载缓存文件
+ # 更安全的后缀处理方式
+ if custom_url.endswith(".json"):
+ md5_url = custom_url[:-5] + "-md5.json"
+ else:
+ md5_url = custom_url + "-md5.json"
+
+ urls = [custom_url]
+ else:
+ cache_file = "data/plugins.json"
+ md5_url = "https://api.soulter.top/astrbot/plugins-md5"
+ urls = [
+ "https://api.soulter.top/astrbot/plugins",
+ "https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json",
+ ]
+ return RegistrySource(urls=urls, cache_file=cache_file, md5_url=md5_url)
+
+ def _load_cached_md5(self, cache_file: str) -> str | None:
+ """从缓存文件中加载MD5"""
+ if not os.path.exists(cache_file):
+ return None
+
+ try:
with open(cache_file, encoding="utf-8") as f:
cache_data = json.load(f)
+ return cache_data.get("md5")
+ except Exception as e:
+ logger.warning(f"加载缓存MD5失败: {e}")
+ return None
- cached_md5 = cache_data.get("md5")
+ async def _fetch_remote_md5(self, md5_url: str | None) -> str | None:
+ """获取远程MD5"""
+ if not md5_url:
+ return None
+
+ try:
+ ssl_context = ssl.create_default_context(cafile=certifi.where())
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
+
+ async with (
+ aiohttp.ClientSession(
+ trust_env=True,
+ connector=connector,
+ ) as session,
+ session.get(md5_url) as response,
+ ):
+ if response.status == 200:
+ data = await response.json()
+ return data.get("md5", "")
+ except Exception as e:
+ logger.debug(f"获取远程MD5失败: {e}")
+ return None
+
+ async def _is_cache_valid(self, source: RegistrySource) -> bool:
+ """检查缓存是否有效(基于MD5)"""
+ try:
+ cached_md5 = self._load_cached_md5(source.cache_file)
if not cached_md5:
logger.debug("缓存文件中没有MD5信息")
return False
- # 获取远程MD5
- remote_md5 = await self._get_remote_md5()
- if not remote_md5:
+ remote_md5 = await self._fetch_remote_md5(source.md5_url)
+ if remote_md5 is None:
logger.warning("无法获取远程MD5,将使用缓存")
return True # 如果无法获取远程MD5,认为缓存有效
@@ -176,30 +244,6 @@ class PluginRoute(Route):
logger.warning(f"检查缓存有效性失败: {e}")
return False
- async def _get_remote_md5(self) -> str:
- """获取远程插件数据的MD5"""
- try:
- ssl_context = ssl.create_default_context(cafile=certifi.where())
- connector = aiohttp.TCPConnector(ssl=ssl_context)
-
- async with (
- aiohttp.ClientSession(
- trust_env=True,
- connector=connector,
- ) as session,
- session.get(
- "https://api.soulter.top/astrbot/plugins-md5",
- ) as response,
- ):
- if response.status == 200:
- data = await response.json()
- return data.get("md5", "")
- logger.error(f"获取MD5失败,状态码:{response.status}")
- return ""
- except Exception as e:
- logger.error(f"获取远程MD5失败: {e}")
- return ""
-
def _load_plugin_cache(self, cache_file: str):
"""加载本地缓存的插件市场数据"""
try:
@@ -346,7 +390,7 @@ class PluginRoute(Route):
.__dict__
)
- post_data = await request.json
+ post_data = await request.get_json()
repo_url = post_data["url"]
proxy: str = post_data.get("proxy", None)
@@ -393,7 +437,7 @@ class PluginRoute(Route):
.__dict__
)
- post_data = await request.json
+ post_data = await request.get_json()
plugin_name = post_data["name"]
delete_config = post_data.get("delete_config", False)
delete_data = post_data.get("delete_data", False)
@@ -418,7 +462,7 @@ class PluginRoute(Route):
.__dict__
)
- post_data = await request.json
+ post_data = await request.get_json()
plugin_name = post_data["name"]
proxy: str = post_data.get("proxy", None)
try:
@@ -432,6 +476,59 @@ class PluginRoute(Route):
logger.error(f"/api/plugin/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+ async def update_all_plugins(self):
+ if DEMO_MODE:
+ return (
+ Response()
+ .error("You are not permitted to do this operation in demo mode")
+ .__dict__
+ )
+
+ post_data = await request.get_json()
+ plugin_names: list[str] = post_data.get("names") or []
+ proxy: str = post_data.get("proxy", "")
+
+ if not isinstance(plugin_names, list) or not plugin_names:
+ return Response().error("插件列表不能为空").__dict__
+
+ results = []
+ sem = asyncio.Semaphore(PLUGIN_UPDATE_CONCURRENCY)
+
+ async def _update_one(name: str):
+ async with sem:
+ try:
+ logger.info(f"批量更新插件 {name}")
+ await self.plugin_manager.update_plugin(name, proxy)
+ return {"name": name, "status": "ok", "message": "更新成功"}
+ except Exception as e:
+ logger.error(
+ f"/api/plugin/update-all: 更新插件 {name} 失败: {traceback.format_exc()}",
+ )
+ return {"name": name, "status": "error", "message": str(e)}
+
+ raw_results = await asyncio.gather(
+ *(_update_one(name) for name in plugin_names),
+ return_exceptions=True,
+ )
+ for name, result in zip(plugin_names, raw_results):
+ if isinstance(result, asyncio.CancelledError):
+ raise result
+ if isinstance(result, BaseException):
+ results.append(
+ {"name": name, "status": "error", "message": str(result)}
+ )
+ else:
+ results.append(result)
+
+ failed = [r for r in results if r["status"] == "error"]
+ message = (
+ "批量更新完成,全部成功。"
+ if not failed
+ else f"批量更新完成,其中 {len(failed)}/{len(results)} 个插件失败。"
+ )
+
+ return Response().ok({"results": results}, message).__dict__
+
async def off_plugin(self):
if DEMO_MODE:
return (
@@ -440,7 +537,7 @@ class PluginRoute(Route):
.__dict__
)
- post_data = await request.json
+ post_data = await request.get_json()
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_off_plugin(plugin_name)
@@ -458,7 +555,7 @@ class PluginRoute(Route):
.__dict__
)
- post_data = await request.json
+ post_data = await request.get_json()
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_on_plugin(plugin_name)
@@ -486,9 +583,13 @@ class PluginRoute(Route):
logger.warning(f"插件 {plugin_name} 不存在")
return Response().error(f"插件 {plugin_name} 不存在").__dict__
+ if not plugin_obj.root_dir_name:
+ logger.warning(f"插件 {plugin_name} 目录不存在")
+ return Response().error(f"插件 {plugin_name} 目录不存在").__dict__
+
plugin_dir = os.path.join(
self.plugin_manager.plugin_store_path,
- plugin_obj.root_dir_name,
+ plugin_obj.root_dir_name or "",
)
if not os.path.isdir(plugin_dir):
@@ -513,3 +614,22 @@ class PluginRoute(Route):
except Exception as e:
logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
return Response().error(f"读取README文件失败: {e!s}").__dict__
+
+ async def get_custom_source(self):
+ """获取自定义插件源"""
+ sources = await sp.global_get("custom_plugin_sources", [])
+ return Response().ok(sources).__dict__
+
+ async def save_custom_source(self):
+ """保存自定义插件源"""
+ try:
+ data = await request.get_json()
+ sources = data.get("sources", [])
+ if not isinstance(sources, list):
+ return Response().error("sources fields must be a list").__dict__
+
+ await sp.global_put("custom_plugin_sources", sources)
+ return Response().ok(None, "保存成功").__dict__
+ except Exception as e:
+ logger.error(f"/api/plugin/source/save: {traceback.format_exc()}")
+ return Response().error(str(e)).__dict__
diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py
index 1105b69a7..01ab292d4 100644
--- a/astrbot/dashboard/routes/route.py
+++ b/astrbot/dashboard/routes/route.py
@@ -12,6 +12,8 @@ class RouteContext:
class Route:
+ routes: list | dict
+
def __init__(self, context: RouteContext):
self.app = context.app
self.config = context.config
diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py
index 0b16c0949..a938d662d 100644
--- a/astrbot/dashboard/routes/session_management.py
+++ b/astrbot/dashboard/routes/session_management.py
@@ -1,16 +1,24 @@
-import traceback
-
from quart import request
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import col, select
from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
+from astrbot.core.db.po import ConversationV2, Preference
from astrbot.core.provider.entities import ProviderType
-from astrbot.core.star.session_llm_manager import SessionServiceManager
-from astrbot.core.star.session_plugin_manager import SessionPluginManager
from .route import Response, Route, RouteContext
+AVAILABLE_SESSION_RULE_KEYS = [
+ "session_service_config",
+ "session_plugin_config",
+ "kb_config",
+ f"provider_perf_{ProviderType.CHAT_COMPLETION.value}",
+ f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}",
+ f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}",
+]
+
class SessionManagementRoute(Route):
def __init__(
@@ -22,667 +30,364 @@ class SessionManagementRoute(Route):
super().__init__(context)
self.db_helper = db_helper
self.routes = {
- "/session/list": ("GET", self.list_sessions),
- "/session/update_persona": ("POST", self.update_session_persona),
- "/session/update_provider": ("POST", self.update_session_provider),
- "/session/plugins": ("GET", self.get_session_plugins),
- "/session/update_plugin": ("POST", self.update_session_plugin),
- "/session/update_llm": ("POST", self.update_session_llm),
- "/session/update_tts": ("POST", self.update_session_tts),
- "/session/update_name": ("POST", self.update_session_name),
- "/session/update_status": ("POST", self.update_session_status),
- "/session/delete": ("POST", self.delete_session),
+ "/session/list-rule": ("GET", self.list_session_rule),
+ "/session/update-rule": ("POST", self.update_session_rule),
+ "/session/delete-rule": ("POST", self.delete_session_rule),
+ "/session/batch-delete-rule": ("POST", self.batch_delete_session_rule),
+ "/session/active-umos": ("GET", self.list_umos),
}
self.conv_mgr = core_lifecycle.conversation_manager
self.core_lifecycle = core_lifecycle
self.register_routes()
- async def list_sessions(self):
- """获取所有会话的列表,包括 persona 和 provider 信息"""
- try:
- page = int(request.args.get("page", 1))
- page_size = int(request.args.get("page_size", 20))
- search_query = request.args.get("search", "")
- platform = request.args.get("platform", "")
+ async def _get_umo_rules(
+ self, page: int = 1, page_size: int = 10, search: str = ""
+ ) -> tuple[dict, int]:
+ """获取所有带有自定义规则的 umo 及其规则内容(支持分页和搜索)。
- # 获取活跃的会话数据(处于对话内的会话)
- sessions_data, total = await self.db_helper.get_session_conversations(
- page,
- page_size,
- search_query,
- platform,
+ 如果某个 umo 在 preference 中有以下字段,则表示有自定义规则:
+
+ 1. session_service_config (包含了 是否启用这个umo, 这个umo是否启用 llm, 这个umo是否启用tts, umo自定义名称。)
+ 2. session_plugin_config (包含了 这个 umo 的 plugin set)
+ 3. provider_perf_{ProviderType.value} (包含了这个 umo 所选择使用的 provider 信息)
+ 4. kb_config (包含了这个 umo 的知识库相关配置)
+
+ Args:
+ page: 页码,从 1 开始
+ page_size: 每页数量
+ search: 搜索关键词,匹配 umo 或 custom_name
+
+ Returns:
+ tuple[dict, int]: (umo_rules, total) - 分页后的 umo 规则和总数
+ """
+ umo_rules = {}
+ async with self.db_helper.get_db() as session:
+ session: AsyncSession
+ result = await session.execute(
+ select(Preference).where(
+ col(Preference.scope) == "umo",
+ col(Preference.key).in_(AVAILABLE_SESSION_RULE_KEYS),
+ )
+ )
+ prefs = result.scalars().all()
+ for pref in prefs:
+ umo_id = pref.scope_id
+ if umo_id not in umo_rules:
+ umo_rules[umo_id] = {}
+ if pref.key == "session_plugin_config" and umo_id in pref.value["val"]:
+ umo_rules[umo_id][pref.key] = pref.value["val"][umo_id]
+ else:
+ umo_rules[umo_id][pref.key] = pref.value["val"]
+
+ # 搜索过滤
+ if search:
+ search_lower = search.lower()
+ filtered_rules = {}
+ for umo_id, rules in umo_rules.items():
+ # 匹配 umo
+ if search_lower in umo_id.lower():
+ filtered_rules[umo_id] = rules
+ continue
+ # 匹配 custom_name
+ svc_config = rules.get("session_service_config", {})
+ custom_name = svc_config.get("custom_name", "") if svc_config else ""
+ if custom_name and search_lower in custom_name.lower():
+ filtered_rules[umo_id] = rules
+ umo_rules = filtered_rules
+
+ # 获取总数
+ total = len(umo_rules)
+
+ # 分页处理
+ all_umo_ids = list(umo_rules.keys())
+ start_idx = (page - 1) * page_size
+ end_idx = start_idx + page_size
+ paginated_umo_ids = all_umo_ids[start_idx:end_idx]
+
+ # 只返回分页后的数据
+ paginated_rules = {umo_id: umo_rules[umo_id] for umo_id in paginated_umo_ids}
+
+ return paginated_rules, total
+
+ async def list_session_rule(self):
+ """获取所有自定义的规则(支持分页和搜索)
+
+ 返回已配置规则的 umo 列表及其规则内容,以及可用的 personas 和 providers
+
+ Query 参数:
+ page: 页码,默认为 1
+ page_size: 每页数量,默认为 10
+ search: 搜索关键词,匹配 umo 或 custom_name
+ """
+ try:
+ # 获取分页和搜索参数
+ page = request.args.get("page", 1, type=int)
+ page_size = request.args.get("page_size", 10, type=int)
+ search = request.args.get("search", "", type=str).strip()
+
+ # 参数校验
+ if page < 1:
+ page = 1
+ if page_size < 1:
+ page_size = 10
+ if page_size > 100:
+ page_size = 100
+
+ umo_rules, total = await self._get_umo_rules(
+ page=page, page_size=page_size, search=search
)
+ # 构建规则列表
+ rules_list = []
+ for umo, rules in umo_rules.items():
+ rule_info = {
+ "umo": umo,
+ "rules": rules,
+ }
+ # 解析 umo 格式: 平台:消息类型:会话ID
+ parts = umo.split(":")
+ if len(parts) >= 3:
+ rule_info["platform"] = parts[0]
+ rule_info["message_type"] = parts[1]
+ rule_info["session_id"] = parts[2]
+ rules_list.append(rule_info)
+
+ # 获取可用的 providers 和 personas
provider_manager = self.core_lifecycle.provider_manager
persona_mgr = self.core_lifecycle.persona_mgr
- personas = persona_mgr.personas_v3
- sessions = []
-
- # 循环补充非数据库信息,如 provider 和 session 状态
- for data in sessions_data:
- session_id = data["session_id"]
- conversation_id = data["conversation_id"]
- conv_persona_id = data["persona_id"]
- title = data["title"]
- persona_name = data["persona_name"]
-
- # 处理 persona 显示
- if persona_name is None:
- if conv_persona_id is None:
- if default_persona := persona_mgr.selected_default_persona_v3:
- persona_name = default_persona["name"]
- else:
- persona_name = "[%None]"
-
- session_info = {
- "session_id": session_id,
- "conversation_id": conversation_id,
- "persona_id": persona_name,
- "chat_provider_id": None,
- "stt_provider_id": None,
- "tts_provider_id": None,
- "session_enabled": SessionServiceManager.is_session_enabled(
- session_id,
- ),
- "llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
- session_id,
- ),
- "tts_enabled": SessionServiceManager.is_tts_enabled_for_session(
- session_id,
- ),
- "platform": session_id.split(":")[0]
- if ":" in session_id
- else "unknown",
- "message_type": session_id.split(":")[1]
- if session_id.count(":") >= 1
- else "unknown",
- "session_name": SessionServiceManager.get_session_display_name(
- session_id,
- ),
- "session_raw_name": session_id.split(":")[2]
- if session_id.count(":") >= 2
- else session_id,
- "title": title,
- }
-
- # 获取 provider 信息
- chat_provider = provider_manager.get_using_provider(
- provider_type=ProviderType.CHAT_COMPLETION,
- umo=session_id,
- )
- tts_provider = provider_manager.get_using_provider(
- provider_type=ProviderType.TEXT_TO_SPEECH,
- umo=session_id,
- )
- stt_provider = provider_manager.get_using_provider(
- provider_type=ProviderType.SPEECH_TO_TEXT,
- umo=session_id,
- )
- if chat_provider:
- meta = chat_provider.meta()
- session_info["chat_provider_id"] = meta.id
- if tts_provider:
- meta = tts_provider.meta()
- session_info["tts_provider_id"] = meta.id
- if stt_provider:
- meta = stt_provider.meta()
- session_info["stt_provider_id"] = meta.id
-
- sessions.append(session_info)
-
- # 获取可用的 personas 和 providers 列表
available_personas = [
- {"name": p["name"], "prompt": p.get("prompt", "")} for p in personas
+ {"name": p["name"], "prompt": p.get("prompt", "")}
+ for p in persona_mgr.personas_v3
]
- available_chat_providers = []
- for provider in provider_manager.provider_insts:
- meta = provider.meta()
- available_chat_providers.append(
- {
- "id": meta.id,
- "name": meta.id,
- "model": meta.model,
- "type": meta.type,
- },
- )
-
- available_stt_providers = []
- for provider in provider_manager.stt_provider_insts:
- meta = provider.meta()
- available_stt_providers.append(
- {
- "id": meta.id,
- "name": meta.id,
- "model": meta.model,
- "type": meta.type,
- },
- )
-
- available_tts_providers = []
- for provider in provider_manager.tts_provider_insts:
- meta = provider.meta()
- available_tts_providers.append(
- {
- "id": meta.id,
- "name": meta.id,
- "model": meta.model,
- "type": meta.type,
- },
- )
-
- result = {
- "sessions": sessions,
- "available_personas": available_personas,
- "available_chat_providers": available_chat_providers,
- "available_stt_providers": available_stt_providers,
- "available_tts_providers": available_tts_providers,
- "pagination": {
- "page": page,
- "page_size": page_size,
- "total": total,
- "total_pages": (total + page_size - 1) // page_size
- if page_size > 0
- else 0,
- },
- }
-
- return Response().ok(result).__dict__
-
- except Exception as e:
- error_msg = f"获取会话列表失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"获取会话列表失败: {e!s}").__dict__
-
- async def _update_single_session_persona(self, session_id: str, persona_name: str):
- """更新单个会话的 persona 的内部方法"""
- conversation_manager = self.core_lifecycle.star_context.conversation_manager
- conversation_id = await conversation_manager.get_curr_conversation_id(
- session_id,
- )
-
- conv = None
- if conversation_id:
- conv = await conversation_manager.get_conversation(
- unified_msg_origin=session_id,
- conversation_id=conversation_id,
- )
- if not conv or not conversation_id:
- conversation_id = await conversation_manager.new_conversation(session_id)
-
- # 更新 persona
- await conversation_manager.update_conversation_persona_id(
- session_id,
- persona_name,
- )
-
- async def _handle_batch_operation(
- self,
- session_ids: list,
- operation_func,
- operation_name: str,
- **kwargs,
- ):
- """通用的批量操作处理方法"""
- success_count = 0
- error_sessions = []
-
- for session_id in session_ids:
- try:
- await operation_func(session_id, **kwargs)
- success_count += 1
- except Exception as e:
- logger.error(f"批量{operation_name} 会话 {session_id} 失败: {e!s}")
- error_sessions.append(session_id)
-
- if error_sessions:
- return (
- Response()
- .ok(
- {
- "message": f"批量更新完成,成功: {success_count},失败: {len(error_sessions)}",
- "success_count": success_count,
- "error_count": len(error_sessions),
- "error_sessions": error_sessions,
- },
- )
- .__dict__
- )
- return (
- Response()
- .ok(
+ available_chat_providers = [
{
- "message": f"成功批量{operation_name} {success_count} 个会话",
- "success_count": success_count,
- },
- )
- .__dict__
- )
+ "id": p.meta().id,
+ "name": p.meta().id,
+ "model": p.meta().model,
+ }
+ for p in provider_manager.provider_insts
+ ]
- async def update_session_persona(self):
- """更新指定会话的 persona,支持批量操作"""
- try:
- data = await request.get_json()
- is_batch = data.get("is_batch", False)
- persona_name = data.get("persona_name")
+ available_stt_providers = [
+ {
+ "id": p.meta().id,
+ "name": p.meta().id,
+ "model": p.meta().model,
+ }
+ for p in provider_manager.stt_provider_insts
+ ]
- if persona_name is None:
- return Response().error("缺少必要参数: persona_name").__dict__
+ available_tts_providers = [
+ {
+ "id": p.meta().id,
+ "name": p.meta().id,
+ "model": p.meta().model,
+ }
+ for p in provider_manager.tts_provider_insts
+ ]
- if is_batch:
- session_ids = data.get("session_ids", [])
- if not session_ids:
- return Response().error("缺少必要参数: session_ids").__dict__
+ # 获取可用的插件列表(排除 reserved 的系统插件)
+ plugin_manager = self.core_lifecycle.plugin_manager
+ available_plugins = [
+ {
+ "name": p.name,
+ "display_name": p.display_name or p.name,
+ "desc": p.desc,
+ }
+ for p in plugin_manager.context.get_all_stars()
+ if not p.reserved and p.name
+ ]
- return await self._handle_batch_operation(
- session_ids,
- self._update_single_session_persona,
- "更新人格",
- persona_name=persona_name,
- )
- session_id = data.get("session_id")
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
+ # 获取可用的知识库列表
+ available_kbs = []
+ kb_manager = self.core_lifecycle.kb_manager
+ if kb_manager:
+ try:
+ kbs = await kb_manager.list_kbs()
+ available_kbs = [
+ {
+ "kb_id": kb.kb_id,
+ "kb_name": kb.kb_name,
+ "emoji": kb.emoji,
+ }
+ for kb in kbs
+ ]
+ except Exception as e:
+ logger.warning(f"获取知识库列表失败: {e!s}")
- await self._update_single_session_persona(session_id, persona_name)
return (
Response()
.ok(
{
- "message": f"成功更新会话 {session_id} 的人格为 {persona_name}",
- },
+ "rules": rules_list,
+ "total": total,
+ "page": page,
+ "page_size": page_size,
+ "available_personas": available_personas,
+ "available_chat_providers": available_chat_providers,
+ "available_stt_providers": available_stt_providers,
+ "available_tts_providers": available_tts_providers,
+ "available_plugins": available_plugins,
+ "available_kbs": available_kbs,
+ "available_rule_keys": AVAILABLE_SESSION_RULE_KEYS,
+ }
)
.__dict__
)
-
except Exception as e:
- error_msg = f"更新会话人格失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"更新会话人格失败: {e!s}").__dict__
+ logger.error(f"获取规则列表失败: {e!s}")
+ return Response().error(f"获取规则列表失败: {e!s}").__dict__
- async def _update_single_session_provider(
- self,
- session_id: str,
- provider_id: str,
- provider_type_enum,
- ):
- """更新单个会话的 provider 的内部方法"""
- provider_manager = self.core_lifecycle.star_context.provider_manager
- await provider_manager.set_provider(
- provider_id=provider_id,
- provider_type=provider_type_enum,
- umo=session_id,
- )
+ async def update_session_rule(self):
+ """更新某个 umo 的自定义规则
- async def update_session_provider(self):
- """更新指定会话的 provider,支持批量操作"""
+ 请求体:
+ {
+ "umo": "平台:消息类型:会话ID",
+ "rule_key": "session_service_config" | "session_plugin_config" | "kb_config" | "provider_perf_xxx",
+ "rule_value": {...} // 规则值,具体结构根据 rule_key 不同而不同
+ }
+ """
try:
data = await request.get_json()
- is_batch = data.get("is_batch", False)
- provider_id = data.get("provider_id")
- provider_type = data.get("provider_type")
+ umo = data.get("umo")
+ rule_key = data.get("rule_key")
+ rule_value = data.get("rule_value")
- if not provider_id or not provider_type:
+ if not umo:
+ return Response().error("缺少必要参数: umo").__dict__
+ if not rule_key:
+ return Response().error("缺少必要参数: rule_key").__dict__
+ if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
+ return Response().error(f"不支持的规则键: {rule_key}").__dict__
+
+ if rule_key == "session_plugin_config":
+ rule_value = {
+ umo: rule_value,
+ }
+
+ # 使用 shared preferences 更新规则
+ await sp.session_put(umo, rule_key, rule_value)
+
+ return (
+ Response()
+ .ok({"message": f"规则 {rule_key} 已更新", "umo": umo})
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(f"更新会话规则失败: {e!s}")
+ return Response().error(f"更新会话规则失败: {e!s}").__dict__
+
+ async def delete_session_rule(self):
+ """删除某个 umo 的自定义规则
+
+ 请求体:
+ {
+ "umo": "平台:消息类型:会话ID",
+ "rule_key": "session_service_config" | "session_plugin_config" | ... (可选,不传则删除所有规则)
+ }
+ """
+ try:
+ data = await request.get_json()
+ umo = data.get("umo")
+ rule_key = data.get("rule_key")
+
+ if not umo:
+ return Response().error("缺少必要参数: umo").__dict__
+
+ if rule_key:
+ # 删除单个规则
+ if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
+ return Response().error(f"不支持的规则键: {rule_key}").__dict__
+ await sp.session_remove(umo, rule_key)
return (
Response()
- .error("缺少必要参数: provider_id, provider_type")
+ .ok({"message": f"规则 {rule_key} 已删除", "umo": umo})
.__dict__
)
+ else:
+ # 删除该 umo 的所有规则
+ await sp.clear_async("umo", umo)
+ return Response().ok({"message": "所有规则已删除", "umo": umo}).__dict__
+ except Exception as e:
+ logger.error(f"删除会话规则失败: {e!s}")
+ return Response().error(f"删除会话规则失败: {e!s}").__dict__
- # 转换 provider_type 字符串为枚举
- if provider_type == "chat_completion":
- provider_type_enum = ProviderType.CHAT_COMPLETION
- elif provider_type == "speech_to_text":
- provider_type_enum = ProviderType.SPEECH_TO_TEXT
- elif provider_type == "text_to_speech":
- provider_type_enum = ProviderType.TEXT_TO_SPEECH
+ async def batch_delete_session_rule(self):
+ """批量删除多个 umo 的自定义规则
+
+ 请求体:
+ {
+ "umos": ["平台:消息类型:会话ID", ...] // umo 列表
+ }
+ """
+ try:
+ data = await request.get_json()
+ umos = data.get("umos", [])
+
+ if not umos:
+ return Response().error("缺少必要参数: umos").__dict__
+
+ if not isinstance(umos, list):
+ return Response().error("参数 umos 必须是数组").__dict__
+
+ # 批量删除
+ deleted_count = 0
+ failed_umos = []
+ for umo in umos:
+ try:
+ await sp.clear_async("umo", umo)
+ deleted_count += 1
+ except Exception as e:
+ logger.error(f"删除 umo {umo} 的规则失败: {e!s}")
+ failed_umos.append(umo)
+
+ if failed_umos:
+ return (
+ Response()
+ .ok(
+ {
+ "message": f"已删除 {deleted_count} 条规则,{len(failed_umos)} 条删除失败",
+ "deleted_count": deleted_count,
+ "failed_umos": failed_umos,
+ }
+ )
+ .__dict__
+ )
else:
return (
Response()
- .error(f"不支持的 provider_type: {provider_type}")
- .__dict__
- )
-
- if is_batch:
- session_ids = data.get("session_ids", [])
- if not session_ids:
- return Response().error("缺少必要参数: session_ids").__dict__
-
- return await self._handle_batch_operation(
- session_ids,
- self._update_single_session_provider,
- f"更新 {provider_type} 提供商",
- provider_id=provider_id,
- provider_type_enum=provider_type_enum,
- )
- session_id = data.get("session_id")
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- await self._update_single_session_provider(
- session_id,
- provider_id,
- provider_type_enum,
- )
- return (
- Response()
- .ok(
- {
- "message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}",
- },
- )
- .__dict__
- )
-
- except Exception as e:
- error_msg = f"更新会话提供商失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"更新会话提供商失败: {e!s}").__dict__
-
- async def get_session_plugins(self):
- """获取指定会话的插件配置信息"""
- try:
- session_id = request.args.get("session_id")
-
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- # 获取所有已激活的插件
- all_plugins = []
- plugin_manager = self.core_lifecycle.plugin_manager
-
- for plugin in plugin_manager.context.get_all_stars():
- # 只显示已激活的插件,不包括保留插件
- if plugin.activated and not plugin.reserved:
- plugin_name = plugin.name or ""
- plugin_enabled = SessionPluginManager.is_plugin_enabled_for_session(
- session_id,
- plugin_name,
- )
-
- all_plugins.append(
+ .ok(
{
- "name": plugin_name,
- "author": plugin.author,
- "desc": plugin.desc,
- "enabled": plugin_enabled,
- },
+ "message": f"已删除 {deleted_count} 条规则",
+ "deleted_count": deleted_count,
+ }
)
-
- return (
- Response()
- .ok(
- {
- "session_id": session_id,
- "plugins": all_plugins,
- },
- )
- .__dict__
- )
-
- except Exception as e:
- error_msg = f"获取会话插件配置失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"获取会话插件配置失败: {e!s}").__dict__
-
- async def update_session_plugin(self):
- """更新指定会话的插件启停状态"""
- try:
- data = await request.get_json()
- session_id = data.get("session_id")
- plugin_name = data.get("plugin_name")
- enabled = data.get("enabled")
-
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- if not plugin_name:
- return Response().error("缺少必要参数: plugin_name").__dict__
-
- if enabled is None:
- return Response().error("缺少必要参数: enabled").__dict__
-
- # 验证插件是否存在且已激活
- plugin_manager = self.core_lifecycle.plugin_manager
- plugin = plugin_manager.context.get_registered_star(plugin_name)
-
- if not plugin:
- return Response().error(f"插件 {plugin_name} 不存在").__dict__
-
- if not plugin.activated:
- return Response().error(f"插件 {plugin_name} 未激活").__dict__
-
- if plugin.reserved:
- return (
- Response()
- .error(f"插件 {plugin_name} 是系统保留插件,无法管理")
.__dict__
)
-
- # 使用 SessionPluginManager 更新插件状态
- SessionPluginManager.set_plugin_status_for_session(
- session_id,
- plugin_name,
- enabled,
- )
-
- return (
- Response()
- .ok(
- {
- "message": f"插件 {plugin_name} 已{'启用' if enabled else '禁用'}",
- "session_id": session_id,
- "plugin_name": plugin_name,
- "enabled": enabled,
- },
- )
- .__dict__
- )
-
except Exception as e:
- error_msg = f"更新会话插件状态失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"更新会话插件状态失败: {e!s}").__dict__
+ logger.error(f"批量删除会话规则失败: {e!s}")
+ return Response().error(f"批量删除会话规则失败: {e!s}").__dict__
- async def _update_single_session_llm(self, session_id: str, enabled: bool):
- """更新单个会话的LLM状态的内部方法"""
- SessionServiceManager.set_llm_status_for_session(session_id, enabled)
+ async def list_umos(self):
+ """列出所有有对话记录的 umo,从 Conversations 表中找
- async def update_session_llm(self):
- """更新指定会话的LLM启停状态,支持批量操作"""
+ 仅返回 umo 字符串列表,用于用户在创建规则时选择 umo
+ """
try:
- data = await request.get_json()
- is_batch = data.get("is_batch", False)
- enabled = data.get("enabled")
-
- if enabled is None:
- return Response().error("缺少必要参数: enabled").__dict__
-
- if is_batch:
- session_ids = data.get("session_ids", [])
- if not session_ids:
- return Response().error("缺少必要参数: session_ids").__dict__
-
- result = await self._handle_batch_operation(
- session_ids,
- self._update_single_session_llm,
- f"{'启用' if enabled else '禁用'}LLM",
- enabled=enabled,
+ # 从 Conversation 表获取所有 distinct user_id (即 umo)
+ async with self.db_helper.get_db() as session:
+ session: AsyncSession
+ result = await session.execute(
+ select(ConversationV2.user_id)
+ .distinct()
+ .order_by(ConversationV2.user_id)
)
- return result
- session_id = data.get("session_id")
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- await self._update_single_session_llm(session_id, enabled)
- return (
- Response()
- .ok(
- {
- "message": f"LLM已{'启用' if enabled else '禁用'}",
- "session_id": session_id,
- "llm_enabled": enabled,
- },
- )
- .__dict__
- )
+ umos = [row[0] for row in result.fetchall()]
+ return Response().ok({"umos": umos}).__dict__
except Exception as e:
- error_msg = f"更新会话LLM状态失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"更新会话LLM状态失败: {e!s}").__dict__
-
- async def _update_single_session_tts(self, session_id: str, enabled: bool):
- """更新单个会话的TTS状态的内部方法"""
- SessionServiceManager.set_tts_status_for_session(session_id, enabled)
-
- async def update_session_tts(self):
- """更新指定会话的TTS启停状态,支持批量操作"""
- try:
- data = await request.get_json()
- is_batch = data.get("is_batch", False)
- enabled = data.get("enabled")
-
- if enabled is None:
- return Response().error("缺少必要参数: enabled").__dict__
-
- if is_batch:
- session_ids = data.get("session_ids", [])
- if not session_ids:
- return Response().error("缺少必要参数: session_ids").__dict__
-
- result = await self._handle_batch_operation(
- session_ids,
- self._update_single_session_tts,
- f"{'启用' if enabled else '禁用'}TTS",
- enabled=enabled,
- )
- return result
- session_id = data.get("session_id")
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- await self._update_single_session_tts(session_id, enabled)
- return (
- Response()
- .ok(
- {
- "message": f"TTS已{'启用' if enabled else '禁用'}",
- "session_id": session_id,
- "tts_enabled": enabled,
- },
- )
- .__dict__
- )
-
- except Exception as e:
- error_msg = f"更新会话TTS状态失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"更新会话TTS状态失败: {e!s}").__dict__
-
- async def update_session_name(self):
- """更新指定会话的自定义名称"""
- try:
- data = await request.get_json()
- session_id = data.get("session_id")
- custom_name = data.get("custom_name", "")
-
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- # 使用 SessionServiceManager 更新会话名称
- SessionServiceManager.set_session_custom_name(session_id, custom_name)
-
- return (
- Response()
- .ok(
- {
- "message": f"会话名称已更新为: {custom_name if custom_name.strip() else '已清除自定义名称'}",
- "session_id": session_id,
- "custom_name": custom_name,
- "display_name": SessionServiceManager.get_session_display_name(
- session_id,
- ),
- },
- )
- .__dict__
- )
-
- except Exception as e:
- error_msg = f"更新会话名称失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"更新会话名称失败: {e!s}").__dict__
-
- async def update_session_status(self):
- """更新指定会话的整体启停状态"""
- try:
- data = await request.get_json()
- session_id = data.get("session_id")
- session_enabled = data.get("session_enabled")
-
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- if session_enabled is None:
- return Response().error("缺少必要参数: session_enabled").__dict__
-
- # 使用 SessionServiceManager 更新会话整体状态
- SessionServiceManager.set_session_status(session_id, session_enabled)
-
- return (
- Response()
- .ok(
- {
- "message": f"会话整体状态已更新为: {'启用' if session_enabled else '禁用'}",
- "session_id": session_id,
- "session_enabled": session_enabled,
- },
- )
- .__dict__
- )
-
- except Exception as e:
- error_msg = f"更新会话整体状态失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"更新会话整体状态失败: {e!s}").__dict__
-
- async def delete_session(self):
- """删除指定会话及其所有相关数据"""
- try:
- data = await request.get_json()
- session_id = data.get("session_id")
-
- if not session_id:
- return Response().error("缺少必要参数: session_id").__dict__
-
- # 删除会话的所有相关数据
- conversation_manager = self.core_lifecycle.conversation_manager
-
- # 1. 删除会话的所有对话
- try:
- await conversation_manager.delete_conversations_by_user_id(session_id)
- except Exception as e:
- logger.warning(f"删除会话 {session_id} 的对话失败: {e!s}")
-
- # 2. 清除会话的偏好设置数据(清空该会话的所有配置)
- try:
- await sp.clear_async("umo", session_id)
- except Exception as e:
- logger.warning(f"清除会话 {session_id} 的偏好设置失败: {e!s}")
-
- return (
- Response()
- .ok(
- {
- "message": f"会话 {session_id} 及其相关所有对话数据已成功删除",
- "session_id": session_id,
- },
- )
- .__dict__
- )
-
- except Exception as e:
- error_msg = f"删除会话失败: {e!s}\n{traceback.format_exc()}"
- logger.error(error_msg)
- return Response().error(f"删除会话失败: {e!s}").__dict__
+ logger.error(f"获取 UMO 列表失败: {e!s}")
+ return Response().error(f"获取 UMO 列表失败: {e!s}").__dict__
diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py
index 64cd78caa..d7b082000 100644
--- a/astrbot/dashboard/routes/tools.py
+++ b/astrbot/dashboard/routes/tools.py
@@ -3,6 +3,7 @@ import traceback
from quart import request
from astrbot.core import logger
+from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star import star_map
@@ -296,15 +297,30 @@ class ToolsRoute(Route):
"""获取所有注册的工具列表"""
try:
tools = self.tool_mgr.func_list
- tools_dict = [
- {
+ tools_dict = []
+ for tool in tools:
+ if isinstance(tool, MCPTool):
+ origin = "mcp"
+ origin_name = tool.mcp_server_name
+ elif tool.handler_module_path and star_map.get(
+ tool.handler_module_path
+ ):
+ star = star_map[tool.handler_module_path]
+ origin = "plugin"
+ origin_name = star.name
+ else:
+ origin = "unknown"
+ origin_name = "unknown"
+
+ tool_info = {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"active": tool.active,
+ "origin": origin,
+ "origin_name": origin_name,
}
- for tool in tools
- ]
+ tools_dict.append(tool_info)
return Response().ok(data=tools_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())
diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py
index 84976f2ba..6d6530c90 100644
--- a/astrbot/dashboard/server.py
+++ b/astrbot/dashboard/server.py
@@ -2,9 +2,12 @@ import asyncio
import logging
import os
import socket
+from typing import cast
import jwt
import psutil
+from flask.json.provider import DefaultJSONProvider
+from psutil._common import addr as psutil_addr
from quart import Quart, g, jsonify, request
from quart.logging import default_handler
@@ -16,11 +19,12 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import get_local_ip_addresses
from .routes import *
+from .routes.platform import PlatformRoute
from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute
from .routes.t2i import T2iRoute
-APP: Quart = None
+APP: Quart
class AstrBotDashboard:
@@ -47,7 +51,7 @@ class AstrBotDashboard:
self.app.config["MAX_CONTENT_LENGTH"] = (
128 * 1024 * 1024
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
- self.app.json.sort_keys = False
+ cast(DefaultJSONProvider, self.app.json).sort_keys = False
self.app.before_request(self.auth_middleware)
# token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler)
@@ -63,6 +67,7 @@ class AstrBotDashboard:
core_lifecycle,
core_lifecycle.plugin_manager,
)
+ self.command_route = CommandRoute(self.context)
self.cr = ConfigRoute(self.context, core_lifecycle)
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
@@ -79,6 +84,7 @@ class AstrBotDashboard:
self.persona_route = PersonaRoute(self.context, db, core_lifecycle)
self.t2i_route = T2iRoute(self.context, core_lifecycle)
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
+ self.platform_route = PlatformRoute(self.context, core_lifecycle)
self.app.add_url_rule(
"/api/plug/",
@@ -102,7 +108,7 @@ class AstrBotDashboard:
async def auth_middleware(self):
if not request.path.startswith("/api"):
return None
- allowed_endpoints = ["/api/auth/login", "/api/file"]
+ allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
return None
# 声明 JWT
@@ -145,7 +151,7 @@ class AstrBotDashboard:
"""获取占用端口的进程详细信息"""
try:
for conn in psutil.net_connections(kind="inet"):
- if conn.laddr.port == port:
+ if cast(psutil_addr, conn.laddr).port == port:
try:
process = psutil.Process(conn.pid)
# 获取详细信息
diff --git a/changelogs/v4.5.7.md b/changelogs/v4.5.7.md
new file mode 100644
index 000000000..06317f57b
--- /dev/null
+++ b/changelogs/v4.5.7.md
@@ -0,0 +1,12 @@
+## What's Changed
+
+1. 新增:支持为 OpenAI API 提供商自定义请求头 ([#3581](https://github.com/AstrBotDevs/AstrBot/issues/3581))
+2. 新增:为 WebChat 为 Thinking 模型添加思考过程展示功能;支持快捷切换流式输出 / 非流式输出。([#3632](https://github.com/AstrBotDevs/AstrBot/issues/3632))
+3. 新增:优化插件调用 LLM 和 Agent 的路径,为 Context 类引入多个调用 LLM 和 Agent 的便捷方法 ([#3636](https://github.com/AstrBotDevs/AstrBot/issues/3636))
+4. 优化:改善不支持流式输出的消息平台的回退策略 ([#3547](https://github.com/AstrBotDevs/AstrBot/issues/3547))
+5. 优化:当同一个会话(umo)下同时有多个请求时,执行排队处理,避免并发请求导致的上下文混乱问题 ([#3607](https://github.com/AstrBotDevs/AstrBot/issues/3607))
+6. 优化:优化 WebUI 的登录界面和 Changelog 页面的显示效果
+7. 修复:修复在知识库名字过长的情况下,“选择知识库”按钮显示异常的问题 ([#3582](https://github.com/AstrBotDevs/AstrBot/issues/3582))
+8. 修复:修复部分情况下,分段消息发送时导致的死锁问题(由 PR #3607 引入)
+9. 修复:钉钉适配器使用部分指令无法生效的问题 ([#3634](https://github.com/AstrBotDevs/AstrBot/issues/3634))
+10. 其他:为部分适配器添加缺失的 send_streaming 方法 ([#3545](https://github.com/AstrBotDevs/AstrBot/issues/3545))
diff --git a/changelogs/v4.5.8.md b/changelogs/v4.5.8.md
new file mode 100644
index 000000000..2f2364623
--- /dev/null
+++ b/changelogs/v4.5.8.md
@@ -0,0 +1,5 @@
+## What's Changed
+
+hot fix of 4.5.7
+
+fix: 无法正常发送图片,报错 `pydantic_core._pydantic_core.ValidationError`
diff --git a/changelogs/v4.6.0.md b/changelogs/v4.6.0.md
new file mode 100644
index 000000000..ca5439900
--- /dev/null
+++ b/changelogs/v4.6.0.md
@@ -0,0 +1,23 @@
+## What's Changed
+
+1. 新增: 支持 gemini-3 系列的 thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
+2. 新增: 支持知识库的 Agentic 检索功能 ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
+3. 新增: 为知识库添加 URL 文档解析器 ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
+4. 修复(core.platform): 修复启用多个企业微信智能机器人适配器时消息混乱的问题 ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
+5. 修复: MCP Server 连接成功一段时间后,调用 mcp 工具时可能出现 `anyio.ClosedResourceError` 错误 ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
+6. 新增(chat): 重构聊天组件结构并添加新功能 ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
+7. 修复(dashboard.i18n): 完善缺失的英文国际化键值 ([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
+8. 重构: 实现 WebChat 会话管理及从版本 4.6 迁移到 4.7
+9. 持续集成(docker-build): 每日构建 Nightly 版本 Docker 镜像 ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))
+
+---
+
+1. feat: add supports for gemini-3 series thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
+2. feat: supports knowledge base agentic search ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
+3. feat: Add URL document parser for knowledge base ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
+4. fix(core.platform): fix message mix-up issue when enabling multiple WeCom AI Bot adapters ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
+5. fix: fix `anyio.ClosedResourceError` that may occur when calling mcp tools after a period of successful connection to MCP Server ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
+6. feat(chat): refactor chat component structure and add new features ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
+7. fix(dashboard.i18n): complete the missing i18n keys for en([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
+8. refactor: Implement WebChat session management and migration from version 4.6 to 4.7
+9. ci(docker-build): build nightly image everyday ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))
diff --git a/changelogs/v4.6.1.md b/changelogs/v4.6.1.md
new file mode 100644
index 000000000..97c6f8a3e
--- /dev/null
+++ b/changelogs/v4.6.1.md
@@ -0,0 +1,29 @@
+## What's Changed
+
+**hot fix of v4.6.0**
+
+fix(core.db): 修复升级后 webchat 相关对话数据未正确迁移的问题 ([#3745](https://github.com/AstrBotDevs/AstrBot/issues/3745))
+
+---
+
+1. 新增: 支持 gemini-3 系列的 thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
+2. 新增: 支持知识库的 Agentic 检索功能 ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
+3. 新增: 为知识库添加 URL 文档解析器 ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
+4. 修复(core.platform): 修复启用多个企业微信智能机器人适配器时消息混乱的问题 ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
+5. 修复: MCP Server 连接成功一段时间后,调用 mcp 工具时可能出现 `anyio.ClosedResourceError` 错误 ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
+6. 新增(chat): 重构聊天组件结构并添加新功能 ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
+7. 修复(dashboard.i18n): 完善缺失的英文国际化键值 ([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
+8. 重构: 实现 WebChat 会话管理及从版本 4.6 迁移到 4.7
+9. 持续集成(docker-build): 每日构建 Nightly 版本 Docker 镜像 ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))
+
+---
+
+1. feat: add supports for gemini-3 series thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
+2. feat: supports knowledge base agentic search ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
+3. feat: Add URL document parser for knowledge base ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
+4. fix(core.platform): fix message mix-up issue when enabling multiple WeCom AI Bot adapters ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
+5. fix: fix `anyio.ClosedResourceError` that may occur when calling mcp tools after a period of successful connection to MCP Server ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
+6. feat(chat): refactor chat component structure and add new features ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
+7. fix(dashboard.i18n): complete the missing i18n keys for en([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
+8. refactor: Implement WebChat session management and migration from version 4.6 to 4.7
+9. ci(docker-build): build nightly image everyday ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))
diff --git a/changelogs/v4.7.0.md b/changelogs/v4.7.0.md
new file mode 100644
index 000000000..687e3479b
--- /dev/null
+++ b/changelogs/v4.7.0.md
@@ -0,0 +1,18 @@
+## What's Changed
+
+重构:
+- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界
+- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
+
+优化:
+- Dify、阿里云百炼应用支持流式输出
+- 防止分段回复正则表达式解析错误导致消息不发送
+- 群聊上下文感知记录 At 信息
+- 优化模型提供商页面的测试提供商功能
+
+新增:
+- 支持在配置文件页面快速测试对话
+- 为配置文件配置项内容添加国际化支持
+
+修复:
+- 在更新 MCP Server 配置后,MCP 无法正常重启的问题
diff --git a/changelogs/v4.7.1.md b/changelogs/v4.7.1.md
new file mode 100644
index 000000000..ff8b8ba05
--- /dev/null
+++ b/changelogs/v4.7.1.md
@@ -0,0 +1,22 @@
+## What's Changed
+
+### 修复了自定义规则页面无法设置插件和知识库的规则的问题
+
+---
+
+重构:
+- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界。详见:[Agent 执行器](https://docs.astrbot.app/use/agent-runner.html)
+- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
+
+优化:
+- Dify、阿里云百炼应用支持流式输出
+- 防止分段回复正则表达式解析错误导致消息不发送
+- 群聊上下文感知记录 At 信息
+- 优化模型提供商页面的测试提供商功能
+
+新增:
+- 支持在配置文件页面快速测试对话
+- 为配置文件配置项内容添加国际化支持
+
+修复:
+- 在更新 MCP Server 配置后,MCP 无法正常重启的问题
diff --git a/changelogs/v4.7.3.md b/changelogs/v4.7.3.md
new file mode 100644
index 000000000..f5105d862
--- /dev/null
+++ b/changelogs/v4.7.3.md
@@ -0,0 +1,25 @@
+## What's Changed
+
+1. 修复使用非默认配置文件情况下时,第三方 Agent Runner (Dify、Coze、阿里云百炼应用等)无法正常工作的问题
+2. 修复当“聊天模型”未设置,并且模型提供商中仅有 Agent Runner 时,无法正常使用 Agent Runner 的问题
+3. 修复部分情况下报错 `pydantic_core._pydantic_core.ValidationError: 1 validation error for Message content` 的问题
+4. 新增群聊模式下的专用图片转述模型配置 ([#3822](https://github.com/AstrBotDevs/AstrBot/issues/3822))
+
+---
+
+重构:
+- 将 Dify、Coze、阿里云百炼应用等 LLMOps 提供商迁移到 Agent 执行器层,理清和本地 Agent 执行器的边界。详见:[Agent 执行器](https://docs.astrbot.app/use/agent-runner.html)
+- 将「会话管理」功能重构为「自定义规则」功能,理清和多配置文件功能的边界。详见:[自定义规则](https://docs.astrbot.app/use/custom-rules.html)
+
+优化:
+- Dify、阿里云百炼应用支持流式输出
+- 防止分段回复正则表达式解析错误导致消息不发送
+- 群聊上下文感知记录 At 信息
+- 优化模型提供商页面的测试提供商功能
+
+新增:
+- 支持在配置文件页面快速测试对话
+- 为配置文件配置项内容添加国际化支持
+
+修复:
+- 在更新 MCP Server 配置后,MCP 无法正常重启的问题
diff --git a/changelogs/v4.7.4.md b/changelogs/v4.7.4.md
new file mode 100644
index 000000000..3929f9744
--- /dev/null
+++ b/changelogs/v4.7.4.md
@@ -0,0 +1,7 @@
+## What's Changed
+
+1. 修复:assistant message 中 tool_call 存在但 content 不存在时,导致验证错误的问题 ([#3862](https://github.com/AstrBotDevs/AstrBot/issues/3862))
+2. 修复:fix: aiocqhttp 适配器 NapCat 文件名获取为空 ([#3853](https://github.com/AstrBotDevs/AstrBot/issues/3853))
+3. 新增:升级所有插件按钮
+4. 新增:/provider 指令支持同时测试提供商可用性
+5. 优化:主动回复的 prompt
\ No newline at end of file
diff --git a/changelogs/v4.8.0.md b/changelogs/v4.8.0.md
new file mode 100644
index 000000000..c0831c52d
--- /dev/null
+++ b/changelogs/v4.8.0.md
@@ -0,0 +1,15 @@
+## What's Changed
+
+**新增:**
+- 对部分需要 Webhook 的适配器(QQ 官方机器人、Slack、企业微信、微信客服、企业微信智能机器人、微信公众号)支持统一的 Webhook 链接模式,避免开多个端口。并支持在 WebUI 机器人卡片中查看和复制 Webhook 链接。详情请看:[统一 Webhook 模式](https://docs.astrbot.app/use/unified-webhook.html)
+- 新增 Kubernetes 部署文档。
+
+**修复:**
+- 修复:Telegram 和 QQ 场景下,使用 Whisper API 报错。
+- 修复:部分情况下 Slack 输出消息段代码的问题。
+- 修复:当启动了流式输出时,QQ 官方机器人适配器无法正常回复消息。
+- 修复:对话数据页的对话详情在暗夜模式下显示异常的问题。
+
+**优化:**
+- 重构:WebChat 的消息数据结构,支持引用回复、文件发送、时间显示等功能,优化思考内容显示的部分 Bug。
+- 优化:机器人页面支持显示报错信息,方便排查问题。
diff --git a/changelogs/v4.9.0.md b/changelogs/v4.9.0.md
new file mode 100644
index 000000000..aeccdb006
--- /dev/null
+++ b/changelogs/v4.9.0.md
@@ -0,0 +1,19 @@
+## What's Changed
+
+### 新增
+
+- 支持自定义插件源。
+- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
+- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
+
+### 优化
+
+- 从 WebUI 移除了开发版本渠道。
+- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
+- WebUI 列表项支持批量粘贴、回车创建项目。
+
+### 修复
+
+- Gemini API 部分调用失败的问题。
+- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
+- 部分情况下,WebUI 日志显示不全的问题。
\ No newline at end of file
diff --git a/changelogs/v4.9.1.md b/changelogs/v4.9.1.md
new file mode 100644
index 000000000..f7e4c2e5c
--- /dev/null
+++ b/changelogs/v4.9.1.md
@@ -0,0 +1,3 @@
+## What's Changed
+
+-
\ No newline at end of file
diff --git a/changelogs/v4.9.2.md b/changelogs/v4.9.2.md
new file mode 100644
index 000000000..87538c6ca
--- /dev/null
+++ b/changelogs/v4.9.2.md
@@ -0,0 +1,17 @@
+## What's Changed
+
+### 修复
+
+- 企业自部署飞书(自定义 domain)可以接收消息但无法发送消息的问题。
+- 安装插件 Dialog 的深色样式问题。
+
+### 优化
+
+- 避免某些插件在流式响应结束后重d复发送消息的问题。
+
+### 新增
+
+- 支持在对话管理批量导出对话轨迹数据为 `jsonl` 格式文件。入口:WebUI -> 对话管理 -> 批量选中 -> 导出。
+- 支持对 TTS(文本转语音)设置概率触发。
+- (插件开发)支持在 schema 中对 float 和 int 类型设置 `slider` 滑块控件。例如 `slider: {min: 0, max: 1, step: 0.1}`。
+- (插件开发)支持 key-value 存储功能。例如使用 `await self.put_kv_data("key", value)`, `await self.get_kv_data("key", default_value)` 和 `await self.delete_kv_data("key")`。
\ No newline at end of file
diff --git a/compose.yml b/compose.yml
index 2b3185301..99557a1d8 100644
--- a/compose.yml
+++ b/compose.yml
@@ -9,10 +9,9 @@ services:
restart: always
ports: # mappings description: https://github.com/AstrBotDevs/AstrBot/issues/497
- "6185:6185" # 必选,AstrBot WebUI 端口
- - "6195:6195" # 可选, 企业微信 Webhook 端口
- "6199:6199" # 可选, QQ 个人号 WebSocket 端口
- - "6196:6196" # 可选, QQ 官方接口 Webhook 端口
- - "11451:11451" # 可选, 微信个人号 Webhook 端口
+ # - "6195:6195" # 可选, 企业微信 Webhook 端口
+ # - "6196:6196" # 可选, QQ 官方接口 Webhook 端口
environment:
- TZ=Asia/Shanghai
volumes:
diff --git a/dashboard/src/assets/images/loading-seio.webp b/dashboard/src/assets/images/loading-seio.webp
new file mode 100644
index 000000000..62e159f98
Binary files /dev/null and b/dashboard/src/assets/images/loading-seio.webp differ
diff --git a/dashboard/src/assets/images/platform_logos/onebot.png b/dashboard/src/assets/images/platform_logos/onebot.png
new file mode 100644
index 000000000..70cc8829f
Binary files /dev/null and b/dashboard/src/assets/images/platform_logos/onebot.png differ
diff --git a/dashboard/src/assets/images/plugin_icon.png b/dashboard/src/assets/images/plugin_icon.png
new file mode 100644
index 000000000..7e4c4f3a9
Binary files /dev/null and b/dashboard/src/assets/images/plugin_icon.png differ
diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue
index 26c7df563..5524e787d 100644
--- a/dashboard/src/components/chat/Chat.vue
+++ b/dashboard/src/components/chat/Chat.vue
@@ -5,89 +5,20 @@
-
+
@@ -109,7 +40,7 @@
mdi-fullscreen
@@ -131,7 +62,7 @@
- mdi-fullscreen-exit
@@ -140,93 +71,39 @@
Hello, I'm
AstrBot ⭐
-
- {{ t('core.common.type') }}
- help
- {{ tm('shortcuts.help') }} 😊
-
-
- {{ t('core.common.longPress') }}
- Ctrl + B
- {{ tm('shortcuts.voiceRecord') }} 🎤
-
-
- {{ t('core.common.press') }}
- Ctrl + V
- {{ tm('shortcuts.pasteImage') }} 🏞️
-
-
+
@@ -242,8 +119,8 @@
- {{ t('core.common.cancel') }}
- {{ t('core.common.save') }}
+ {{ t('core.common.cancel') }}
+ {{ t('core.common.save') }}
@@ -262,980 +139,334 @@
-
-
\ No newline at end of file
+
+.mobile-menu-btn {
+ margin-right: 8px;
+}
+
+.conversation-header-actions {
+ display: flex;
+ gap: 8px;
+ align-items: center;
+}
+
+.fullscreen-icon {
+ cursor: pointer;
+ margin-left: 8px;
+}
+
+.welcome-container {
+ height: 100%;
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ flex-direction: column;
+}
+
+.welcome-title {
+ font-size: 28px;
+ margin-bottom: 16px;
+}
+
+.bot-name {
+ font-weight: 700;
+ margin-left: 8px;
+ color: var(--v-theme-secondary);
+}
+
+.fade-in {
+ animation: fadeIn 0.3s ease-in-out;
+}
+
+.dialog-title {
+ font-size: 18px;
+ font-weight: 500;
+ padding-bottom: 8px;
+}
+
+/* 手机端样式调整 */
+@media (max-width: 768px) {
+ .chat-content-panel {
+ width: 100%;
+ }
+
+ .chat-page-container {
+ padding: 0 !important;
+ }
+
+ .conversation-header {
+ padding: 2px;
+ }
+}
+
diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue
new file mode 100644
index 000000000..53e1e30c0
--- /dev/null
+++ b/dashboard/src/components/chat/ChatInput.vue
@@ -0,0 +1,397 @@
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/chat/ConfigSelector.vue b/dashboard/src/components/chat/ConfigSelector.vue
new file mode 100644
index 000000000..9cb5eeaac
--- /dev/null
+++ b/dashboard/src/components/chat/ConfigSelector.vue
@@ -0,0 +1,313 @@
+
+
+
+
+
+ mdi-cog
+ {{ selectedConfigLabel }}
+
+
+
+
+
+
+
+ 选择配置文件
+
+ mdi-close
+
+
+
+
+
+
+
+
+
+ {{ config.name }}
+
+ {{ config.id }}
+
+
+ mdi-check
+
+
+
+ 暂无可选配置,请先在配置页创建。
+
+
+
+
+
+ 取消
+
+ 应用
+
+
+
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue
new file mode 100644
index 000000000..062588854
--- /dev/null
+++ b/dashboard/src/components/chat/ConversationSidebar.vue
@@ -0,0 +1,303 @@
+
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/chat/MessageList.vue b/dashboard/src/components/chat/MessageList.vue
index 7ab592497..cd14c6574 100644
--- a/dashboard/src/components/chat/MessageList.vue
+++ b/dashboard/src/components/chat/MessageList.vue
@@ -5,64 +5,222 @@
-
-
{{ msg.content.message }}
-
-
-
-
-
+
+
+
+
+ mdi-reply
+ {{ getReplyContent(part.message_id) }}
-
-
-
-
-
- {{ t('messages.errors.browser.audioNotSupported') }}
-
-
+
+
{{ part.text }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('messages.errors.browser.audioNotSupported') }}
+
+
+
+
+
+
-
-
- mdi-star-four-points-small
+
+ mdi-star-four-points-small
-
-
+
+
+ {{ tm('message.loading') }}
+
-
-
-
-
+
+
+
-
-
-
-
-
- {{ t('messages.errors.browser.audioNotSupported') }}
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('messages.errors.browser.audioNotSupported') }}
+
+
+
+
+
+
+
-
-
+ {{ formatMessageTime(msg.created_at)
+ }}
+
+
+
+ mdi-information-outline
+
+
+
+
+
@@ -76,6 +234,7 @@ import { useI18n, useModuleI18n } from '@/i18n/composables';
import MarkdownIt from 'markdown-it';
import hljs from 'highlight.js';
import 'highlight.js/styles/github.css';
+import axios from 'axios';
const md = new MarkdownIt({
html: false,
@@ -109,7 +268,7 @@ export default {
default: false
}
},
- emits: ['openImagePreview'],
+ emits: ['openImagePreview', 'replyMessage'],
setup() {
const { t } = useI18n();
const { tm } = useModuleI18n('features/chat');
@@ -125,7 +284,12 @@ export default {
copiedMessages: new Set(),
isUserNearBottom: true,
scrollThreshold: 1,
- scrollTimer: null
+ scrollTimer: null,
+ expandedReasoning: new Set(), // Track which reasoning blocks are expanded
+ downloadingFiles: new Set(), // Track which files are being downloaded
+ expandedToolCalls: new Set(), // Track which tool call cards are expanded
+ elapsedTimeTimer: null, // Timer for updating elapsed time
+ currentTime: Date.now() / 1000, // Current time for elapsed time calculation
};
},
mounted() {
@@ -133,6 +297,7 @@ export default {
this.initImageClickEvents();
this.addScrollListener();
this.scrollToBottom();
+ this.startElapsedTimeTimer();
},
updated() {
this.initCodeCopyButtons();
@@ -142,6 +307,94 @@ export default {
}
},
methods: {
+ // 检查 message 中是否有音频
+ hasAudio(messageParts) {
+ if (!Array.isArray(messageParts)) return false;
+ return messageParts.some(part => part.type === 'record' && part.embedded_url);
+ },
+
+ // 获取被引用消息的内容
+ getReplyContent(messageId) {
+ const replyMsg = this.messages.find(m => m.id === messageId);
+ if (!replyMsg) {
+ return this.tm('reply.notFound');
+ }
+ let content = '';
+ if (Array.isArray(replyMsg.content.message)) {
+ const textParts = replyMsg.content.message
+ .filter(part => part.type === 'plain' && part.text)
+ .map(part => part.text);
+ content = textParts.join('');
+ }
+ // 截断过长内容
+ if (content.length > 50) {
+ content = content.substring(0, 50) + '...';
+ }
+ return content || '[媒体内容]';
+ },
+
+ // 滚动到指定消息
+ scrollToMessage(messageId) {
+ const msgIndex = this.messages.findIndex(m => m.id === messageId);
+ if (msgIndex === -1) return;
+
+ const container = this.$refs.messageContainer;
+ const messageItems = container?.querySelectorAll('.message-item');
+ if (messageItems && messageItems[msgIndex]) {
+ messageItems[msgIndex].scrollIntoView({ behavior: 'smooth', block: 'center' });
+ // 高亮一下
+ messageItems[msgIndex].classList.add('highlight-message');
+ setTimeout(() => {
+ messageItems[msgIndex].classList.remove('highlight-message');
+ }, 2000);
+ }
+ },
+
+ // Toggle reasoning expansion state
+ toggleReasoning(messageIndex) {
+ if (this.expandedReasoning.has(messageIndex)) {
+ this.expandedReasoning.delete(messageIndex);
+ } else {
+ this.expandedReasoning.add(messageIndex);
+ }
+ // Force reactivity
+ this.expandedReasoning = new Set(this.expandedReasoning);
+ },
+
+ // Check if reasoning is expanded
+ isReasoningExpanded(messageIndex) {
+ return this.expandedReasoning.has(messageIndex);
+ },
+
+ // 下载文件
+ async downloadFile(file) {
+ if (!file.attachment_id) return;
+
+ // 标记为下载中
+ this.downloadingFiles.add(file.attachment_id);
+ this.downloadingFiles = new Set(this.downloadingFiles);
+
+ try {
+ const response = await axios.get(`/api/chat/get_attachment?attachment_id=${file.attachment_id}`, {
+ responseType: 'blob'
+ });
+
+ const url = URL.createObjectURL(response.data);
+ const a = document.createElement('a');
+ a.href = url;
+ a.download = file.filename || 'file';
+ document.body.appendChild(a);
+ a.click();
+ document.body.removeChild(a);
+ setTimeout(() => URL.revokeObjectURL(url), 100);
+ } catch (err) {
+ console.error('Download file failed:', err);
+ } finally {
+ this.downloadingFiles.delete(file.attachment_id);
+ this.downloadingFiles = new Set(this.downloadingFiles);
+ }
+ },
+
// 复制代码到剪贴板
copyCodeToClipboard(code) {
navigator.clipboard.writeText(code).then(() => {
@@ -164,29 +417,29 @@ export default {
},
// 复制bot消息到剪贴板
- copyBotMessage(message, messageIndex) {
- // 获取对应的消息对象
- const msgObj = this.messages[messageIndex].content;
+ copyBotMessage(messageParts, messageIndex) {
let textToCopy = '';
- // 如果有文本消息,添加到复制内容中
- if (message && message.trim()) {
- // 移除HTML标签,获取纯文本
- const tempDiv = document.createElement('div');
- tempDiv.innerHTML = message;
- textToCopy = tempDiv.textContent || tempDiv.innerText || message;
- }
+ if (Array.isArray(messageParts)) {
+ // 提取所有文本内容
+ const textContents = messageParts
+ .filter(part => part.type === 'plain' && part.text)
+ .map(part => part.text);
+ textToCopy = textContents.join('\n');
- // 如果有内嵌图片,添加说明
- if (msgObj && msgObj.embedded_images && msgObj.embedded_images.length > 0) {
- if (textToCopy) textToCopy += '\n\n';
- textToCopy += `[包含 ${msgObj.embedded_images.length} 张图片]`;
- }
+ // 检查是否有图片
+ const imageCount = messageParts.filter(part => part.type === 'image' && part.embedded_url).length;
+ if (imageCount > 0) {
+ if (textToCopy) textToCopy += '\n\n';
+ textToCopy += `[包含 ${imageCount} 张图片]`;
+ }
- // 如果有内嵌音频,添加说明
- if (msgObj && msgObj.embedded_audio) {
- if (textToCopy) textToCopy += '\n\n';
- textToCopy += '[包含音频内容]';
+ // 检查是否有音频
+ const hasAudio = messageParts.some(part => part.type === 'record' && part.embedded_url);
+ if (hasAudio) {
+ if (textToCopy) textToCopy += '\n\n';
+ textToCopy += '[包含音频内容]';
+ }
}
// 如果没有任何内容,使用默认文本
@@ -338,6 +591,150 @@ export default {
clearTimeout(this.scrollTimer);
this.scrollTimer = null;
}
+ // 清理 elapsed time 计时器
+ if (this.elapsedTimeTimer) {
+ clearInterval(this.elapsedTimeTimer);
+ this.elapsedTimeTimer = null;
+ }
+ },
+
+ // 格式化消息时间,支持别名显示
+ formatMessageTime(dateStr) {
+ if (!dateStr) return '';
+
+ const date = new Date(dateStr);
+ const now = new Date();
+
+ // 获取本地时间的日期部分
+ const dateDay = new Date(date.getFullYear(), date.getMonth(), date.getDate());
+ const todayDay = new Date(now.getFullYear(), now.getMonth(), now.getDate());
+ const yesterdayDay = new Date(todayDay);
+ yesterdayDay.setDate(yesterdayDay.getDate() - 1);
+
+ // 格式化时间 HH:MM
+ const hours = date.getHours().toString().padStart(2, '0');
+ const minutes = date.getMinutes().toString().padStart(2, '0');
+ const timeStr = `${hours}:${minutes}`;
+
+ // 判断是今天、昨天还是更早
+ if (dateDay.getTime() === todayDay.getTime()) {
+ return `${this.tm('time.today')} ${timeStr}`;
+ } else if (dateDay.getTime() === yesterdayDay.getTime()) {
+ return `${this.tm('time.yesterday')} ${timeStr}`;
+ } else {
+ // 更早的日期显示完整格式
+ const month = (date.getMonth() + 1).toString().padStart(2, '0');
+ const day = date.getDate().toString().padStart(2, '0');
+ return `${month}-${day} ${timeStr}`;
+ }
+ },
+
+ // Tool call related methods
+ toggleToolCall(messageIndex, partIndex, toolCallIndex) {
+ const key = `${messageIndex}-${partIndex}-${toolCallIndex}`;
+ if (this.expandedToolCalls.has(key)) {
+ this.expandedToolCalls.delete(key);
+ } else {
+ this.expandedToolCalls.add(key);
+ }
+ // Force reactivity
+ this.expandedToolCalls = new Set(this.expandedToolCalls);
+ },
+
+ isToolCallExpanded(messageIndex, partIndex, toolCallIndex) {
+ return this.expandedToolCalls.has(`${messageIndex}-${partIndex}-${toolCallIndex}`);
+ },
+
+ // Start timer for updating elapsed time
+ startElapsedTimeTimer() {
+ // Update every 12ms for sub-second precision, then every second after 1s
+ let fastUpdateCount = 0;
+ const fastUpdateInterval = 12;
+ const slowUpdateInterval = 1000;
+
+ const updateTime = () => {
+ this.currentTime = Date.now() / 1000;
+
+ // Check if there are any running tool calls
+ const hasRunningToolCalls = this.messages.some(msg =>
+ Array.isArray(msg.content.message) && msg.content.message.some(part =>
+ part.type === 'tool_call' && part.tool_calls?.some(tc => !tc.finished_ts)
+ )
+ );
+
+ if (hasRunningToolCalls) {
+ // Check if any running tool call is under 1 second
+ const hasSubSecondToolCall = this.messages.some(msg =>
+ Array.isArray(msg.content.message) && msg.content.message.some(part =>
+ part.type === 'tool_call' && part.tool_calls?.some(tc =>
+ !tc.finished_ts && (this.currentTime - tc.ts) < 1
+ )
+ )
+ );
+
+ if (hasSubSecondToolCall) {
+ fastUpdateCount++;
+ this.elapsedTimeTimer = setTimeout(updateTime, fastUpdateInterval);
+ } else {
+ this.elapsedTimeTimer = setTimeout(updateTime, slowUpdateInterval);
+ }
+ } else {
+ // No running tool calls, check again after 1 second
+ this.elapsedTimeTimer = setTimeout(updateTime, slowUpdateInterval);
+ }
+ };
+
+ updateTime();
+ },
+
+ // Get elapsed time string for a tool call
+ getElapsedTime(startTs) {
+ const elapsed = this.currentTime - startTs;
+ return this.formatDuration(elapsed);
+ },
+
+ // Format duration in seconds to human readable string
+ formatDuration(seconds) {
+ if (seconds < 1) {
+ return `${Math.round(seconds * 1000)}ms`;
+ } else if (seconds < 60) {
+ return `${seconds.toFixed(1)}s`;
+ } else {
+ const minutes = Math.floor(seconds / 60);
+ const secs = Math.round(seconds % 60);
+ return `${minutes}m ${secs}s`;
+ }
+ },
+
+ // Format tool result for display
+ formatToolResult(result) {
+ if (!result) return '';
+ // Try to parse as JSON for pretty formatting
+ try {
+ const parsed = JSON.parse(result);
+ return JSON.stringify(parsed, null, 2);
+ } catch {
+ return result;
+ }
+ },
+
+ // Get input tokens (input_other + input_cached)
+ getInputTokens(tokenUsage) {
+ if (!tokenUsage) return 0;
+ return (tokenUsage.input_other || 0) + (tokenUsage.input_cached || 0);
+ },
+
+ // Format agent duration
+ formatAgentDuration(agentStats) {
+ if (!agentStats) return '';
+ const duration = agentStats.end_time - agentStats.start_time;
+ return this.formatDuration(duration);
+ },
+
+ // Format time to first token
+ formatTTFT(ttft) {
+ if (!ttft || ttft <= 0) return '';
+ return this.formatDuration(ttft);
}
}
}
@@ -348,7 +745,7 @@ export default {
@keyframes fadeIn {
from {
opacity: 0;
- transform: translateY(10px);
+ transform: translateY(0);
}
to {
@@ -368,6 +765,22 @@ export default {
min-height: 0;
}
+.message-bubble {
+ padding: 2px 16px;
+ border-radius: 12px;
+}
+
+
+@media (max-width: 768px) {
+ .messages-container {
+ padding: 0;
+ }
+
+ .message-bubble {
+ padding: 2px 8px;
+ }
+}
+
/* 消息列表样式 */
.message-list {
max-width: 900px;
@@ -376,7 +789,7 @@ export default {
}
.message-item {
- margin-bottom: 24px;
+ margin-bottom: 12px;
animation: fadeIn 0.3s ease-out;
}
@@ -404,10 +817,36 @@ export default {
.message-actions {
display: flex;
- gap: 4px;
+ align-items: center;
+ gap: 8px;
opacity: 0;
transition: opacity 0.2s ease;
- margin-left: 8px;
+ margin-left: 16px;
+}
+
+/* 最后一条消息始终显示操作按钮 */
+.message-item:last-child .message-actions {
+ opacity: 1;
+}
+
+.message-time {
+ font-size: 12px;
+ color: var(--v-theme-secondaryText);
+ opacity: 0.7;
+ white-space: nowrap;
+}
+
+/* Agent Stats Info Icon */
+.stats-info-icon {
+ margin-left: 6px;
+ color: var(--v-theme-secondaryText);
+ opacity: 0.6;
+ cursor: pointer;
+ transition: opacity 0.2s ease;
+}
+
+.stats-info-icon:hover {
+ opacity: 1;
}
.bot-message:hover .message-actions {
@@ -435,11 +874,64 @@ export default {
background-color: rgba(76, 175, 80, 0.1);
}
-.message-bubble {
- padding: 2px 16px;
- border-radius: 12px;
+.reply-message-btn {
+ opacity: 0.6;
+ transition: all 0.2s ease;
+ color: var(--v-theme-secondary);
}
+.reply-message-btn:hover {
+ opacity: 1;
+ background-color: rgba(103, 58, 183, 0.1);
+}
+
+/* 引用消息显示样式 */
+.reply-quote {
+ display: flex;
+ align-items: center;
+ gap: 6px;
+ padding: 6px 10px;
+ margin-bottom: 8px;
+ background-color: rgba(103, 58, 183, 0.08);
+ border-left: 3px solid var(--v-theme-secondary);
+ border-radius: 4px;
+ cursor: pointer;
+ transition: background-color 0.2s ease;
+}
+
+.reply-quote:hover {
+ background-color: rgba(103, 58, 183, 0.15);
+}
+
+.reply-quote-icon {
+ color: var(--v-theme-secondary);
+ flex-shrink: 0;
+}
+
+.reply-quote-text {
+ font-size: 13px;
+ color: var(--v-theme-secondaryText);
+ overflow: hidden;
+ text-overflow: ellipsis;
+ white-space: nowrap;
+}
+
+/* 消息高亮动画 */
+.highlight-message {
+ animation: highlightPulse 2s ease-out;
+}
+
+@keyframes highlightPulse {
+ 0% {
+ background-color: rgba(103, 58, 183, 0.3);
+ }
+
+ 100% {
+ background-color: transparent;
+ }
+}
+
+
.user-bubble {
color: var(--v-theme-primaryText);
padding: 12px 18px;
@@ -512,19 +1004,14 @@ export default {
}
.bot-embedded-image {
- max-width: 80%;
+ max-width: 40%;
width: auto;
height: auto;
border-radius: 8px;
- box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
cursor: pointer;
transition: transform 0.2s ease;
}
-.bot-embedded-image:hover {
- transform: scale(1.02);
-}
-
.embedded-audio {
width: 300px;
margin-top: 8px;
@@ -535,10 +1022,307 @@ export default {
max-width: 300px;
}
+/* 文件附件样式 */
+.file-attachments,
+.embedded-files {
+ margin-top: 8px;
+ display: flex;
+ flex-direction: column;
+ gap: 6px;
+}
+
+.file-attachment,
+.embedded-file {
+ display: flex;
+ align-items: center;
+}
+
+.file-link {
+ display: inline-flex;
+ align-items: center;
+ gap: 6px;
+ padding: 8px 12px;
+ background-color: rgba(var(--v-theme-primary), 0.08);
+ border: 1px solid rgba(var(--v-theme-primary), 0.2);
+ border-radius: 8px;
+ color: rgb(var(--v-theme-primary));
+ text-decoration: none;
+ font-size: 14px;
+ transition: all 0.2s ease;
+ max-width: 300px;
+}
+
+.file-link-download {
+ cursor: pointer;
+}
+
+.download-icon {
+ margin-left: 4px;
+ opacity: 0.7;
+}
+
+.file-icon {
+ flex-shrink: 0;
+ color: rgb(var(--v-theme-primary));
+}
+
+.file-name {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ white-space: nowrap;
+}
+
+.v-theme--dark .file-link {
+ background-color: rgba(255, 255, 255, 0.05);
+ border-color: rgba(255, 255, 255, 0.1);
+ color: var(--v-theme-secondary);
+}
+
+.v-theme--dark .file-link:hover {
+ background-color: rgba(255, 255, 255, 0.1);
+ border-color: rgba(255, 255, 255, 0.2);
+}
+
+.v-theme--dark .file-icon {
+ color: var(--v-theme-secondary);
+}
+
/* 动画类 */
.fade-in {
animation: fadeIn 0.3s ease-in-out;
}
+
+/* Reasoning 区块样式 */
+.reasoning-container {
+ margin-bottom: 12px;
+ margin-top: 6px;
+ border: 1px solid var(--v-theme-border);
+ border-radius: 8px;
+ overflow: hidden;
+ width: fit-content;
+}
+
+.v-theme--dark .reasoning-container {
+ background-color: rgba(103, 58, 183, 0.08);
+}
+
+.reasoning-header {
+ display: inline-flex;
+ align-items: center;
+ padding: 8px 8px;
+ cursor: pointer;
+ user-select: none;
+ transition: background-color 0.2s ease;
+ border-radius: 8px;
+}
+
+.reasoning-header:hover {
+ background-color: rgba(103, 58, 183, 0.08);
+}
+
+.v-theme--dark .reasoning-header:hover {
+ background-color: rgba(103, 58, 183, 0.15);
+}
+
+.reasoning-icon {
+ margin-right: 6px;
+ color: var(--v-theme-secondary);
+ transition: transform 0.2s ease;
+}
+
+.reasoning-label {
+ font-size: 13px;
+ font-weight: 500;
+ color: var(--v-theme-secondary);
+ letter-spacing: 0.3px;
+}
+
+.reasoning-content {
+ padding: 0px 12px;
+ border-top: 1px solid var(--v-theme-border);
+ color: gray;
+ animation: fadeIn 0.2s ease-in-out;
+ font-style: italic;
+}
+
+.reasoning-text {
+ font-size: 14px;
+ line-height: 1.6;
+ color: var(--v-theme-secondaryText);
+}
+
+.v-theme--dark .reasoning-text {
+ opacity: 0.85;
+}
+
+/* Tool Call Card Styles */
+.tool-calls-container {
+ display: flex;
+ flex-direction: column;
+ gap: 8px;
+ margin-bottom: 12px;
+ margin-top: 6px;
+}
+
+.tool-call-card {
+ border-radius: 8px;
+ overflow: hidden;
+ background-color: #eff3f6;
+ margin: 8px 0px;
+}
+
+.v-theme--dark .tool-call-card {
+ background-color: rgba(40, 60, 100, 0.4);
+ border-color: rgba(100, 140, 200, 0.4);
+}
+
+.tool-call-header {
+ display: flex;
+ align-items: center;
+ padding: 10px 12px;
+ cursor: pointer;
+ user-select: none;
+ transition: background-color 0.2s ease;
+ gap: 8px;
+}
+
+.tool-call-header:hover {
+ background-color: rgba(169, 194, 219, 0.15);
+}
+
+.v-theme--dark .tool-call-header:hover {
+ background-color: rgba(100, 150, 200, 0.2);
+}
+
+.tool-call-expand-icon {
+ color: var(--v-theme-secondary);
+ transition: transform 0.2s ease;
+ flex-shrink: 0;
+}
+
+.tool-call-icon {
+ color: var(--v-theme-secondary);
+ flex-shrink: 0;
+}
+
+.tool-call-info {
+ display: flex;
+ flex-direction: column;
+ gap: 2px;
+ flex: 1;
+ min-width: 0;
+}
+
+.tool-call-name {
+ font-size: 13px;
+ font-weight: 600;
+ color: var(--v-theme-secondary);
+}
+
+.tool-call-id {
+ font-size: 11px;
+ color: var(--v-theme-secondaryText);
+ opacity: 0.7;
+ overflow: hidden;
+ text-overflow: ellipsis;
+ white-space: nowrap;
+}
+
+.tool-call-status {
+ margin-left: 8px;
+ display: flex;
+ align-items: center;
+ gap: 4px;
+ font-size: 12px;
+ font-weight: 500;
+ flex-shrink: 0;
+}
+
+.tool-call-status.status-running {
+ color: #ff9800;
+}
+
+.tool-call-status.status-finished {
+ color: #4caf50;
+}
+
+.tool-call-status .status-icon {
+ font-size: 14px;
+}
+
+.tool-call-status .status-icon.spinning {
+ animation: spin 1s linear infinite;
+}
+
+@keyframes spin {
+ from {
+ transform: rotate(0deg);
+ }
+
+ to {
+ transform: rotate(360deg);
+ }
+}
+
+.tool-call-details {
+ padding: 12px;
+ background-color: rgba(255, 255, 255, 0.5);
+ animation: fadeIn 0.2s ease-in-out;
+}
+
+.v-theme--dark .tool-call-details {
+ border-top-color: rgba(100, 140, 200, 0.3);
+ background-color: rgba(30, 45, 70, 0.5);
+}
+
+.tool-call-detail-row {
+ display: flex;
+ flex-direction: column;
+ margin-bottom: 8px;
+}
+
+.tool-call-detail-row:last-child {
+ margin-bottom: 0;
+}
+
+.detail-label {
+ font-size: 11px;
+ font-weight: 600;
+ color: var(--v-theme-secondaryText);
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+ margin-bottom: 4px;
+}
+
+.detail-value {
+ font-size: 12px;
+ color: var(--v-theme-primaryText);
+ background-color: transparent;
+ padding: 4px 8px;
+ border-radius: 4px;
+ word-break: break-all;
+}
+
+.detail-json {
+ font-family: 'Fira Code', 'Consolas', monospace;
+ white-space: pre-wrap;
+ max-height: 200px;
+ overflow-y: auto;
+ margin: 0;
+}
+
+.detail-result {
+ max-height: 300px;
+ background-color: transparent;
+}
+
+.v-theme--dark .detail-value {
+ background-color: transparent;
+}
+
+.v-theme--dark .detail-result {
+ background-color: transparent;
+}
diff --git a/dashboard/src/components/chat/ProviderModelSelector.vue b/dashboard/src/components/chat/ProviderModelSelector.vue
index 0b983e416..b1b9f9fb4 100644
--- a/dashboard/src/components/chat/ProviderModelSelector.vue
+++ b/dashboard/src/components/chat/ProviderModelSelector.vue
@@ -3,6 +3,7 @@
+ mdi-creation
{{ selectedProviderId }} / {{ selectedModelName }}
@@ -10,7 +11,7 @@
-
+
选择提供商和模型
diff --git a/dashboard/src/components/chat/StandaloneChat.vue b/dashboard/src/components/chat/StandaloneChat.vue
new file mode 100644
index 000000000..2dcc8aeb8
--- /dev/null
+++ b/dashboard/src/components/chat/StandaloneChat.vue
@@ -0,0 +1,324 @@
+
+
+
+
+
+
+
+
+
+ Hello, I'm
+ AstrBot ⭐
+
+
+ 测试配置: {{ configId || 'default' }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('core.common.imagePreview') }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/config/AstrBotCoreConfigWrapper.vue b/dashboard/src/components/config/AstrBotCoreConfigWrapper.vue
index f3d5abe14..2a84989df 100644
--- a/dashboard/src/components/config/AstrBotCoreConfigWrapper.vue
+++ b/dashboard/src/components/config/AstrBotCoreConfigWrapper.vue
@@ -4,7 +4,7 @@
:align-tabs="$vuetify.display.mobile ? 'left' : 'start'" color="deep-purple-accent-4" class="config-tabs">
- {{ metadata[key]['name'] }}
+ {{ tm(metadata[key]['name']) }}
@@ -59,7 +59,17 @@ export default {
}
},
setup() {
- const { tm } = useModuleI18n('features/config');
+ const { tm: tmConfig } = useModuleI18n('features/config');
+ const { tm: tmMetadata } = useModuleI18n('features/config-metadata');
+
+ const tm = (key) => {
+ const metadataResult = tmMetadata(key);
+ if (!metadataResult.startsWith('[MISSING:') && !metadataResult.startsWith('[INVALID:')) {
+ return metadataResult;
+ }
+ return tmConfig(key);
+ };
+
return {
tm
};
diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/components/extension/McpServersSection.vue
similarity index 62%
rename from dashboard/src/views/ToolUsePage.vue
rename to dashboard/src/components/extension/McpServersSection.vue
index db8fee905..fe20497f8 100644
--- a/dashboard/src/views/ToolUsePage.vue
+++ b/dashboard/src/components/extension/McpServersSection.vue
@@ -4,42 +4,18 @@
-
- mdi-function-variant {{ tm('title') }}
-
-
- {{ tm('subtitle') }}
-
-
-
- mdi-information
-
-
- {{ tm('tooltip.info') }}
-
-
-
-
-
- {{ tm('functionTools.buttons.view') }}({{ tools.length }})
-
+ @click="showMcpServerDialog = true" >
{{ tm('mcpServers.buttons.add') }}
+ >
{{ tm('mcpServers.buttons.sync') }}
-
-
-
mdi-server-off
{{ tm('mcpServers.empty') }}
@@ -57,7 +33,6 @@
-
@@ -67,8 +42,7 @@
- {{ tm('mcpServers.status.availableTools', { count: item.tools.length }) }} ({{
- item.tools.length }})
+ {{ tm('mcpServers.status.availableTools', { count: item.tools.length }) }} ({{ item.tools.length }})
@@ -78,10 +52,7 @@
- {{
- tool
- }}
-
+ {{ tool }}
@@ -91,8 +62,6 @@
-
-
@@ -105,8 +74,6 @@
-
-
@@ -183,8 +150,7 @@
-
-
+
@@ -240,115 +206,8 @@
-
-
-
-
- {{ tm('functionTools.title') }}
- {{ tools.length }}
-
-
-
-
-
-
mdi-api-off
-
{{ tm('functionTools.empty') }}
-
-
-
-
-
-
复选框代表该工具是否被启用。
-
-
-
-
-
-
-
-
-
-
-
- {{ tool.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
-
-
- {{ formatToolName(tool.name) }}
-
-
-
-
- {{ tool.description }}
-
-
-
-
-
-
-
-
- mdi-information
- {{ tm('functionTools.description') }}
-
- {{ tool.description }}
-
-
-
- mdi-code-json
- {{ tm('functionTools.parameters') }}
-
-
-
-
-
- {{ tm('functionTools.table.paramName') }}
- {{ tm('functionTools.table.type') }}
- {{ tm('functionTools.table.description') }}
-
-
-
-
- {{ paramName }}
-
-
- {{ param.type }}
-
-
- {{ param.description }}
-
-
-
-
-
-
mdi-code-brackets
-
{{ tm('functionTools.noParameters') }}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- {{ tm('dialogs.serverDetail.buttons.close') }}
-
-
-
-
-
-
+
{{ save_message }}
@@ -356,15 +215,13 @@
\ No newline at end of file
+
diff --git a/dashboard/src/components/extension/componentPanel/components/CommandFilters.vue b/dashboard/src/components/extension/componentPanel/components/CommandFilters.vue
new file mode 100644
index 000000000..c4b212803
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/components/CommandFilters.vue
@@ -0,0 +1,155 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ tm('filters.showSystemPlugins') }}
+
+
+ mdi-alert-circle
+
+ {{ tm('filters.systemPluginConflictHint') }}
+
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/extension/componentPanel/components/CommandTable.vue b/dashboard/src/components/extension/componentPanel/components/CommandTable.vue
new file mode 100644
index 000000000..f8bb6fa82
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/components/CommandTable.vue
@@ -0,0 +1,257 @@
+
+
+
+
+
+
+
+
+
+ {{ isGroupExpanded(item) ? 'mdi-chevron-down' : 'mdi-chevron-right' }}
+
+
+
+
+
+ {{ item.effective_command }}
+
+
+
+
+
+
+
+ {{ getTypeInfo(item.type).icon }}
+ {{ getTypeInfo(item.type).text }}{{ item.is_group && item.sub_commands?.length > 0 ? `(${item.sub_commands.length})` : '' }}
+
+
+
+
+ {{ item.plugin_display_name || item.plugin }}
+
+
+
+
+ {{ item.description || '-' }}
+
+
+
+
+
+ {{ getPermissionLabel(item.permission) }}
+
+
+
+
+
+ {{ getStatusInfo(item).text }}
+
+
+
+
+
+
+
+ mdi-play
+ {{ tm('tooltips.enable') }}
+
+
+ mdi-pause
+ {{ tm('tooltips.disable') }}
+
+
+
+ mdi-pencil
+ {{ tm('tooltips.rename') }}
+
+
+
+ mdi-information
+ {{ tm('tooltips.viewDetails') }}
+
+
+
+
+
+
+
+
mdi-console-line
+
{{ tm('empty.noCommands') }}
+
{{ tm('empty.noCommandsDesc') }}
+
+
+
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/extension/componentPanel/components/DetailsDialog.vue b/dashboard/src/components/extension/componentPanel/components/DetailsDialog.vue
new file mode 100644
index 000000000..6d9188374
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/components/DetailsDialog.vue
@@ -0,0 +1,143 @@
+
+
+
+
+
+ {{ tm('dialogs.details.title') }}
+
+
+
+ {{ tm('dialogs.details.type') }}
+
+
+ {{ getTypeInfo(command.type).icon }}
+ {{ getTypeInfo(command.type).text }}
+
+
+
+
+ {{ tm('dialogs.details.handler') }}
+ {{ command.handler_name }}
+
+
+ {{ tm('dialogs.details.module') }}
+ {{ command.module_path }}
+
+
+ {{ tm('dialogs.details.originalCommand') }}
+ {{ command.original_command }}
+
+
+ {{ tm('dialogs.details.effectiveCommand') }}
+ {{ command.effective_command }}
+
+
+ {{ tm('dialogs.details.parentGroup') }}
+ {{ command.parent_signature }}
+
+
+ {{ tm('dialogs.details.aliases') }}
+
+
+ {{ alias }}
+
+
+
+
+ {{ tm('dialogs.details.subCommands') }}
+
+
+
+ {{ sub.current_fragment }}
+
+
+
+
+
+ {{ tm('dialogs.details.permission') }}
+
+
+ {{ getPermissionLabel(command.permission) }}
+
+
+
+
+ {{ tm('dialogs.details.conflictStatus') }}
+
+ {{ tm('status.conflict') }}
+
+
+
+
+
+
+
+ {{ t('core.actions.close') }}
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/extension/componentPanel/components/RenameDialog.vue b/dashboard/src/components/extension/componentPanel/components/RenameDialog.vue
new file mode 100644
index 000000000..ffdc5a826
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/components/RenameDialog.vue
@@ -0,0 +1,53 @@
+
+
+
+
+
+ {{ tm('dialogs.rename.title') }}
+
+
+
+
+
+
+ {{ tm('dialogs.rename.cancel') }}
+
+
+ {{ tm('dialogs.rename.confirm') }}
+
+
+
+
+
diff --git a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue
new file mode 100644
index 000000000..1b6fecfc1
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue
@@ -0,0 +1,144 @@
+
+
+
+
+
+
+
+
+ {{ item.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
+
+
+
+
+
+
+
+ {{ item.description || '-' }}
+
+
+
+
+
+ {{ item.origin || '-' }}
+
+
+
+
+
+ {{ item.origin_name || '-' }}
+
+
+
+
+
+ {{ item.active ? tmCommand('status.enabled') : tmCommand('status.disabled') }}
+
+
+
+
+
+
+
+
+
+
mdi-function-variant
+
{{ tmTool('functionTools.empty') }}
+
+
+
+
+
+
+
mdi-code-json
+
+
{{ tmTool('functionTools.parameters') }}
+
+ {{ tmTool('functionTools.noParameters') }}
+
+
+
+
+ {{ tmTool('functionTools.table.paramName') }}
+ {{ tmTool('functionTools.table.type') }}
+ {{ tmTool('functionTools.table.description') }}
+
+
+
+
+ {{ paramName }}
+
+
+ {{ param?.type || '-' }}
+
+
+ {{ param?.description || '-' }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dashboard/src/components/extension/componentPanel/composables/useCommandActions.ts b/dashboard/src/components/extension/componentPanel/composables/useCommandActions.ts
new file mode 100644
index 000000000..a285c473f
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/composables/useCommandActions.ts
@@ -0,0 +1,177 @@
+/**
+ * 指令操作方法 Composable
+ */
+import { reactive } from 'vue';
+import axios from 'axios';
+import type { CommandItem, RenameDialogState, DetailsDialogState, TypeInfo, StatusInfo } from '../types';
+
+export function useCommandActions(
+ toast: (message: string, color?: string) => void,
+ fetchCommands: () => Promise
+) {
+ // 重命名对话框状态
+ const renameDialog = reactive({
+ show: false,
+ command: null,
+ newName: '',
+ loading: false
+ });
+
+ // 详情对话框状态
+ const detailsDialog = reactive({
+ show: false,
+ command: null
+ });
+
+ /**
+ * 切换指令启用/禁用状态
+ */
+ const toggleCommand = async (
+ cmd: CommandItem,
+ successMessage: string,
+ errorMessage: string
+ ) => {
+ try {
+ const res = await axios.post('/api/commands/toggle', {
+ handler_full_name: cmd.handler_full_name,
+ enabled: !cmd.enabled
+ });
+ if (res.data.status === 'ok') {
+ toast(successMessage, 'success');
+ await fetchCommands();
+ } else {
+ toast(res.data.message || errorMessage, 'error');
+ }
+ } catch (err: any) {
+ toast(err?.message || errorMessage, 'error');
+ }
+ };
+
+ /**
+ * 打开重命名对话框
+ */
+ const openRenameDialog = (cmd: CommandItem) => {
+ renameDialog.command = cmd;
+ renameDialog.newName = cmd.current_fragment || '';
+ renameDialog.show = true;
+ };
+
+ /**
+ * 确认重命名
+ */
+ const confirmRename = async (successMessage: string, errorMessage: string) => {
+ if (!renameDialog.command || !renameDialog.newName.trim()) return;
+
+ renameDialog.loading = true;
+ try {
+ const res = await axios.post('/api/commands/rename', {
+ handler_full_name: renameDialog.command.handler_full_name,
+ new_name: renameDialog.newName.trim()
+ });
+ if (res.data.status === 'ok') {
+ toast(successMessage, 'success');
+ renameDialog.show = false;
+ await fetchCommands();
+ } else {
+ toast(res.data.message || errorMessage, 'error');
+ }
+ } catch (err: any) {
+ toast(err?.message || errorMessage, 'error');
+ } finally {
+ renameDialog.loading = false;
+ }
+ };
+
+ /**
+ * 打开详情对话框
+ */
+ const openDetailsDialog = (cmd: CommandItem) => {
+ detailsDialog.command = cmd;
+ detailsDialog.show = true;
+ };
+
+ /**
+ * 获取类型显示信息
+ */
+ const getTypeInfo = (type: string, translations: { group: string; subCommand: string; command: string }): TypeInfo => {
+ switch (type) {
+ case 'group':
+ return { text: translations.group, color: 'info', icon: 'mdi-folder-outline' };
+ case 'sub_command':
+ return { text: translations.subCommand, color: 'secondary', icon: 'mdi-subdirectory-arrow-right' };
+ default:
+ return { text: translations.command, color: 'primary', icon: 'mdi-console-line' };
+ }
+ };
+
+ /**
+ * 获取权限颜色
+ */
+ const getPermissionColor = (permission: string): string => {
+ switch (permission) {
+ case 'admin': return 'error';
+ default: return 'success';
+ }
+ };
+
+ /**
+ * 获取权限标签
+ */
+ const getPermissionLabel = (permission: string, translations: { admin: string; everyone: string }): string => {
+ switch (permission) {
+ case 'admin': return translations.admin;
+ default: return translations.everyone;
+ }
+ };
+
+ /**
+ * 获取状态显示信息
+ */
+ const getStatusInfo = (
+ cmd: CommandItem,
+ translations: { conflict: string; enabled: string; disabled: string }
+ ): StatusInfo => {
+ if (cmd.has_conflict) {
+ return { text: translations.conflict, color: 'warning', variant: 'flat' };
+ }
+ if (cmd.enabled) {
+ return { text: translations.enabled, color: 'success', variant: 'flat' };
+ }
+ return { text: translations.disabled, color: 'error', variant: 'outlined' };
+ };
+
+ /**
+ * 获取表格行属性(用于冲突高亮和子指令样式)
+ */
+ const getRowProps = ({ item }: { item: CommandItem }) => {
+ const classes: string[] = [];
+ if (item.has_conflict) {
+ classes.push('conflict-row');
+ }
+ if (item.type === 'sub_command') {
+ classes.push('sub-command-row');
+ }
+ if (item.is_group) {
+ classes.push('group-row');
+ }
+ return classes.length > 0 ? { class: classes.join(' ') } : {};
+ };
+
+ return {
+ // 状态
+ renameDialog,
+ detailsDialog,
+
+ // 方法
+ toggleCommand,
+ openRenameDialog,
+ confirmRename,
+ openDetailsDialog,
+ getTypeInfo,
+ getPermissionColor,
+ getPermissionLabel,
+ getStatusInfo,
+ getRowProps
+ };
+}
+
diff --git a/dashboard/src/components/extension/componentPanel/composables/useCommandFilters.ts b/dashboard/src/components/extension/componentPanel/composables/useCommandFilters.ts
new file mode 100644
index 000000000..f7d5bbc0e
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/composables/useCommandFilters.ts
@@ -0,0 +1,187 @@
+/**
+ * 指令过滤逻辑 Composable
+ */
+import { ref, computed, type Ref } from 'vue';
+import type { CommandItem, FilterState } from '../types';
+
+export function useCommandFilters(commands: Ref) {
+ // 过滤状态
+ const searchQuery = ref('');
+ const pluginFilter = ref('all');
+ const permissionFilter = ref('all');
+ const statusFilter = ref('all');
+ const typeFilter = ref('all');
+ const showSystemPlugins = ref(false);
+
+ // 展开的指令组
+ const expandedGroups = ref>(new Set());
+
+ /**
+ * 检查是否有涉及系统插件的冲突
+ */
+ const hasSystemPluginConflict = computed(() => {
+ return commands.value.some(cmd => cmd.has_conflict && cmd.reserved);
+ });
+
+ /**
+ * 实际是否显示系统插件(如果有系统插件冲突则强制显示)
+ */
+ const effectiveShowSystemPlugins = computed(() => {
+ return showSystemPlugins.value || hasSystemPluginConflict.value;
+ });
+
+ /**
+ * 获取可用的插件列表(用于过滤下拉框)
+ */
+ const availablePlugins = computed(() => {
+ const plugins = new Set(
+ commands.value
+ .filter(cmd => effectiveShowSystemPlugins.value || !cmd.reserved)
+ .map(cmd => cmd.plugin)
+ );
+ return Array.from(plugins).sort();
+ });
+
+ /**
+ * 检查指令是否匹配过滤条件
+ */
+ const matchesFilters = (cmd: CommandItem, query: string): boolean => {
+ // 系统插件过滤(除非显示系统插件)
+ if (!effectiveShowSystemPlugins.value && cmd.reserved) {
+ return false;
+ }
+
+ // 搜索过滤
+ if (query) {
+ const matchesSearch =
+ cmd.effective_command?.toLowerCase().includes(query) ||
+ cmd.description?.toLowerCase().includes(query) ||
+ cmd.plugin?.toLowerCase().includes(query);
+ if (!matchesSearch) return false;
+ }
+
+ // 插件过滤
+ if (pluginFilter.value !== 'all' && cmd.plugin !== pluginFilter.value) {
+ return false;
+ }
+
+ // 权限过滤
+ if (permissionFilter.value !== 'all') {
+ if (permissionFilter.value === 'everyone') {
+ if (cmd.permission !== 'everyone' && cmd.permission !== 'member') return false;
+ } else if (cmd.permission !== permissionFilter.value) {
+ return false;
+ }
+ }
+
+ // 状态过滤
+ if (statusFilter.value !== 'all') {
+ if (statusFilter.value === 'enabled' && !cmd.enabled) return false;
+ if (statusFilter.value === 'disabled' && cmd.enabled) return false;
+ if (statusFilter.value === 'conflict' && !cmd.has_conflict) return false;
+ }
+
+ // 类型过滤
+ if (typeFilter.value !== 'all') {
+ if (typeFilter.value === 'group' && cmd.type !== 'group') return false;
+ if (typeFilter.value === 'command' && cmd.type !== 'command') return false;
+ if (typeFilter.value === 'sub_command' && cmd.type !== 'sub_command') return false;
+ }
+
+ return true;
+ };
+
+ /**
+ * 过滤后的指令列表(支持层级结构)
+ */
+ const filteredCommands = computed(() => {
+ const query = searchQuery.value.toLowerCase();
+ const conflictCmds: CommandItem[] = [];
+ const normalCmds: CommandItem[] = [];
+
+ for (const cmd of commands.value) {
+ // 对于指令组,检查组本身或子指令是否匹配
+ if (cmd.is_group) {
+ const groupMatches = matchesFilters(cmd, query);
+ const matchingSubCmds = (cmd.sub_commands || []).filter(sub => matchesFilters(sub, query));
+
+ // 如果组匹配或有匹配的子指令,则包含它
+ if (groupMatches || matchingSubCmds.length > 0) {
+ if (cmd.has_conflict) {
+ conflictCmds.push(cmd);
+ } else {
+ normalCmds.push(cmd);
+ }
+
+ // 如果组已展开,添加匹配的子指令
+ if (expandedGroups.value.has(cmd.handler_full_name)) {
+ const subsToShow = query ? matchingSubCmds : (cmd.sub_commands || []);
+ for (const sub of subsToShow) {
+ if (sub.has_conflict) {
+ conflictCmds.push(sub);
+ } else {
+ normalCmds.push(sub);
+ }
+ }
+ }
+ }
+ } else if (cmd.type !== 'sub_command') {
+ // 普通指令(子指令通过组处理)
+ if (matchesFilters(cmd, query)) {
+ if (cmd.has_conflict) {
+ conflictCmds.push(cmd);
+ } else {
+ normalCmds.push(cmd);
+ }
+ }
+ }
+ }
+
+ // 按 effective_command 排序冲突指令,使其分组在一起
+ conflictCmds.sort((a, b) => (a.effective_command || '').localeCompare(b.effective_command || ''));
+
+ return [...conflictCmds, ...normalCmds];
+ });
+
+ /**
+ * 切换指令组的展开/折叠状态
+ */
+ const toggleGroupExpand = (cmd: CommandItem) => {
+ if (!cmd.is_group) return;
+ if (expandedGroups.value.has(cmd.handler_full_name)) {
+ expandedGroups.value.delete(cmd.handler_full_name);
+ } else {
+ expandedGroups.value.add(cmd.handler_full_name);
+ }
+ };
+
+ /**
+ * 检查指令组是否已展开
+ */
+ const isGroupExpanded = (cmd: CommandItem): boolean => {
+ return expandedGroups.value.has(cmd.handler_full_name);
+ };
+
+ return {
+ // 状态
+ searchQuery,
+ pluginFilter,
+ permissionFilter,
+ statusFilter,
+ typeFilter,
+ showSystemPlugins,
+ expandedGroups,
+
+ // 计算属性
+ hasSystemPluginConflict,
+ effectiveShowSystemPlugins,
+ availablePlugins,
+ filteredCommands,
+
+ // 方法
+ matchesFilters,
+ toggleGroupExpand,
+ isGroupExpanded
+ };
+}
+
diff --git a/dashboard/src/components/extension/componentPanel/composables/useComponentData.ts b/dashboard/src/components/extension/componentPanel/composables/useComponentData.ts
new file mode 100644
index 000000000..291ba53c4
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/composables/useComponentData.ts
@@ -0,0 +1,83 @@
+/**
+ * 指令数据管理 Composable
+ */
+import { ref, reactive } from 'vue';
+import axios from 'axios';
+import type { CommandItem, CommandSummary, SnackbarState, ToolItem } from '../types';
+
+export function useComponentData() {
+ const loading = ref(false);
+ const commands = ref([]);
+ const tools = ref([]);
+ const toolsLoading = ref(false);
+ const summary = reactive({
+ disabled: 0,
+ conflicts: 0
+ });
+
+ const snackbar = reactive({
+ show: false,
+ message: '',
+ color: 'success'
+ });
+
+ /**
+ * 显示 Toast 消息
+ */
+ const toast = (message: string, color: string = 'success') => {
+ snackbar.message = message;
+ snackbar.color = color;
+ snackbar.show = true;
+ };
+
+ /**
+ * 获取指令列表
+ */
+ const fetchCommands = async (errorMessage: string) => {
+ loading.value = true;
+ try {
+ const res = await axios.get('/api/commands');
+ if (res.data.status === 'ok') {
+ commands.value = res.data.data.items || [];
+ const s = res.data.data.summary || {};
+ summary.disabled = s.disabled || 0;
+ summary.conflicts = s.conflicts || 0;
+ } else {
+ toast(res.data.message || errorMessage, 'error');
+ }
+ } catch (err: any) {
+ toast(err?.message || errorMessage, 'error');
+ } finally {
+ loading.value = false;
+ }
+ };
+
+ const fetchTools = async (errorMessage: string) => {
+ toolsLoading.value = true;
+ try {
+ const res = await axios.get('/api/tools/list');
+ if (res.data.status === 'ok') {
+ tools.value = res.data.data || [];
+ } else {
+ toast(res.data.message || errorMessage, 'error');
+ }
+ } catch (err: any) {
+ toast(err?.message || errorMessage, 'error');
+ } finally {
+ toolsLoading.value = false;
+ }
+ };
+
+ return {
+ loading,
+ commands,
+ tools,
+ toolsLoading,
+ summary,
+ snackbar,
+ toast,
+ fetchCommands,
+ fetchTools
+ };
+}
+
diff --git a/dashboard/src/components/extension/componentPanel/index.vue b/dashboard/src/components/extension/componentPanel/index.vue
new file mode 100644
index 000000000..912af9156
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/index.vue
@@ -0,0 +1,307 @@
+
+
+
+
+
+
+
+
+
+
+ mdi-console-line
+ {{ tm('type.command') }}
+
+
+ mdi-function-variant
+ {{ tmTool('functionTools.title') }}
+
+
+
+
+
+
+
+
+
+
+ mdi-console-line
+ {{ tm('summary.total') }}:
+ {{ filteredCommands.length }}
+
+
+
+ mdi-close-circle-outline
+ {{ tm('summary.disabled') }}:
+ {{ summary.disabled }}
+
+
+
+
+
+
+ mdi-alert-circle
+
+
+ {{ tm('conflictAlert.title') }}
+
+
+ {{ tm('conflictAlert.description', { count: summary.conflicts }) }}
+
+
+ mdi-lightbulb-outline
+ {{ tm('conflictAlert.hint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ mdi-function-variant
+ {{ tm('summary.total') }}:
+ {{ filteredTools.length }}
+
+
+
+ mdi-check-circle-outline
+ {{ tm('status.enabled') }}:
+ {{ filteredTools.filter(t => t.active).length }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ snackbar.message }}
+
+
diff --git a/dashboard/src/components/extension/componentPanel/types.ts b/dashboard/src/components/extension/componentPanel/types.ts
new file mode 100644
index 000000000..d2b388ec9
--- /dev/null
+++ b/dashboard/src/components/extension/componentPanel/types.ts
@@ -0,0 +1,102 @@
+/**
+ * 指令管理模块 - 类型定义
+ */
+
+/** 指令项接口 */
+export interface CommandItem {
+ handler_full_name: string;
+ handler_name: string;
+ plugin: string;
+ plugin_display_name: string | null;
+ module_path: string;
+ description: string;
+ type: CommandType;
+ parent_signature: string;
+ parent_group_handler: string;
+ original_command: string;
+ current_fragment: string;
+ effective_command: string;
+ aliases: string[];
+ permission: PermissionType;
+ enabled: boolean;
+ is_group: boolean;
+ has_conflict: boolean;
+ reserved: boolean;
+ sub_commands: CommandItem[];
+}
+
+/** 指令类型 */
+export type CommandType = 'command' | 'group' | 'sub_command';
+
+/** 权限类型 */
+export type PermissionType = 'admin' | 'everyone' | 'member';
+
+/** 指令摘要统计 */
+export interface CommandSummary {
+ disabled: number;
+ conflicts: number;
+}
+
+/** 过滤器状态 */
+export interface FilterState {
+ searchQuery: string;
+ pluginFilter: string;
+ permissionFilter: string;
+ statusFilter: string;
+ typeFilter: string;
+ showSystemPlugins: boolean;
+}
+
+/** 重命名对话框状态 */
+export interface RenameDialogState {
+ show: boolean;
+ command: CommandItem | null;
+ newName: string;
+ loading: boolean;
+}
+
+/** 详情对话框状态 */
+export interface DetailsDialogState {
+ show: boolean;
+ command: CommandItem | null;
+}
+
+/** Toast 消息状态 */
+export interface SnackbarState {
+ show: boolean;
+ message: string;
+ color: string;
+}
+
+/** 类型信息展示 */
+export interface TypeInfo {
+ text: string;
+ color: string;
+ icon: string;
+}
+
+/** 状态信息展示 */
+export interface StatusInfo {
+ text: string;
+ color: string;
+ variant: 'flat' | 'outlined' | 'text' | 'elevated' | 'tonal' | 'plain';
+}
+
+/** MCP/函数工具参数定义 */
+export interface ToolParameter {
+ type?: string;
+ description?: string;
+}
+
+/** MCP/函数工具对象 */
+export interface ToolItem {
+ name: string;
+ description: string;
+ active: boolean;
+ parameters?: {
+ properties?: Record;
+ };
+ origin?: string;
+ origin_name?: string;
+}
+
diff --git a/dashboard/src/components/provider/AddNewProvider.vue b/dashboard/src/components/provider/AddNewProvider.vue
index b4cd1eb92..f59c24942 100644
--- a/dashboard/src/components/provider/AddNewProvider.vue
+++ b/dashboard/src/components/provider/AddNewProvider.vue
@@ -7,6 +7,10 @@
mdi-message-text
{{ tm('dialogs.addProvider.tabs.basic') }}
+
+ mdi-cogs
+ {{ tm('dialogs.addProvider.tabs.agentRunner') }}
+
mdi-microphone-message
{{ tm('dialogs.addProvider.tabs.speechToText') }}
@@ -27,7 +31,7 @@
-
接入 {{ name }}
+
{{ name }}
{{ getProviderDescription(template, name) }}
@@ -54,7 +58,7 @@
- {{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }}
+ {{ tm('dialogs.addProvider.noTemplates') }}
@@ -104,19 +108,6 @@ export default {
this.$emit('update:show', value);
}
},
-
- // 翻译消息的计算属性
- messages() {
- return {
- tabTypes: {
- 'chat_completion': this.tm('providers.tabs.chatCompletion'),
- 'speech_to_text': this.tm('providers.tabs.speechToText'),
- 'text_to_speech': this.tm('providers.tabs.textToSpeech'),
- 'embedding': this.tm('providers.tabs.embedding'),
- 'rerank': this.tm('providers.tabs.rerank')
- }
- };
- }
},
methods: {
closeDialog() {
@@ -140,11 +131,6 @@ export default {
// 从工具函数导入
getProviderIcon,
- // 获取Tab类型的中文名称
- getTabTypeName(tabType) {
- return this.messages.tabTypes[tabType] || tabType;
- },
-
// 获取提供商简介
getProviderDescription(template, name) {
return getProviderDescription(template, name, this.tm);
diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue
index d6c6fee9c..361794156 100644
--- a/dashboard/src/components/shared/AstrBotConfig.vue
+++ b/dashboard/src/components/shared/AstrBotConfig.vue
@@ -304,16 +304,32 @@ function hasVisibleItemsAfter(items, currentIndex) {
hide-details
>
-
-
+
+ class="d-flex align-center gap-3"
+ >
+
+
+
-
-
+
+ class="d-flex align-center gap-3"
+ >
+
+
+
{
+ if (!value || typeof value !== 'string') return value
+ return tm(value)
+}
+
+// 处理labels翻译 - labels可以是数组或国际化键
+const getTranslatedLabels = (itemMeta) => {
+ if (!itemMeta?.labels) return null
+
+ // 如果labels是字符串(国际化键)
+ if (typeof itemMeta.labels === 'string') {
+ const translatedLabels = getRaw(itemMeta.labels)
+ // 如果翻译成功且是数组,返回翻译结果
+ if (Array.isArray(translatedLabels)) {
+ return translatedLabels
+ }
+ }
+
+ // 如果labels是数组,直接返回
+ if (Array.isArray(itemMeta.labels)) {
+ return itemMeta.labels
+ }
+
+ return null
+}
const dialog = ref(false)
const currentEditingKey = ref('')
@@ -101,6 +129,21 @@ function shouldShowItem(itemMeta, itemKey) {
return true
}
+// 检查最外层的 object 是否应该显示
+function shouldShowSection() {
+ const sectionMeta = props.metadata[props.metadataKey]
+ if (!sectionMeta?.condition) {
+ return true
+ }
+ for (const [conditionKey, expectedValue] of Object.entries(sectionMeta.condition)) {
+ const actualValue = getValueBySelector(props.iterable, conditionKey)
+ if (actualValue !== expectedValue) {
+ return false
+ }
+ }
+ return true
+}
+
function hasVisibleItemsAfter(items, currentIndex) {
const itemEntries = Object.entries(items)
@@ -114,19 +157,40 @@ function hasVisibleItemsAfter(items, currentIndex) {
return false
}
+
+function parseSpecialValue(value) {
+ if (!value || typeof value !== 'string') {
+ return { name: '', subtype: '' }
+ }
+ const [name, ...rest] = value.split(':')
+ return {
+ name,
+ subtype: rest.join(':') || ''
+ }
+}
+
+function getSpecialName(value) {
+ return parseSpecialValue(value).name
+}
+
+function getSpecialSubtype(value) {
+ return parseSpecialValue(value).subtype
+}
+
-
+
- {{ metadata[metadataKey]?.description }}
+ {{ translateIfKey(metadata[metadataKey]?.description) }}
‼️
- {{ metadata[metadataKey]?.hint }}
+ {{ translateIfKey(metadata[metadataKey]?.hint) }}
@@ -140,13 +204,13 @@ function hasVisibleItemsAfter(items, currentIndex) {
- {{ itemMeta?.description || itemKey }}
+ {{ translateIfKey(itemMeta?.description) || itemKey }}
({{ itemKey }})
‼️
- {{ itemMeta?.hint }}
+ {{ translateIfKey(itemMeta?.hint) }}
@@ -154,7 +218,13 @@ function hasVisibleItemsAfter(items, currentIndex) {
@@ -175,10 +245,29 @@ function hasVisibleItemsAfter(items, currentIndex) {
-
-
+
+
+
+
+
+ color="primary" inset density="compact" hide-details
+ style="display: flex; justify-content: end;">
-
+
-
+
+
-
+
@@ -262,21 +331,17 @@ function hasVisibleItemsAfter(items, currentIndex) {
-
+
-
+
-
+
{{ plugin === '*' ? '所有插件' : plugin }}
@@ -284,7 +349,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
-
+
diff --git a/dashboard/src/components/shared/ConsoleDisplayer.vue b/dashboard/src/components/shared/ConsoleDisplayer.vue
index ea2ce2a95..7d6759dfd 100644
--- a/dashboard/src/components/shared/ConsoleDisplayer.vue
+++ b/dashboard/src/components/shared/ConsoleDisplayer.vue
@@ -1,5 +1,7 @@
@@ -7,8 +9,8 @@ import { useCommonStore } from '@/stores/common';
-
+
{{ level }}
@@ -35,7 +37,6 @@ export default {
'\u001b[32m': 'color: #00FF00;', // green
'default': 'color: #FFFFFF;'
},
- logCache: useCommonStore().getLogCache(),
historyNum_: -1,
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
@@ -45,7 +46,17 @@ export default {
'WARNING': 'amber',
'ERROR': 'red',
'CRITICAL': 'purple'
- }
+ },
+ lastProcessedTime: 0, // 记录最后处理的日志时间戳
+ localLogCache: [], // 本地日志缓存
+ }
+ },
+ computed: {
+ commonStore() {
+ return useCommonStore();
+ },
+ logCache() {
+ return this.commonStore.log_cache;
}
},
props: {
@@ -60,13 +71,39 @@ export default {
},
watch: {
logCache: {
- handler(val) {
- const lastLog = val[this.logCache.length - 1];
- if (lastLog && this.isLevelSelected(lastLog.level)) {
- this.printLog(lastLog.data);
+ handler(newVal) {
+ // 基于 timestamp 处理新增的日志
+ if (newVal && newVal.length > 0) {
+ // 确保 DOM 已经准备好
+ this.$nextTick(() => {
+ // 合并到本地缓存并按时间排序
+ const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
+
+ if (newLogs.length > 0) {
+ this.localLogCache.push(...newLogs);
+ // 按时间戳排序
+ this.localLogCache.sort((a, b) => a.time - b.time);
+
+ // 只保留最新的 log_cache_max_len 条
+ if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
+ this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
+ }
+
+ // 显示新日志
+ newLogs.forEach(logItem => {
+ if (this.isLevelSelected(logItem.level)) {
+ this.printLog(logItem.data);
+ }
+ });
+
+ // 更新最后处理时间
+ this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
+ }
+ });
}
},
- deep: true
+ deep: true,
+ immediate: false
},
selectedLevels: {
handler() {
@@ -75,14 +112,37 @@ export default {
deep: true
}
},
- mounted() {
- if (this.logCache.length === 0) {
- this.delayInit()
- } else {
- this.init()
- }
+ async mounted() {
+ // 请求历史日志
+ await this.fetchLogHistory();
+
+ // 等待 DOM 准备好后,显示历史日志
+ this.$nextTick(() => {
+ if (this.localLogCache.length > 0) {
+ this.localLogCache.forEach(logItem => {
+ if (this.isLevelSelected(logItem.level)) {
+ this.printLog(logItem.data);
+ }
+ });
+ // 更新最后处理时间
+ this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
+ }
+ });
},
methods: {
+ async fetchLogHistory() {
+ try {
+ const res = await axios.get('/api/log-history');
+ if (res.data.data.logs && res.data.data.logs.length > 0) {
+ this.localLogCache = [...res.data.data.logs];
+ // 按时间戳排序
+ this.localLogCache.sort((a, b) => a.time - b.time);
+ }
+ } catch (err) {
+ console.error('Failed to fetch log history:', err);
+ }
+ },
+
getLevelColor(level) {
return this.levelColors[level] || 'grey';
},
@@ -98,41 +158,22 @@ export default {
},
refreshDisplay() {
- // 清空现有的显示
const termElement = document.getElementById('term');
if (termElement) {
termElement.innerHTML = '';
- }
-
- // 重新显示符合筛选条件的日志
- this.init();
- },
-
- delayInit() {
- if (this.logCache.length === 0) {
- setTimeout(() => {
- this.delayInit()
- }, 500)
- } else {
- this.init()
- }
- },
-
- init() {
- this.historyNum_ = parseInt(this.historyNum)
- let i = 0
- for (let log of this.logCache) {
- if (this.isLevelSelected(log.level)) { // 只显示选中级别的日志
- if (this.historyNum_ != -1 && i >= this.logCache.length - this.historyNum_) {
- this.printLog(log.data)
- ++i
- } else if (this.historyNum_ == -1) {
- this.printLog(log.data)
- }
+
+ // 重新显示所有符合筛选条件的日志
+ if (this.localLogCache && this.localLogCache.length > 0) {
+ this.localLogCache.forEach(logItem => {
+ if (this.isLevelSelected(logItem.level)) {
+ this.printLog(logItem.data);
+ }
+ });
}
}
},
+
toggleAutoScroll() {
this.autoScroll = !this.autoScroll;
},
@@ -140,6 +181,11 @@ export default {
printLog(log) {
// append 一个 span 标签到 term,block 的方式
let ele = document.getElementById('term')
+ if (!ele) {
+ console.warn('term element not found, skipping log print');
+ return;
+ }
+
let span = document.createElement('pre')
let style = this.logColorAnsiMap['default']
for (let key in this.logColorAnsiMap) {
@@ -168,6 +214,7 @@ export default {
flex-wrap: wrap;
gap: 8px;
margin-bottom: 8px;
+ margin-left: 20px;
}
.fade-in {
diff --git a/dashboard/src/components/shared/KnowledgeBaseSelector.vue b/dashboard/src/components/shared/KnowledgeBaseSelector.vue
index e959b948a..8c8dae6ae 100644
--- a/dashboard/src/components/shared/KnowledgeBaseSelector.vue
+++ b/dashboard/src/components/shared/KnowledgeBaseSelector.vue
@@ -3,7 +3,7 @@
- 未选择
+ {{ tm('knowledgeBaseSelector.notSelected') }}
- 选择知识库
+ {{ tm('knowledgeBaseSelector.dialogTitle') }}
@@ -50,9 +50,9 @@
{{ kb.kb_name }}
- {{ kb.description || '无描述' }}
- - {{ kb.doc_count }} 个文档
- - {{ kb.chunk_count }} 个块
+ {{ kb.description || tm('knowledgeBaseSelector.noDescription') }}
+ - {{ tm('knowledgeBaseSelector.documentCount', { count: kb.doc_count }) }}
+ - {{ tm('knowledgeBaseSelector.chunkCount', { count: kb.chunk_count }) }}
@@ -68,9 +68,9 @@
mdi-database-off
-
暂无知识库
+
{{ tm('knowledgeBaseSelector.noKnowledgeBases') }}
- 创建知识库
+ {{ tm('knowledgeBaseSelector.createKnowledgeBase') }}
@@ -78,14 +78,14 @@
- 已选择 {{ selectedKnowledgeBases.length }} 个知识库
+ {{ tm('knowledgeBaseSelector.selectedCount', { count: selectedKnowledgeBases.length }) }}
- 取消
+ {{ tm('knowledgeBaseSelector.cancelSelection') }}
- 确认选择
+ {{ tm('knowledgeBaseSelector.confirmSelection') }}
@@ -96,6 +96,7 @@
import { ref, watch } from 'vue'
import axios from 'axios'
import { useRouter } from 'vue-router'
+import { useModuleI18n } from '@/i18n/composables'
const props = defineProps({
modelValue: {
@@ -110,6 +111,7 @@ const props = defineProps({
const emit = defineEmits(['update:modelValue'])
const router = useRouter()
+const { tm } = useModuleI18n('core.shared')
const dialog = ref(false)
const knowledgeBaseList = ref([])
diff --git a/dashboard/src/components/shared/ListConfigItem.vue b/dashboard/src/components/shared/ListConfigItem.vue
index 96c0ba372..626218223 100644
--- a/dashboard/src/components/shared/ListConfigItem.vue
+++ b/dashboard/src/components/shared/ListConfigItem.vue
@@ -2,7 +2,7 @@
- 暂无项目
+ {{ t('core.common.list.noItems') }}
@@ -14,7 +14,7 @@
- {{ buttonText }}
+ {{ buttonText || t('core.common.list.modifyButton') }}
@@ -22,17 +22,43 @@
- {{ dialogTitle }}
+ {{ dialogTitle || t('core.common.list.editTitle') }}
+
+
+
+
+
+
+ mdi-import
+ {{ t('core.common.list.batchImport') }}
+
+
+
+
-
+ class="ma-1 list-item-clickable"
+ @click="startEdit(index, item)">
+
{{ item }}
-
-
- mdi-pencil
-
-
- mdi-close
-
-
-
-
+
+
mdi-check
-
+
mdi-close
@@ -69,34 +99,43 @@
mdi-format-list-bulleted
-
暂无项目
-
-
-
-
-
-
-
-
-
- mdi-plus
- {{ t('core.common.list.addButton') }}
-
+
{{ t('core.common.list.noItemsHint') }}
- 取消
- 确认
+ {{ t('core.common.cancel') }}
+ {{ t('core.common.confirm') }}
+
+
+
+
+
+
+
+
+ {{ t('core.common.list.batchImportTitle') }}
+
+
+
+
+
+
+
+
+ {{ t('core.common.cancel') }}
+
+ {{ t('core.common.list.batchImportButton', { count: batchImportPreviewCount }) }}
+
@@ -139,12 +178,24 @@ const originalItems = ref([])
const newItem = ref('')
const editIndex = ref(-1)
const editItem = ref('')
+const showBatchImport = ref(false)
+const batchImportText = ref('')
// 计算要显示的项目
const displayItems = computed(() => {
return props.modelValue.slice(0, props.maxDisplayItems)
})
+// 计算批量导入的项目数量
+const batchImportPreviewCount = computed(() => {
+ if (!batchImportText.value) return 0
+ return batchImportText.value
+ .split('\n')
+ .map(line => line.trim())
+ .filter(line => line.length > 0)
+ .length
+})
+
// 监听 modelValue 变化,同步到 localItems
watch(() => props.modelValue, (newValue) => {
localItems.value = [...(newValue || [])]
@@ -199,6 +250,24 @@ function cancelDialog() {
newItem.value = ''
dialog.value = false
}
+
+function confirmBatchImport() {
+ if (batchImportText.value.trim()) {
+ const newItems = batchImportText.value
+ .split('\n')
+ .map(line => line.trim())
+ .filter(line => line.length > 0)
+
+ localItems.value.push(...newItems)
+ batchImportText.value = ''
+ showBatchImport.value = false
+ }
+}
+
+function cancelBatchImport() {
+ batchImportText.value = ''
+ showBatchImport.value = false
+}
\ No newline at end of file
diff --git a/dashboard/src/views/ConsolePage.vue b/dashboard/src/views/ConsolePage.vue
index 7df3aeca5..0ffff380d 100644
--- a/dashboard/src/views/ConsolePage.vue
+++ b/dashboard/src/views/ConsolePage.vue
@@ -13,10 +13,11 @@ const { tm } = useModuleI18n('features/console');
{{ tm('title') }}
@@ -57,7 +58,7 @@ export default {
},
data() {
return {
- autoScrollDisabled: false,
+ autoScrollEnabled: true,
pipDialog: false,
pipInstallPayload: {
package: '',
@@ -68,9 +69,9 @@ export default {
}
},
watch: {
- autoScrollDisabled(val) {
+ autoScrollEnabled(val) {
if (this.$refs.consoleDisplayer) {
- this.$refs.consoleDisplayer.autoScroll = !val;
+ this.$refs.consoleDisplayer.autoScroll = val;
}
}
},
diff --git a/dashboard/src/views/ConversationPage.vue b/dashboard/src/views/ConversationPage.vue
index a1b853f3f..8a4debd5e 100644
--- a/dashboard/src/views/ConversationPage.vue
+++ b/dashboard/src/views/ConversationPage.vue
@@ -40,6 +40,17 @@
:loading="loading" size="small" class="mr-2">
{{ tm('history.refresh') }}
+
+ {{ tm('batch.exportSelected', { count: selectedItems.length }) }}
+
-
+
mdi-chat-remove
@@ -195,7 +206,7 @@
-
+
@@ -320,6 +331,7 @@ import { debounce } from 'lodash';
import { VueMonacoEditor } from '@guolao/vue-monaco-editor';
import MarkdownIt from 'markdown-it';
import { useCommonStore } from '@/stores/common';
+import { useCustomizerStore } from '@/stores/customizer';
import { useI18n, useModuleI18n } from '@/i18n/composables';
import MessageList from '@/components/chat/MessageList.vue';
@@ -341,11 +353,13 @@ export default {
setup() {
const { t, locale } = useI18n();
const { tm } = useModuleI18n('features/conversation');
+ const customizerStore = useCustomizerStore();
return {
t,
tm,
- locale
+ locale,
+ customizerStore
};
},
@@ -485,6 +499,12 @@ export default {
};
},
+ // 检测是否为暗色模式
+ isDark() {
+ console.log('isDark', this.customizerStore.uiTheme);
+ return this.customizerStore.uiTheme === 'PurpleThemeDark';
+ },
+
// 将对话历史转换为 MessageList 组件期望的格式
formattedMessages() {
return this.conversationHistory.map(msg => {
@@ -901,6 +921,53 @@ export default {
}
},
+ // 导出选中的对话
+ async exportConversations() {
+ if (this.selectedItems.length === 0) {
+ this.showErrorMessage(this.tm('messages.noItemSelectedForExport'));
+ return;
+ }
+
+ this.loading = true;
+ try {
+ // 准备导出的数据
+ const conversations = this.selectedItems.map(item => ({
+ user_id: item.user_id,
+ cid: item.cid
+ }));
+
+ const response = await axios.post('/api/conversation/export', {
+ conversations: conversations
+ }, {
+ responseType: 'blob' // 重要:告诉 axios 响应是一个 blob
+ });
+
+ // 创建一个下载链接
+ const url = window.URL.createObjectURL(response.data);
+ const link = document.createElement('a');
+ link.href = url;
+
+ // 生成文件名(使用时间戳)
+ const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, -5);
+ const filename = `conversations_export_${timestamp}.jsonl`;
+
+ link.setAttribute('download', filename);
+ document.body.appendChild(link);
+ link.click();
+
+ // 清理
+ link.remove();
+ window.URL.revokeObjectURL(url);
+
+ this.showSuccessMessage(this.tm('messages.exportSuccess'));
+ } catch (error) {
+ console.error(this.tm('messages.exportError'), error);
+ this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.exportError'));
+ } finally {
+ this.loading = false;
+ }
+ },
+
// 格式化时间戳
formatTimestamp(timestamp) {
if (!timestamp) return this.tm('status.unknown');
@@ -987,6 +1054,11 @@ export default {
background-color: #f9f9f9;
}
+/* 暗色模式下的聊天消息容器 */
+.v-theme--dark .conversation-messages-container {
+ background-color: #1e1e1e;
+}
+
/* 对话详情卡片 */
.conversation-detail-card {
max-height: 90vh;
diff --git a/dashboard/src/views/ExtensionPage.vue b/dashboard/src/views/ExtensionPage.vue
index bb83dc912..5a5037efb 100644
--- a/dashboard/src/views/ExtensionPage.vue
+++ b/dashboard/src/views/ExtensionPage.vue
@@ -5,17 +5,45 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import ReadmeDialog from '@/components/shared/ReadmeDialog.vue';
import ProxySelector from '@/components/shared/ProxySelector.vue';
import UninstallConfirmDialog from '@/components/shared/UninstallConfirmDialog.vue';
+import McpServersSection from '@/components/extension/McpServersSection.vue';
+import ComponentPanel from '@/components/extension/componentPanel/index.vue';
import axios from 'axios';
import { pinyin } from 'pinyin-pro';
import { useCommonStore } from '@/stores/common';
import { useI18n, useModuleI18n } from '@/i18n/composables';
+import defaultPluginIcon from '@/assets/images/plugin_icon.png';
-import { ref, computed, onMounted, reactive, inject, watch } from 'vue';
-
+import { ref, computed, onMounted, reactive, watch } from 'vue';
+import { useRouter } from 'vue-router';
const commonStore = useCommonStore();
const { t } = useI18n();
const { tm } = useModuleI18n('features/extension');
+const router = useRouter();
+
+// 检查指令冲突并提示
+const conflictDialog = reactive({
+ show: false,
+ count: 0
+});
+const checkAndPromptConflicts = async () => {
+ try {
+ const res = await axios.get('/api/commands');
+ if (res.data.status === 'ok') {
+ const conflicts = res.data.data.summary?.conflicts || 0;
+ if (conflicts > 0) {
+ conflictDialog.count = conflicts;
+ conflictDialog.show = true;
+ }
+ }
+ } catch (err) {
+ console.debug('Failed to check command conflicts:', err);
+ }
+};
+const handleConflictConfirm = () => {
+ activeTab.value = 'commands';
+};
+
const fileInput = ref(null);
const activeTab = ref('installed');
const extension_data = reactive({
@@ -41,6 +69,7 @@ const loadingDialog = reactive({
const showPluginInfoDialog = ref(false);
const selectedPlugin = ref({});
const curr_namespace = ref("");
+const updatingAll = ref(false);
const readmeDialog = reactive({
show: false,
@@ -65,6 +94,17 @@ const selectedDangerPlugin = ref(null);
const showUninstallDialog = ref(false);
const pluginToUninstall = ref(null);
+// 自定义插件源相关
+const showSourceDialog = ref(false);
+const sourceName = ref("");
+const sourceUrl = ref("");
+const customSources = ref([]);
+const selectedSource = ref(null);
+const showRemoveSourceDialog = ref(false);
+const sourceToRemove = ref(null);
+const editingSource = ref(false);
+const originalSourceUrl = ref("");
+
// 插件市场相关
const extension_url = ref("");
const dialog = ref(false);
@@ -136,10 +176,11 @@ const pluginMarketHeaders = computed(() => [
// 过滤要显示的插件
const filteredExtensions = computed(() => {
+ const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
if (!showReserved.value) {
- return extension_data?.data?.filter(ext => !ext.reserved) || [];
+ return data.filter(ext => !ext.reserved);
}
- return extension_data.data || [];
+ return data;
});
// 通过搜索过滤插件
@@ -225,6 +266,10 @@ const paginatedPlugins = computed(() => {
return sortedPlugins.value.slice(start, end);
});
+const updatableExtensions = computed(() => {
+ return extension_data?.data?.filter(ext => ext.has_update) || [];
+});
+
// 方法
const toggleShowReserved = () => {
showReserved.value = !showReserved.value;
@@ -274,7 +319,8 @@ const checkUpdate = () => {
onlinePluginsNameMap.set(plugin.name, plugin);
});
- extension_data.data.forEach(extension => {
+ const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
+ data.forEach(extension => {
const repoKey = extension.repo?.toLowerCase();
const onlinePlugin = repoKey ? onlinePluginsMap.get(repoKey) : null;
const onlinePluginByName = onlinePluginsNameMap.get(extension.name);
@@ -371,6 +417,56 @@ const updateExtension = async (extension_name) => {
}
};
+const updateAllExtensions = async () => {
+ if (updatingAll.value || updatableExtensions.value.length === 0) return;
+ updatingAll.value = true;
+ loadingDialog.title = tm('status.loading');
+ loadingDialog.statusCode = 0;
+ loadingDialog.result = "";
+ loadingDialog.show = true;
+
+ const targets = updatableExtensions.value.map(ext => ext.name);
+ try {
+ const res = await axios.post('/api/plugin/update-all', {
+ names: targets,
+ proxy: localStorage.getItem('selectedGitHubProxy') || ""
+ });
+
+ if (res.data.status === "error") {
+ onLoadingDialogResult(2, res.data.message || tm('messages.updateAllFailed', {
+ failed: targets.length,
+ total: targets.length
+ }), -1);
+ return;
+ }
+
+ const results = res.data.data?.results || [];
+ const failures = results.filter(r => r.status !== 'ok');
+ try {
+ await getExtensions();
+ } catch (err) {
+ const errorMsg = err.response?.data?.message || err.message || String(err);
+ failures.push({ name: 'refresh', status: 'error', message: errorMsg });
+ }
+
+ if (failures.length === 0) {
+ onLoadingDialogResult(1, tm('messages.updateAllSuccess'));
+ } else {
+ const failureText = tm('messages.updateAllFailed', {
+ failed: failures.length,
+ total: targets.length
+ });
+ const detail = failures.map(f => `${f.name}: ${f.message}`).join('\n');
+ onLoadingDialogResult(2, `${failureText}\n${detail}`, -1);
+ }
+ } catch (err) {
+ const errorMsg = err.response?.data?.message || err.message || String(err);
+ onLoadingDialogResult(2, errorMsg, -1);
+ } finally {
+ updatingAll.value = false;
+ }
+};
+
const pluginOn = async (extension) => {
try {
const res = await axios.post('/api/plugin/on', { name: extension.name });
@@ -379,7 +475,9 @@ const pluginOn = async (extension) => {
return;
}
toast(res.data.message, "success");
- getExtensions();
+ await getExtensions();
+
+ await checkAndPromptConflicts();
} catch (err) {
toast(err, "error");
}
@@ -491,6 +589,156 @@ const cancelDangerInstall = () => {
selectedDangerPlugin.value = null;
};
+// 自定义插件源管理方法
+const loadCustomSources = async () => {
+ try {
+ const res = await axios.get('/api/plugin/source/get');
+ if (res.data.status === "ok") {
+ customSources.value = res.data.data;
+ } else {
+ toast(res.data.message, "error");
+ }
+ } catch (e) {
+ console.warn('Failed to load custom sources:', e);
+ customSources.value = [];
+ }
+
+ // 加载当前选中的插件源
+ const currentSource = localStorage.getItem('selectedPluginSource');
+ if (currentSource) {
+ selectedSource.value = currentSource;
+ }
+};
+
+const saveCustomSources = async () => {
+ try {
+ const res = await axios.post('/api/plugin/source/save', {
+ sources: customSources.value
+ });
+ if (res.data.status !== "ok") {
+ toast(res.data.message, "error");
+ }
+ } catch (e) {
+ toast(e, "error");
+ }
+};
+
+const addCustomSource = () => {
+ editingSource.value = false;
+ originalSourceUrl.value = '';
+ sourceName.value = '';
+ sourceUrl.value = '';
+ showSourceDialog.value = true;
+};
+
+const selectPluginSource = (sourceUrl) => {
+ selectedSource.value = sourceUrl;
+ if (sourceUrl) {
+ localStorage.setItem('selectedPluginSource', sourceUrl);
+ } else {
+ localStorage.removeItem('selectedPluginSource');
+ }
+ // 重新加载插件市场数据
+ refreshPluginMarket();
+};
+
+// 获取当前选中的源对象
+const selectedSourceObj = computed(() => {
+ if (!selectedSource.value) return null;
+ return customSources.value.find(s => s.url === selectedSource.value) || null;
+});
+
+const editCustomSource = (source) => {
+ if (!source) return;
+ editingSource.value = true;
+ originalSourceUrl.value = source.url;
+ sourceName.value = source.name;
+ sourceUrl.value = source.url;
+ showSourceDialog.value = true;
+};
+
+const removeCustomSource = (source) => {
+ if (!source) return;
+ sourceToRemove.value = source;
+ showRemoveSourceDialog.value = true;
+};
+
+const confirmRemoveSource = () => {
+ if (sourceToRemove.value) {
+ customSources.value = customSources.value.filter(s => s.url !== sourceToRemove.value.url);
+ saveCustomSources();
+
+ // 如果删除的是当前选中的源,切换到默认源
+ if (selectedSource.value === sourceToRemove.value.url) {
+ selectedSource.value = null;
+ localStorage.removeItem('selectedPluginSource');
+ // 重新加载插件市场数据
+ refreshPluginMarket();
+ }
+
+ toast(tm('market.sourceRemoved'), 'success');
+ showRemoveSourceDialog.value = false;
+ sourceToRemove.value = null;
+ }
+};
+
+const saveCustomSource = () => {
+ const normalizedUrl = sourceUrl.value.trim();
+
+ if (!sourceName.value.trim() || !normalizedUrl) {
+ toast(tm('messages.fillSourceNameAndUrl'), 'error');
+ return;
+ }
+
+ // 检查URL格式
+ try {
+ new URL(normalizedUrl);
+ } catch (e) {
+ toast(tm('messages.invalidUrl'), 'error');
+ return;
+ }
+
+ if (editingSource.value) {
+ // 编辑模式:更新现有源
+ const index = customSources.value.findIndex(s => s.url === originalSourceUrl.value);
+ if (index !== -1) {
+ customSources.value[index] = {
+ name: sourceName.value.trim(),
+ url: normalizedUrl
+ };
+
+ // 如果编辑的是当前选中的源,更新选中源
+ if (selectedSource.value === originalSourceUrl.value) {
+ selectedSource.value = normalizedUrl;
+ localStorage.setItem('selectedPluginSource', selectedSource.value);
+ // 重新加载插件市场数据
+ refreshPluginMarket();
+ }
+ }
+ } else {
+ // 添加模式:检查是否已存在
+ if (customSources.value.some(source => source.url === normalizedUrl)) {
+ toast(tm('market.sourceExists'), 'error');
+ return;
+ }
+
+ customSources.value.push({
+ name: sourceName.value.trim(),
+ url: normalizedUrl
+ });
+ }
+
+ saveCustomSources();
+ toast(editingSource.value ? tm('market.sourceUpdated') : tm('market.sourceAdded'), 'success');
+
+ // 重置表单
+ sourceName.value = '';
+ sourceUrl.value = '';
+ editingSource.value = false;
+ originalSourceUrl.value = '';
+ showSourceDialog.value = false;
+};
+
// 插件市场显示完整插件名称
const trimExtensionName = () => {
pluginMarketData.value.forEach(plugin => {
@@ -506,8 +754,9 @@ const trimExtensionName = () => {
};
const checkAlreadyInstalled = () => {
- const installedRepos = new Set(extension_data.data.map(ext => ext.repo?.toLowerCase()));
- const installedNames = new Set(extension_data.data.map(ext => ext.name));
+ const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
+ const installedRepos = new Set(data.map(ext => ext.repo?.toLowerCase()));
+ const installedNames = new Set(data.map(ext => ext.name));
for (let i = 0; i < pluginMarketData.value.length; i++) {
const plugin = pluginMarketData.value[i];
@@ -562,6 +811,8 @@ const newExtension = async () => {
name: res.data.data.name,
repo: res.data.data.repo || null
});
+
+ await checkAndPromptConflicts();
}).catch((err) => {
loading_.value = false;
onLoadingDialogResult(2, err, -1);
@@ -588,6 +839,8 @@ const newExtension = async () => {
name: res.data.data.name,
repo: res.data.data.repo || null
});
+
+ await checkAndPromptConflicts();
}).catch((err) => {
loading_.value = false;
toast(tm('messages.installFailed') + " " + err, "error");
@@ -601,7 +854,7 @@ const refreshPluginMarket = async () => {
refreshingMarket.value = true;
try {
// 强制刷新插件市场数据
- const data = await commonStore.getPluginCollections(true);
+ const data = await commonStore.getPluginCollections(true, selectedSource.value);
pluginMarketData.value = data;
trimExtensionName();
checkAlreadyInstalled();
@@ -619,6 +872,9 @@ const refreshPluginMarket = async () => {
// 生命周期
onMounted(async () => {
await getExtensions();
+
+ // 加载自定义插件源
+ loadCustomSources();
// 检查是否有 open_config 参数
let urlParams;
@@ -638,7 +894,7 @@ onMounted(async () => {
}
try {
- const data = await commonStore.getPluginCollections();
+ const data = await commonStore.getPluginCollections(false, selectedSource.value);
pluginMarketData.value = data;
trimExtensionName();
checkAlreadyInstalled();
@@ -677,21 +933,29 @@ watch(marketSearch, (newVal) => {
mdi-puzzle
- {{ tm('tabs.installed') }}
+ {{ tm('tabs.installedPlugins') }}
+
+
+ mdi-server-network
+ {{ tm('tabs.installedMcpServers') }}
mdi-store
{{ tm('tabs.market') }}
+
+ mdi-wrench
+ {{ tm('tabs.handlersOperation') }}
+
-
-
@@ -719,6 +983,12 @@ watch(marketSearch, (newVal) => {
{{ showReserved ? tm('buttons.hideSystemPlugins') : tm('buttons.showSystemPlugins') }}
+
+ mdi-update
+ {{ tm('buttons.updateAll') }}
+
+
mdi-plus
{{ tm('buttons.install') }}
@@ -889,14 +1159,167 @@ watch(marketSearch, (newVal) => {
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ mdi-source-branch
+
+ {{ tm('market.source') }}
+
+
+
+
+
+
+ {{ selectedSource ? customSources.find(s => s.url === selectedSource)?.name : tm('market.defaultSource') }}
+
+ mdi-chevron-down
+ {{ selectedSource || tm('market.defaultOfficialSource') }}
+
+
+
+
+ {{ tm('market.availableSources') }}
+
+
+
+
+
+ {{ tm('market.defaultSource') }}
+
+
+
+
+
+
+
+
+ {{ source.name }}
+ {{ source.url }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
@@ -939,7 +1362,7 @@ watch(marketSearch, (newVal) => {
-
@@ -950,8 +1373,8 @@ watch(marketSearch, (newVal) => {
-
-
+
@@ -986,8 +1409,7 @@ watch(marketSearch, (newVal) => {
-
+
{{ plugin.desc }}
@@ -1063,9 +1485,29 @@ watch(marketSearch, (newVal) => {
- {{ tm('market.devDocs') }} |
- {{ tm('market.submitRepo')
- }}
+
+
+ {{ tm('market.devDocs') }}
+
+
+
+ {{ tm('market.submitRepo') }}
+
+
@@ -1090,7 +1532,7 @@ watch(marketSearch, (newVal) => {
{{ loadingDialog.title }}
-
+
@@ -1161,6 +1603,34 @@ watch(marketSearch, (newVal) => {
+
+
+
+
+ mdi-alert-circle
+ {{ tm('conflicts.title') }}
+
+
+
+
+ {{ conflictDialog.count }}
+
+ {{ tm('conflicts.pairs') }}
+
+
+ {{ tm('conflicts.message') }}
+
+
+
+
+ {{ tm('conflicts.later') }}
+
+ {{ tm('conflicts.goToManage') }}
+
+
+
+
+
@@ -1185,10 +1655,15 @@ watch(marketSearch, (newVal) => {
-
- {{ tm('dialogs.install.title') }}
-
-
+
+
+
+
+
+
{{ tm('dialogs.install.title') }}
+
+
+
{{ tm('dialogs.install.fromFile') }}
{{ tm('dialogs.install.fromUrl') }}
@@ -1199,7 +1674,7 @@ watch(marketSearch, (newVal) => {
-
+
{{ tm('buttons.selectFile') }}
@@ -1221,19 +1696,79 @@ watch(marketSearch, (newVal) => {
-
-
+
+
+
{{ tm('buttons.cancel') }}
{{ tm('buttons.install') }}
+
+
+
+
+
+
+
+ {{ editingSource ? tm('market.editSource') : tm('market.addSource') }}
+
+
+
+
+
+
+
+ {{ tm('messages.enterJsonUrl') }}
+
+
+
+
+
+ {{ tm('buttons.cancel') }}
+ {{ tm('buttons.save') }}
+
+
+
+
+
+
+
+
+ mdi-alert-circle
+ {{ tm('dialogs.uninstall.title') }}
+
+
+ {{ tm('market.confirmRemoveSource') }}
+
+
{{ sourceToRemove.name }}
+
{{ sourceToRemove.url }}
+
+
+
+
+ {{ tm('buttons.cancel') }}
+ {{ tm('buttons.deleteSource') }}
@@ -1246,4 +1781,46 @@ watch(marketSearch, (newVal) => {
border-radius: 5px;
background-color: #f5f5f5;
}
+
+.plugin-description {
+ color: rgba(var(--v-theme-on-surface), 0.6);
+ line-height: 1.3;
+ margin-bottom: 6px;
+ flex: 1;
+ overflow-y: hidden;
+}
+
+.plugin-card:hover .plugin-description {
+ overflow-y: auto;
+}
+
+.plugin-description::-webkit-scrollbar {
+ width: 8px;
+ height: 8px;
+}
+
+.plugin-description::-webkit-scrollbar-track {
+ background: transparent;
+}
+
+.plugin-description::-webkit-scrollbar-thumb {
+ background-color: rgba(var(--v-theme-primary-rgb), 0.4);
+ border-radius: 4px;
+ border: 2px solid transparent;
+ background-clip: content-box;
+}
+
+.plugin-description::-webkit-scrollbar-thumb:hover {
+ background-color: rgba(var(--v-theme-primary-rgb), 0.6);
+}
+
+.fab-button {
+ transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1);
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
+}
+
+.fab-button:hover {
+ transform: translateY(-4px) scale(1.05);
+ box-shadow: 0 12px 20px rgba(var(--v-theme-primary), 0.4);
+}
diff --git a/dashboard/src/views/PlatformPage.vue b/dashboard/src/views/PlatformPage.vue
index dda6d5298..56d23852d 100644
--- a/dashboard/src/views/PlatformPage.vue
+++ b/dashboard/src/views/PlatformPage.vue
@@ -29,13 +29,54 @@
+
+
+
+
+
+ {{ getStatusIcon(getPlatformStat(item.id)?.status) }}
+ {{ tm('runtimeStatus.' + (getPlatformStat(item.id)?.status || 'unknown')) }}
+
+
+
+ mdi-bug
+ {{ getPlatformStat(item.id)?.error_count }} {{ tm('runtimeStatus.errors') }}
+
+
+
+
+ mdi-webhook
+ {{ tm('viewWebhook') }}
+
+
+
-
+
mdi-console-line
{{ tm('logs.title') }}
@@ -60,6 +101,84 @@
:updating-mode="updatingMode" :updating-platform-config="updatingPlatformConfig" @update="getConfig"
@show-toast="showToast" @refresh-config="getConfig"/>
+
+
+
+
+ mdi-webhook
+ {{ tm('webhookDialog.title') }}
+
+
+ {{ tm('webhookDialog.description') }}
+
+
+
+ mdi-content-copy
+
+
+
+
+
+
+
+ {{ tm('webhookDialog.close') }}
+
+
+
+
+
+
+
+
+
+ mdi-alert-circle
+ {{ tm('errorDialog.title') }}
+
+
+
+ {{ tm('errorDialog.platformId') }}: {{ currentErrorPlatform.id }}
+
+
+ {{ tm('errorDialog.errorCount') }}: {{ currentErrorPlatform.error_count }}
+
+
+
+ {{ tm('errorDialog.lastError') }}:
+
+
+ {{ currentErrorPlatform.last_error.message }}
+
+ {{ tm('errorDialog.occurredAt') }}: {{ new Date(currentErrorPlatform.last_error.timestamp).toLocaleString() }}
+
+
+
+
+ {{ tm('errorDialog.traceback') }}:
+
+
{{ currentErrorPlatform.last_error.traceback }}
+
+
+
+
+
+
+ {{ tm('errorDialog.close') }}
+
+
+
+
+
@@ -97,18 +216,6 @@ export default {
tm
};
},
- computed: {
- // 安全访问翻译的计算属性
- messages() {
- return {
- updateSuccess: this.tm('messages.updateSuccess'),
- addSuccess: this.tm('messages.addSuccess'),
- deleteSuccess: this.tm('messages.deleteSuccess'),
- statusUpdateSuccess: this.tm('messages.statusUpdateSuccess'),
- deleteConfirm: this.tm('messages.deleteConfirm')
- };
- }
- },
data() {
return {
config_data: {},
@@ -125,6 +232,17 @@ export default {
showConsole: false,
+ showWebhookDialog: false,
+ currentWebhookUuid: '',
+
+ // 平台统计信息
+ platformStats: {},
+ statsRefreshInterval: null,
+
+ // 错误详情对话框
+ showErrorDialog: false,
+ currentErrorPlatform: null,
+
store: useCommonStore()
}
},
@@ -147,6 +265,17 @@ export default {
mounted() {
this.getConfig();
+ this.getPlatformStats();
+ // 每 10 秒刷新一次平台状态
+ this.statsRefreshInterval = setInterval(() => {
+ this.getPlatformStats();
+ }, 10000);
+ },
+
+ beforeUnmount() {
+ if (this.statsRefreshInterval) {
+ clearInterval(this.statsRefreshInterval);
+ }
},
methods: {
@@ -171,6 +300,53 @@ export default {
});
},
+ getPlatformStats() {
+ axios.get('/api/platform/stats').then((res) => {
+ if (res.data.status === 'ok') {
+ // 将数组转换为以 id 为 key 的对象,方便查找
+ const stats = {};
+ for (const platform of res.data.data.platforms || []) {
+ stats[platform.id] = platform;
+ }
+ this.platformStats = stats;
+ }
+ }).catch((err) => {
+ console.warn('获取平台统计信息失败:', err);
+ });
+ },
+
+ getPlatformStat(platformId) {
+ return this.platformStats[platformId] || null;
+ },
+
+ getStatusColor(status) {
+ switch (status) {
+ case 'running': return 'success';
+ case 'error': return 'error';
+ case 'pending': return 'warning';
+ case 'stopped': return 'grey';
+ default: return 'grey';
+ }
+ },
+
+ getStatusIcon(status) {
+ switch (status) {
+ case 'running': return 'mdi-check-circle';
+ case 'error': return 'mdi-alert-circle';
+ case 'pending': return 'mdi-clock-outline';
+ case 'stopped': return 'mdi-stop-circle';
+ default: return 'mdi-help-circle';
+ }
+ },
+
+ showErrorDetails(platform) {
+ const stat = this.getPlatformStat(platform.id);
+ if (stat && stat.error_count > 0) {
+ this.currentErrorPlatform = stat;
+ this.showErrorDialog = true;
+ }
+ },
+
editPlatform(platform) {
this.updatingPlatformConfig = JSON.parse(JSON.stringify(platform));
this.updatingMode = true;
@@ -224,6 +400,47 @@ export default {
this.save_message = message;
this.save_message_success = "error";
this.save_message_snack = true;
+ },
+
+ getWebhookUrl(webhookUuid) {
+ let callbackBase = this.config_data.callback_api_base || '';
+ if (!callbackBase) {
+ callbackBase = "http(s)://";
+ }
+ if (callbackBase) {
+ return `${callbackBase.replace(/\/$/, '')}/api/platform/webhook/${webhookUuid}`;
+ }
+ return `/api/platform/webhook/${webhookUuid}`;
+ },
+
+ openWebhookDialog(webhookUuid) {
+ this.currentWebhookUuid = webhookUuid;
+ this.showWebhookDialog = true;
+ },
+
+ async copyWebhookUrl(webhookUuid) {
+ const url = this.getWebhookUrl(webhookUuid);
+ try {
+ await navigator.clipboard.writeText(url);
+ this.showSuccess(this.tm('webhookCopied'));
+ } catch (err) {
+ this.showError(this.tm('webhookCopyFailed'));
+ }
+ }
+ },
+ computed: {
+ // 安全访问翻译的计算属性
+ messages() {
+ return {
+ updateSuccess: this.tm('messages.updateSuccess'),
+ addSuccess: this.tm('messages.addSuccess'),
+ deleteSuccess: this.tm('messages.deleteSuccess'),
+ statusUpdateSuccess: this.tm('messages.statusUpdateSuccess'),
+ deleteConfirm: this.tm('messages.deleteConfirm')
+ };
+ },
+ currentWebhookUrl() {
+ return this.getWebhookUrl(this.currentWebhookUuid);
}
}
}
@@ -233,5 +450,52 @@ export default {
.platform-page {
padding: 20px;
padding-top: 8px;
+ padding-bottom: 40px;
+}
+
+.webhook-info {
+ margin-top: 4px;
+}
+
+.webhook-chip {
+ cursor: pointer;
+}
+
+.platform-status-row {
+ display: flex;
+ align-items: center;
+ flex-wrap: wrap;
+ gap: 4px;
+}
+
+.status-chip {
+ font-size: 12px;
+}
+
+.error-chip {
+ cursor: pointer;
+ font-size: 12px;
+}
+
+.error-details {
+ margin-top: 8px;
+}
+
+.error-message {
+ word-break: break-word;
+}
+
+.traceback-box {
+ background-color: #1e1e1e;
+ color: #d4d4d4;
+ padding: 12px;
+ border-radius: 8px;
+ font-size: 12px;
+ line-height: 1.5;
+ overflow-x: auto;
+ white-space: pre-wrap;
+ word-break: break-word;
+ max-height: 300px;
+ overflow-y: auto;
}
diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue
index e973ef624..643beaac6 100644
--- a/dashboard/src/views/ProviderPage.vue
+++ b/dashboard/src/views/ProviderPage.vue
@@ -30,6 +30,10 @@
mdi-message-text
{{ tm('providers.tabs.chatCompletion') }}
+
+ mdi-message-text
+ {{ tm('providers.tabs.agentRunner') }}
+
mdi-microphone-message
{{ tm('providers.tabs.speechToText') }}
@@ -48,87 +52,103 @@
-
-
- mdi-api-off
- {{ getEmptyText() }}
-
-
+
+
+
+ mdi-api-off
+ {{ getEmptyText() }}
+
+
+
+
+
{{ group.label }}
+
+
+
+
+
+
+
+
+
+ {{ getProviderStatus(item.id).status === 'available' ? 'mdi-check-circle' :
+ getProviderStatus(item.id).status === 'unavailable' ? 'mdi-alert-circle' :
+ 'mdi-clock-outline' }}
+
+ {{ getStatusText(getProviderStatus(item.id).status) }}
+
+
+
+ {{ getProviderStatus(item.id).error }}
+
+ {{ getStatusText(getProviderStatus(item.id).status) }}
+
+
+
+
+ {{ tm('availability.test') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ mdi-api-off
+ {{ getEmptyText() }}
+
+
+
+
+
-
-
-
-
-
- {{ tm('availability.test') }}
-
-
-
-
-
-
-
+
+
+
+
+
+
+ {{ getProviderStatus(item.id).status === 'available' ? 'mdi-check-circle' :
+ getProviderStatus(item.id).status === 'unavailable' ? 'mdi-alert-circle' :
+ 'mdi-clock-outline' }}
+
+ {{ getStatusText(getProviderStatus(item.id).status) }}
+
+
+
+ {{ getProviderStatus(item.id).error }}
+
+ {{ getStatusText(getProviderStatus(item.id).status) }}
+
+
+
+
+ {{ tm('availability.test') }}
+
+
+
+
+
+
-
-
-
- mdi-heart-pulse
- {{ tm('availability.title') }}
-
-
- mdi-refresh
- {{ tm('availability.refresh') }}
-
-
- {{ showStatus ? tm('logs.collapse') : tm('logs.expand') }}
- {{ showStatus ? 'mdi-chevron-up' : 'mdi-chevron-down' }}
-
-
-
-
-
-
-
- {{ tm('availability.noData') }}
-
-
-
-
-
-
-
- mdi-check-circle
- mdi-alert-circle
-
-
- {{ status.id }}
-
-
- {{ getStatusText(status.status) }}
-
-
-
- {{ tm('availability.errorMessage') }}: {{ status.error }}
-
-
-
-
-
-
-
-
-
-
-
+
mdi-console-line
{{ tm('logs.title') }}
@@ -182,6 +202,30 @@
+
+
+
+
+ mdi-information
+ 请前往「配置文件」页测试 Agent 执行器
+
+
+ Agent 执行器的测试请在「配置文件」页进行。
+
+ 找到对应的配置文件并打开。
+ 找到 Agent 执行方式部分,修改执行器后点击保存。
+ 点击右下角的 💬 聊天按钮进行测试。
+
+ 要让机器人应用这个 Agent 执行器,你也需要前往修改 Agent 执行器。
+
+
+
+ 好的
+ 点击前往
+
+
+
+
@@ -258,6 +302,9 @@ export default {
showKeyConfirm: false,
keyConfirmResolve: null,
+ // Agent Runner 提示对话框
+ showAgentRunnerDialog: false,
+
newSelectedProviderName: '',
newSelectedProviderConfig: {},
updatingMode: false,
@@ -289,8 +336,8 @@ export default {
"anthropic_chat_completion": "chat_completion",
"googlegenai_chat_completion": "chat_completion",
"zhipu_chat_completion": "chat_completion",
- "dify": "chat_completion",
- "coze": "chat_completion",
+ "dify": "agent_runner",
+ "coze": "agent_runner",
"dashscope": "chat_completion",
"openai_whisper_api": "speech_to_text",
"openai_whisper_selfhost": "speech_to_text",
@@ -334,6 +381,7 @@ export default {
},
tabTypes: {
'chat_completion': this.tm('providers.tabs.chatCompletion'),
+ 'agent_runner': this.tm('providers.tabs.agentRunner'),
'speech_to_text': this.tm('providers.tabs.speechToText'),
'text_to_speech': this.tm('providers.tabs.textToSpeech'),
'embedding': this.tm('providers.tabs.embedding'),
@@ -363,6 +411,52 @@ export default {
};
},
+ groupedProviders() {
+ if (!this.config_data.provider) {
+ return [];
+ }
+
+ const typeOrder = [
+ 'chat_completion',
+ 'agent_runner',
+ 'speech_to_text',
+ 'text_to_speech',
+ 'embedding',
+ 'rerank',
+ ];
+
+ const assigned = new Set();
+ const groups = typeOrder
+ .map((typeKey) => {
+ const items = this.config_data.provider.filter((provider) => {
+ const resolved = this.getProviderType(provider);
+ if (resolved === typeKey) {
+ assigned.add(provider.id);
+ return true;
+ }
+ return false;
+ });
+ return {
+ typeKey,
+ label: this.messages.tabTypes[typeKey] || typeKey,
+ items,
+ };
+ })
+ .filter((group) => group.items.length > 0);
+
+ const remaining = this.config_data.provider.filter(
+ (provider) => !assigned.has(provider.id),
+ );
+ if (remaining.length > 0) {
+ groups.push({
+ typeKey: 'others',
+ label: this.tm('providers.tabs.all'),
+ items: remaining,
+ });
+ }
+ return groups;
+ },
+
// 根据选择的标签过滤提供商列表
filteredProviders() {
if (!this.config_data.provider || this.activeProviderTypeTab === 'all') {
@@ -371,13 +465,7 @@ export default {
return this.config_data.provider.filter(provider => {
// 如果provider.provider_type已经存在,直接使用它
- if (provider.provider_type) {
- return provider.provider_type === this.activeProviderTypeTab;
- }
-
- // 否则使用映射关系
- const mappedType = this.oldVersionProviderTypeMapping[provider.type];
- return mappedType === this.activeProviderTypeTab;
+ return this.getProviderType(provider) === this.activeProviderTypeTab;
});
}
},
@@ -387,6 +475,14 @@ export default {
},
methods: {
+ getProviderType(provider) {
+ if (!provider) return undefined;
+ if (provider.provider_type) {
+ return provider.provider_type;
+ }
+ return this.oldVersionProviderTypeMapping[provider.type];
+ },
+
getConfig() {
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
@@ -666,11 +762,14 @@ export default {
return this.testingProviders.includes(providerId);
},
+ getProviderStatus(providerId) {
+ return this.providerStatuses.find(s => s.id === providerId);
+ },
+
async testSingleProvider(provider) {
if (this.isProviderTesting(provider.id)) return;
this.testingProviders.push(provider.id);
- this.showStatus = true; // 自动展开状态部分
// 更新UI为pending状态
const statusIndex = this.providerStatuses.findIndex(s => s.id === provider.id);
@@ -690,6 +789,11 @@ export default {
if (!provider.enable) {
throw new Error('该提供商未被用户启用');
}
+ if (provider.provider_type === 'agent_runner') {
+ this.showAgentRunnerDialog = true;
+ this.providerStatuses = this.providerStatuses.filter(s => s.id !== provider.id);
+ return;
+ }
const res = await axios.get(`/api/config/provider/check_one?id=${provider.id}`);
if (res.data && res.data.status === 'ok') {
@@ -750,6 +854,10 @@ export default {
}
this.showIdConflictDialog = false;
},
+ goToConfigPage() {
+ this.showAgentRunnerDialog = false;
+ this.$router.push({ name: 'Configs' });
+ },
getStatusColor(status) {
switch (status) {
case 'available':
@@ -774,6 +882,7 @@ export default {
.provider-page {
padding: 20px;
padding-top: 8px;
+ padding-bottom: 40px;
}
.status-card {
diff --git a/dashboard/src/views/SessionManagementPage.vue b/dashboard/src/views/SessionManagementPage.vue
index 8adf1f221..0e2b59555 100644
--- a/dashboard/src/views/SessionManagementPage.vue
+++ b/dashboard/src/views/SessionManagementPage.vue
@@ -3,16 +3,22 @@
- {{ tm('sessions.activeSessions') }}
- {{ totalItems }} {{ tm('sessions.sessionCount') }}
+ {{ tm('customRules.title') }}
+
+ {{ totalItems }} {{ tm('customRules.rulesCount') }}
-
+ hide-details clearable variant="solo-filled" flat class="me-4" density="compact">
-
+ {{ tm('buttons.batchDelete') }} ({{ selectedItems.length }})
+
+
+ {{ tm('buttons.addRule') }}
+
+
{{ tm('buttons.refresh') }}
@@ -21,442 +27,313 @@
-
-
+
-
-
- updateSessionStatus(item, value)" :loading="item.updating" hide-details
- density="compact" color="success">
-
-
-
-
-
+
+
+
+ {{ item.platform || 'unknown' }}
+
+
{{ item.umo }}
+
+
+ ({{ item.rules?.session_service_config?.custom_name }})
+
+
+ mdi-pencil-outline
+ {{ tm('buttons.editCustomName') }}
+
+
-
-
- {{ item.session_name }}
- ({{ item.session_id }})
-
-
- {{ item.session_id }}
-
+
+ mdi-information-outline
-
使用 /sid 指令可查看会话 ID。
-
会话信息:
-
- 机器人 ID: {{ item.platform }}
- 消息类型: {{ item.message_type }}
- 会话 ID: {{ item.session_raw_name }}
- 用户: {{ item.user_name }}
-
+
UMO: {{ item.umo }}
+
平台: {{ item.platform }}
+
消息类型: {{ item.message_type }}
+
会话 ID: {{ item.session_id }}
-
-
- mdi-pencil
-
- {{ tm('buttons.editName') }}
-
-
+
-
-
- updatePersona(item, value)" :loading="item.updating"
- :disabled="!item.session_enabled">
-
- {{ selection.raw.label }}
-
-
-
-
-
-
- updateProvider(item, value, 'chat_completion')" :loading="item.updating"
- :disabled="!item.session_enabled">
-
- {{ selection.raw.label }}
-
-
-
-
-
-
- updateProvider(item, value, 'speech_to_text')" :loading="item.updating"
- :disabled="sttProviderOptions.length === 0 || !item.session_enabled">
-
- {{ selection.raw.label }}
-
-
-
-
-
-
- updateProvider(item, value, 'text_to_speech')" :loading="item.updating"
- :disabled="ttsProviderOptions.length === 0 || !item.session_enabled">
-
- {{ selection.raw.label }}
-
-
-
-
-
- updateLLM(item, value)"
- :loading="item.updating" :disabled="!item.session_enabled" hide-details density="compact"
- color="primary">
-
-
-
-
-
- updateTTS(item, value)"
- :loading="item.updating" :disabled="!item.session_enabled" hide-details density="compact"
- color="secondary">
-
-
-
-
-
-
- {{ tm('knowledgeBase.configure') }}
-
-
-
-
-
-
- {{ tm('buttons.edit') }}
-
+
+
+
+
+ {{ tm('customRules.serviceConfig') }}
+
+
+ {{ tm('customRules.pluginConfig') }}
+
+
+ {{ tm('customRules.kbConfig') }}
+
+
+ {{ tm('customRules.providerConfig') }}
+
+
-
+
+ mdi-pencil
+ {{ tm('buttons.editRule') }}
+
+
mdi-delete
-
- {{ tm('buttons.delete') }}
-
+ {{ tm('buttons.deleteAllRules') }}
-
mdi-account-group-outline
-
{{ tm('sessions.noActiveSessions') }}
-
{{ tm('sessions.noActiveSessionsDesc') }}
+
mdi-file-document-edit-outline
+
{{ tm('customRules.noRules') }}
+
{{ tm('customRules.noRulesDesc') }}
+
+ mdi-plus
+ {{ tm('buttons.addRule') }}
+
-
-
-
- {{ tm('batchOperations.title') }}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- mdi-check-all
- {{ tm('buttons.apply') }}
-
-
-
-
-
-
-
-
-
-
-
-
- {{ tm('pluginManagement.title') }} - {{ selectedSessionForPlugin.session_name }}
+
+
+
+
+ {{ tm('addRule.title') }}
-
+
mdi-close
-
-
-
-
mdi-puzzle-outline
-
{{ tm('pluginManagement.noPlugins') }}
-
{{ tm('pluginManagement.noPluginsDesc') }}
-
+
+
+ {{ tm('addRule.description') }}
+
-
-
-
-
- {{ plugin.enabled ? 'mdi-check-circle' : 'mdi-circle-outline' }}
-
-
-
-
- {{ plugin.name }}
-
-
-
- {{ tm('pluginManagement.author') }}: {{ plugin.author }}
-
-
-
- togglePlugin(plugin, value)"
- :loading="plugin.updating">
-
-
-
-
-
-
-
-
-
- {{ tm('pluginManagement.loading') }}
-
-
-
-
-
-
-
-
- {{ tm('nameEditor.title') }}
-
-
- mdi-close
-
-
-
-
-
-
-
-
- {{ tm('nameEditor.originalName') }}: {{ selectedSessionForName.session_raw_name }}
-
-
-
- {{ tm('nameEditor.fullSessionId') }}: {{ selectedSessionForName.session_id }}
-
-
-
- {{ tm('nameEditor.hint') }}
-
-
+
-
- {{ tm('buttons.cancel') }}
-
-
- {{ tm('buttons.save') }}
+ {{ tm('buttons.cancel') }}
+
+ {{ tm('buttons.next') }}
-
- { if (!val) closeKBDialog(); }">
-
-
- {{ tm('knowledgeBase.title') }} - {{ selectedSessionForKB.session_name }}
+
+
+
+
+ {{ tm('ruleEditor.title') }}
+
+ {{ selectedUmo.umo }}
+
-
- mdi-close
-
+
-
-
-
- {{ tm('knowledgeBase.description') }}
-
-
-
-
-
{{ tm('knowledgeBase.selectKB') }}
-
-
- {{ tm('knowledgeBase.noKBAvailable') || '暂无可用知识库' }}
-
-
-
-
-
-
{{ kb.emoji }}
-
-
{{ kb.kb_name }}
-
- {{ kb.description || tm('knowledgeBase.noKBDesc') }} - {{ kb.doc_count }} {{ tm('list.documents', { count: kb.doc_count }) }}
-
-
-
-
-
-
-
-
- {{ tm('knowledgeBase.selectMultiple') }}
-
+
+
+
+
+
{{ tm('ruleEditor.serviceConfig.title') }}
-
-
-
-
- mdi-cog
- {{ tm('knowledgeBase.advancedSettings') }}
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
mdi-database-off
-
{{ tm('knowledgeBase.noKBAvailable') }}
-
- {{ tm('knowledgeBase.createKB') }}
+
+
+ {{ tm('buttons.save') }}
+
+
+
+
+
+
{{ tm('ruleEditor.providerConfig.title') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ tm('buttons.save') }}
+
+
+
+
+
+
{{ tm('ruleEditor.personaConfig.title') }}
+
+
+
+
+
+
+
+
+ {{ tm('ruleEditor.personaConfig.hint') }}
+
+
+
+
+
+
+ {{ tm('buttons.save') }}
+
+
+
+
+
+
{{ tm('ruleEditor.pluginConfig.title') }}
+
+
+
+
+
+
+
+
+ {{ tm('ruleEditor.pluginConfig.hint') }}
+
+
+
+
+
+
+ {{ tm('buttons.save') }}
+
+
+
+
+
+
{{ tm('ruleEditor.kbConfig.title') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ tm('buttons.save') }}
+
+
-
-
- {{ tm('knowledgeBase.loading') }}
+
+
+
+ {{ tm('deleteConfirm.title') }}
+
+ {{ tm('deleteConfirm.message') }}
+
+ {{ deleteTarget?.umo }}
+
+
+ {{ tm('buttons.cancel') }}
+ {{ tm('buttons.delete')
+ }}
+
+
+
-
-
-
-
- {{ tm('knowledgeBase.clearConfig') }}
-
-
-
- {{ tm('knowledgeBase.cancel') }}
-
-
- {{ tm('knowledgeBase.save') }}
+
+
+
+ {{ tm('batchDeleteConfirm.title') }}
+
+ {{ tm('batchDeleteConfirm.message', { count: selectedItems.length }) }}
+
+
+ {{ item.rules?.session_service_config?.custom_name || item.umo }}
+
+
+
+
+
+ {{ tm('buttons.cancel') }}
+
+ {{ tm('buttons.delete') }}
@@ -466,13 +343,30 @@
{{ snackbarText }}
+
+
+
+
+ {{ tm('quickEditName.title') }}
+
+
+
+
+
+ {{ tm('buttons.cancel') }}
+
+ {{ tm('buttons.save') }}
+
+
+
+
diff --git a/dashboard/src/views/knowledge-base/components/DocumentsTab.vue b/dashboard/src/views/knowledge-base/components/DocumentsTab.vue
index 9e146e86a..bf110f282 100644
--- a/dashboard/src/views/knowledge-base/components/DocumentsTab.vue
+++ b/dashboard/src/views/knowledge-base/components/DocumentsTab.vue
@@ -57,7 +57,7 @@
-
+
{{ t('upload.title') }}
@@ -67,40 +67,91 @@
-
-
-
-
mdi-cloud-upload
-
{{ t('upload.dropzone') }}
-
{{ t('upload.supportedFormats') }}.txt, .md, .pdf, .docx,
- .xls, .xlsx
-
{{ t('upload.maxSize') }}
-
最多可上传 10 个文件
-
-
+
+ {{ t('upload.fileUpload') }}
+
+ {{ t('upload.fromUrl') }}
+
+
+
-
-
- 已选择 {{ selectedFiles.length }} 个文件
- 清空
-
-
-
-
-
-
{{ getFileIcon(file.name) }}
-
-
{{ file.name }}
-
{{ formatFileSize(file.size) }}
+
+
+
+
+
+
+
mdi-cloud-upload
+
{{ t('upload.dropzone') }}
+
{{ t('upload.supportedFormats') }}.txt, .md, .pdf,
+ .docx,
+ .xls, .xlsx
+
{{ t('upload.maxSize') }}
+
最多可上传 10 个文件
+
+
+
+
+
+ 已选择 {{ selectedFiles.length }} 个文件
+ 清空
+
+
+
+
+
+
{{ getFileIcon(file.name) }}
+
+
{{ file.name }}
+
{{ formatFileSize(file.size) }}
+
+
+
-
+
+
+
+
+
+
+
+
+
+ {{ tavilyConfigStatus === 'error' ? '检查网页搜索配置失败' : '使用此功能需要配置 Tavily Key' }}
+
+
+ 配置
+
+
+
+
+
+
+
+
+
+
+
+
+
{{ t('upload.cleaningSettings') }}
+
+
+
+
+
+
+
+
@@ -151,8 +202,8 @@
{{ t('upload.cancel') }}
-
+
{{ t('upload.submit') }}
@@ -185,11 +236,15 @@
{{ snackbar.text }}
+
+
+
\ No newline at end of file
diff --git a/k8s/astrbot/00-namespace.yaml b/k8s/astrbot/00-namespace.yaml
new file mode 100644
index 000000000..547118bd8
--- /dev/null
+++ b/k8s/astrbot/00-namespace.yaml
@@ -0,0 +1,4 @@
+apiVersion: v1
+kind: Namespace
+metadata:
+ name: astrbot-standalone-ns
\ No newline at end of file
diff --git a/k8s/astrbot/01-pvc.yaml b/k8s/astrbot/01-pvc.yaml
new file mode 100644
index 000000000..a219aa0f6
--- /dev/null
+++ b/k8s/astrbot/01-pvc.yaml
@@ -0,0 +1,14 @@
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: astrbot-data-pvc
+ namespace: astrbot-standalone-ns
+ labels:
+ app: astrbot-standalone
+spec:
+ accessModes:
+ - ReadWriteOnce
+ resources:
+ requests:
+ storage: 10Gi
+ # storageClassName: standard # uncomment and set proper StorageClass
\ No newline at end of file
diff --git a/k8s/astrbot/02-deployment.yaml b/k8s/astrbot/02-deployment.yaml
new file mode 100644
index 000000000..d2799ab90
--- /dev/null
+++ b/k8s/astrbot/02-deployment.yaml
@@ -0,0 +1,49 @@
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: astrbot-standalone
+ namespace: astrbot-standalone-ns
+ labels:
+ app: astrbot-standalone
+spec:
+ replicas: 1
+ strategy:
+ type: Recreate
+ selector:
+ matchLabels:
+ app: astrbot-standalone
+ template:
+ metadata:
+ labels:
+ app: astrbot-standalone
+ spec:
+ containers:
+ - name: astrbot
+ image: soulter/astrbot:latest
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: TZ
+ value: "Asia/Shanghai"
+ ports:
+ - containerPort: 6185
+ name: webui
+ - containerPort: 6199
+ name: qq-ws
+ # - containerPort: 6195
+ # name: wecom-wh
+ # - containerPort: 6196
+ # name: qq-off-wh
+ volumeMounts:
+ - name: data
+ mountPath: /AstrBot/data
+ - name: localtime
+ mountPath: /etc/localtime
+ readOnly: true
+ volumes:
+ - name: data
+ persistentVolumeClaim:
+ claimName: astrbot-data-pvc
+ - name: localtime
+ hostPath:
+ path: /etc/localtime
+ type: File
\ No newline at end of file
diff --git a/k8s/astrbot/03-service-nodeport.yaml b/k8s/astrbot/03-service-nodeport.yaml
new file mode 100644
index 000000000..7342bd97a
--- /dev/null
+++ b/k8s/astrbot/03-service-nodeport.yaml
@@ -0,0 +1,28 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: astrbot-standalone-nodeport
+ namespace: astrbot-standalone-ns
+ labels:
+ app: astrbot-standalone
+spec:
+ type: NodePort
+ selector:
+ app: astrbot-standalone
+ ports:
+ - name: webui
+ port: 6185
+ targetPort: 6185
+ nodePort: 30185
+ - name: qq-ws
+ port: 6199
+ targetPort: 6199
+ nodePort: 30199
+ # - name: wecom-wh
+ # port: 6195
+ # targetPort: 6195
+ # nodePort: 30195
+ # - name: qq-off-wh
+ # port: 6196
+ # targetPort: 6196
+ # nodePort: 30196
\ No newline at end of file
diff --git a/k8s/astrbot/04-service-loadbalancer.yaml b/k8s/astrbot/04-service-loadbalancer.yaml
new file mode 100644
index 000000000..f841594d4
--- /dev/null
+++ b/k8s/astrbot/04-service-loadbalancer.yaml
@@ -0,0 +1,24 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: astrbot-standalone-lb
+ namespace: astrbot-standalone-ns
+ labels:
+ app: astrbot-standalone
+spec:
+ type: LoadBalancer
+ selector:
+ app: astrbot-standalone
+ ports:
+ - name: webui
+ port: 6185
+ targetPort: 6185
+ - name: qq-ws
+ port: 6199
+ targetPort: 6199
+ # - name: wecom-wh
+ # port: 6195
+ # targetPort: 6195
+ # - name: qq-off-wh
+ # port: 6196
+ # targetPort: 6196
\ No newline at end of file
diff --git a/k8s/astrbot_with_napcat/00-namespace.yaml b/k8s/astrbot_with_napcat/00-namespace.yaml
new file mode 100644
index 000000000..1e6ab5016
--- /dev/null
+++ b/k8s/astrbot_with_napcat/00-namespace.yaml
@@ -0,0 +1,4 @@
+apiVersion: v1
+kind: Namespace
+metadata:
+ name: astrbot-ns
\ No newline at end of file
diff --git a/k8s/astrbot_with_napcat/01-pvc.yaml b/k8s/astrbot_with_napcat/01-pvc.yaml
new file mode 100644
index 000000000..8efd67f9a
--- /dev/null
+++ b/k8s/astrbot_with_napcat/01-pvc.yaml
@@ -0,0 +1,46 @@
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: astrbot-data-shared-pvc
+ namespace: astrbot-ns
+ labels:
+ app: astrbot-stack
+spec:
+ accessModes:
+ - ReadWriteMany
+ resources:
+ requests:
+ storage: 10Gi
+ # storageClassName: nfs-client # Uncomment and set your RWX storage class if needed
+
+---
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: napcat-config-pvc
+ namespace: astrbot-ns
+ labels:
+ app: astrbot-stack
+spec:
+ accessModes:
+ - ReadWriteOnce
+ resources:
+ requests:
+ storage: 5Gi
+ # storageClassName: standard
+
+---
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: napcat-qq-pvc
+ namespace: astrbot-ns
+ labels:
+ app: astrbot-stack
+spec:
+ accessModes:
+ - ReadWriteOnce
+ resources:
+ requests:
+ storage: 5Gi
+ # storageClassName: standard
\ No newline at end of file
diff --git a/k8s/astrbot_with_napcat/02-deployment.yaml b/k8s/astrbot_with_napcat/02-deployment.yaml
new file mode 100644
index 000000000..53bf98db2
--- /dev/null
+++ b/k8s/astrbot_with_napcat/02-deployment.yaml
@@ -0,0 +1,64 @@
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: astrbot-stack
+ namespace: astrbot-ns
+ labels:
+ app: astrbot-stack
+spec:
+ replicas: 1
+ strategy:
+ type: Recreate # Use Recreate strategy for stateful applications
+ selector:
+ matchLabels:
+ app: astrbot-stack
+ template:
+ metadata:
+ labels:
+ app: astrbot-stack
+ spec:
+ containers:
+ - name: napcat
+ image: mlikiowa/napcat-docker:latest
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: NAPCAT_UID
+ value: "1000"
+ - name: NAPCAT_GID
+ value: "1000"
+ - name: MODE
+ value: "astrbot"
+ ports:
+ - containerPort: 6099
+ name: napcat-web
+ volumeMounts:
+ - name: shared-data
+ mountPath: /AstrBot/data
+ - name: napcat-config
+ mountPath: /app/napcat/config
+ - name: napcat-qq
+ mountPath: /app/.config/QQ
+
+ - name: astrbot
+ image: soulter/astrbot:latest
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: TZ
+ value: "Asia/Shanghai"
+ ports:
+ - containerPort: 6185
+ name: astrbot-web
+ volumeMounts:
+ - name: shared-data
+ mountPath: /AstrBot/data
+
+ volumes:
+ - name: shared-data
+ persistentVolumeClaim:
+ claimName: astrbot-data-shared-pvc
+ - name: napcat-config
+ persistentVolumeClaim:
+ claimName: napcat-config-pvc
+ - name: napcat-qq
+ persistentVolumeClaim:
+ claimName: napcat-qq-pvc
\ No newline at end of file
diff --git a/k8s/astrbot_with_napcat/03-service-nodeport.yaml b/k8s/astrbot_with_napcat/03-service-nodeport.yaml
new file mode 100644
index 000000000..2bd2f333c
--- /dev/null
+++ b/k8s/astrbot_with_napcat/03-service-nodeport.yaml
@@ -0,0 +1,20 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: astrbot-service-nodeport
+ namespace: astrbot-ns
+ labels:
+ app: astrbot-stack
+spec:
+ type: NodePort
+ selector:
+ app: astrbot-stack
+ ports:
+ - name: napcat-web
+ port: 6099
+ targetPort: 6099
+ # nodePort: 30099 # Optional: Specify a fixed NodePort if needed, otherwise remove this line
+ - name: astrbot-web
+ port: 6185
+ targetPort: 6185
+ # nodePort: 30185 # Optional: Specify a fixed NodePort if needed, otherwise remove this line
\ No newline at end of file
diff --git a/k8s/astrbot_with_napcat/04-service-loadbalancer.yaml b/k8s/astrbot_with_napcat/04-service-loadbalancer.yaml
new file mode 100644
index 000000000..b519b3c9b
--- /dev/null
+++ b/k8s/astrbot_with_napcat/04-service-loadbalancer.yaml
@@ -0,0 +1,18 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: astrbot-service-lb
+ namespace: astrbot-ns
+ labels:
+ app: astrbot-stack
+spec:
+ type: LoadBalancer
+ selector:
+ app: astrbot-stack
+ ports:
+ - name: napcat-web
+ port: 6099
+ targetPort: 6099
+ - name: astrbot-web
+ port: 6185
+ targetPort: 6185
\ No newline at end of file
diff --git a/packages/astrbot/commands/help.py b/packages/astrbot/commands/help.py
deleted file mode 100644
index 7f5b6c170..000000000
--- a/packages/astrbot/commands/help.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import aiohttp
-
-from astrbot.api import star
-from astrbot.api.event import AstrMessageEvent, MessageEventResult
-from astrbot.core.config.default import VERSION
-from astrbot.core.utils.io import get_dashboard_version
-
-
-class HelpCommand:
- def __init__(self, context: star.Context):
- self.context = context
-
- async def _query_astrbot_notice(self):
- try:
- async with aiohttp.ClientSession(trust_env=True) as session:
- async with session.get(
- "https://astrbot.app/notice.json",
- timeout=2,
- ) as resp:
- return (await resp.json())["notice"]
- except BaseException:
- return ""
-
- async def help(self, event: AstrMessageEvent):
- """查看帮助"""
- notice = ""
- try:
- notice = await self._query_astrbot_notice()
- except BaseException:
- pass
-
- dashboard_version = await get_dashboard_version()
-
- msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
-内置指令:
-[System]
-/plugin: 查看插件、插件帮助
-/t2i: 开关文本转图片
-/tts: 开关文本转语音
-/sid: 获取会话 ID
-/op: 管理员
-/wl: 白名单
-/dashboard_update: 更新管理面板(op)
-/alter_cmd: 设置指令权限(op)
-
-[大模型]
-/llm: 开启/关闭 LLM
-/provider: 大模型提供商
-/model: 模型列表
-/ls: 对话列表
-/new: 创建新对话
-/groupnew 群号: 为群聊创建新对话(op)
-/switch 序号: 切换对话
-/rename 新名字: 重命名当前对话
-/del: 删除当前会话对话(op)
-/reset: 重置 LLM 会话
-/history: 当前对话的对话记录
-/persona: 人格情景(op)
-/key: API Key(op)
-/websearch: 网页搜索
-{notice}"""
-
- event.set_result(MessageEventResult().message(msg).use_t2i(False))
diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py
index ceca60ef7..610995db2 100644
--- a/packages/astrbot/long_term_memory.py
+++ b/packages/astrbot/long_term_memory.py
@@ -6,9 +6,9 @@ from collections import defaultdict
from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent
-from astrbot.api.message_components import Image, Plain
+from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import MessageType
-from astrbot.api.provider import Provider, ProviderRequest
+from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
"""
@@ -30,16 +30,13 @@ class LongTermMemory:
except BaseException as e:
logger.error(e)
max_cnt = 300
- image_caption = (
- True
- if cfg["provider_settings"]["default_image_caption_provider_id"]
- and cfg["provider_ltm_settings"]["image_caption"]
- else False
- )
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
- image_caption_provider_id = cfg["provider_settings"][
- "default_image_caption_provider_id"
- ]
+ image_caption_provider_id = cfg["provider_ltm_settings"].get(
+ "image_caption_provider_id"
+ )
+ image_caption = cfg["provider_ltm_settings"]["image_caption"] and bool(
+ image_caption_provider_id
+ )
active_reply = cfg["provider_ltm_settings"]["active_reply"]
enable_active_reply = active_reply.get("enable", False)
ar_method = active_reply["method"]
@@ -142,6 +139,8 @@ class LongTermMemory:
logger.error(f"获取图片描述失败: {e}")
else:
parts.append(" [Image]")
+ elif isinstance(comp, At):
+ parts.append(f" [At: {comp.name}]")
final_message = "".join(parts)
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
@@ -159,8 +158,12 @@ class LongTermMemory:
cfg = self.cfg(event)
if cfg["enable_active_reply"]:
prompt = req.prompt
- req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
- req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
+ req.prompt = (
+ f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
+ f"\nNow, a new message is coming: `{prompt}`. "
+ "Please react to it. Only output your response and do not output any other information. "
+ "You MUST use the SAME language as the chatroom is using."
+ )
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
else:
req.system_prompt += (
@@ -168,13 +171,15 @@ class LongTermMemory:
)
req.system_prompt += chats_str
- async def after_req_llm(self, event: AstrMessageEvent):
+ async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
if event.unified_msg_origin not in self.session_chats:
return
- if event.get_result() and event.get_result().is_llm_result():
- final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
- logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
+ if llm_resp.completion_text:
+ final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}"
+ logger.debug(
+ f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
+ )
self.session_chats[event.unified_msg_origin].append(final_message)
cfg = self.cfg(event)
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py
index 8b33b887d..09859ab95 100644
--- a/packages/astrbot/main.py
+++ b/packages/astrbot/main.py
@@ -3,25 +3,9 @@ import traceback
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, filter
from astrbot.api.message_components import Image, Plain
-from astrbot.api.provider import ProviderRequest
+from astrbot.api.provider import LLMResponse, ProviderRequest
from astrbot.core import logger
-from astrbot.core.provider.sources.dify_source import ProviderDify
-from .commands import (
- AdminCommands,
- AlterCmdCommands,
- ConversationCommands,
- HelpCommand,
- LLMCommands,
- PersonaCommands,
- PluginCommands,
- ProviderCommands,
- SetUnsetCommands,
- SIDCommand,
- T2ICommand,
- ToolCommands,
- TTSCommand,
-)
from .long_term_memory import LongTermMemory
from .process_llm_request import ProcessLLMRequest
@@ -35,19 +19,6 @@ class Main(star.Star):
except BaseException as e:
logger.error(f"聊天增强 err: {e}")
- self.help_c = HelpCommand(self.context)
- self.llm_c = LLMCommands(self.context)
- self.tool_c = ToolCommands(self.context)
- self.plugin_c = PluginCommands(self.context)
- self.admin_c = AdminCommands(self.context)
- self.conversation_c = ConversationCommands(self.context, self.ltm)
- self.provider_c = ProviderCommands(self.context)
- self.persona_c = PersonaCommands(self.context)
- self.alter_cmd_c = AlterCmdCommands(self.context)
- self.setunset_c = SetUnsetCommands(self.context)
- self.t2i_c = T2ICommand(self.context)
- self.tts_c = TTSCommand(self.context)
- self.sid_c = SIDCommand(self.context)
self.proc_llm_req = ProcessLLMRequest(self.context)
def ltm_enabled(self, event: AstrMessageEvent):
@@ -56,199 +27,6 @@ class Main(star.Star):
]
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
- @filter.command("help")
- async def help(self, event: AstrMessageEvent):
- """查看帮助"""
- await self.help_c.help(event)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("llm")
- async def llm(self, event: AstrMessageEvent):
- """开启/关闭 LLM"""
- await self.llm_c.llm(event)
-
- @filter.command_group("tool")
- def tool(self):
- pass
-
- @tool.command("ls")
- async def tool_ls(self, event: AstrMessageEvent):
- """查看函数工具列表"""
- await self.tool_c.tool_ls(event)
-
- @tool.command("on")
- async def tool_on(self, event: AstrMessageEvent, tool_name: str):
- """启用一个函数工具"""
- await self.tool_c.tool_on(event, tool_name)
-
- @tool.command("off")
- async def tool_off(self, event: AstrMessageEvent, tool_name: str):
- """停用一个函数工具"""
- await self.tool_c.tool_off(event, tool_name)
-
- @tool.command("off_all")
- async def tool_all_off(self, event: AstrMessageEvent):
- """停用所有函数工具"""
- await self.tool_c.tool_all_off(event)
-
- @filter.command_group("plugin")
- def plugin(self):
- pass
-
- @plugin.command("ls")
- async def plugin_ls(self, event: AstrMessageEvent):
- """获取已经安装的插件列表。"""
- await self.plugin_c.plugin_ls(event)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @plugin.command("off")
- async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
- """禁用插件"""
- await self.plugin_c.plugin_off(event, plugin_name)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @plugin.command("on")
- async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
- """启用插件"""
- await self.plugin_c.plugin_on(event, plugin_name)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @plugin.command("get")
- async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
- """安装插件"""
- await self.plugin_c.plugin_get(event, plugin_repo)
-
- @plugin.command("help")
- async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
- """获取插件帮助"""
- await self.plugin_c.plugin_help(event, plugin_name)
-
- @filter.command("t2i")
- async def t2i(self, event: AstrMessageEvent):
- """开关文本转图片"""
- await self.t2i_c.t2i(event)
-
- @filter.command("tts")
- async def tts(self, event: AstrMessageEvent):
- """开关文本转语音(会话级别)"""
- await self.tts_c.tts(event)
-
- @filter.command("sid")
- async def sid(self, event: AstrMessageEvent):
- """获取会话 ID 和 管理员 ID"""
- await self.sid_c.sid(event)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("op")
- async def op(self, event: AstrMessageEvent, admin_id: str = ""):
- """授权管理员。op
"""
- await self.admin_c.op(event, admin_id)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("deop")
- async def deop(self, event: AstrMessageEvent, admin_id: str):
- """取消授权管理员。deop """
- await self.admin_c.deop(event, admin_id)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("wl")
- async def wl(self, event: AstrMessageEvent, sid: str = ""):
- """添加白名单。wl """
- await self.admin_c.wl(event, sid)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("dwl")
- async def dwl(self, event: AstrMessageEvent, sid: str):
- """删除白名单。dwl """
- await self.admin_c.dwl(event, sid)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("provider")
- async def provider(
- self,
- event: AstrMessageEvent,
- idx: str | int | None = None,
- idx2: int | None = None,
- ):
- """查看或者切换 LLM Provider"""
- await self.provider_c.provider(event, idx, idx2)
-
- @filter.command("reset")
- async def reset(self, message: AstrMessageEvent):
- """重置 LLM 会话"""
- await self.conversation_c.reset(message)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("model")
- async def model_ls(
- self,
- message: AstrMessageEvent,
- idx_or_name: int | str | None = None,
- ):
- """查看或者切换模型"""
- await self.provider_c.model_ls(message, idx_or_name)
-
- @filter.command("history")
- async def his(self, message: AstrMessageEvent, page: int = 1):
- """查看对话记录"""
- await self.conversation_c.his(message, page)
-
- @filter.command("ls")
- async def convs(self, message: AstrMessageEvent, page: int = 1):
- """查看对话列表"""
- await self.conversation_c.convs(message, page)
-
- @filter.command("new")
- async def new_conv(self, message: AstrMessageEvent):
- """创建新对话"""
- await self.conversation_c.new_conv(message)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("groupnew")
- async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
- """创建新群聊对话"""
- await self.conversation_c.groupnew_conv(message, sid)
-
- @filter.command("switch")
- async def switch_conv(self, message: AstrMessageEvent, index: int | None = None):
- """通过 /ls 前面的序号切换对话"""
- await self.conversation_c.switch_conv(message, index)
-
- @filter.command("rename")
- async def rename_conv(self, message: AstrMessageEvent, new_name: str):
- """重命名对话"""
- await self.conversation_c.rename_conv(message, new_name)
-
- @filter.command("del")
- async def del_conv(self, message: AstrMessageEvent):
- """删除当前对话"""
- await self.conversation_c.del_conv(message)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("key")
- async def key(self, message: AstrMessageEvent, index: int | None = None):
- """查看或者切换 Key"""
- await self.provider_c.key(message, index)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("persona")
- async def persona(self, message: AstrMessageEvent):
- """查看或者切换 Persona"""
- await self.persona_c.persona(message)
-
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("dashboard_update")
- async def update_dashboard(self, event: AstrMessageEvent):
- await self.admin_c.update_dashboard(event)
-
- @filter.command("set")
- async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
- await self.setunset_c.set_variable(event, key, value)
-
- @filter.command("unset")
- async def unset_variable(self, event: AstrMessageEvent, key: str):
- await self.setunset_c.unset_variable(event, key)
-
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
async def on_message(self, event: AstrMessageEvent):
"""群聊记忆增强"""
@@ -279,33 +57,20 @@ class Main(star.Star):
return
try:
conv = None
- if provider.meta().type != "dify":
- session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
- event.unified_msg_origin,
- )
+ session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
+ event.unified_msg_origin,
+ )
- if not session_curr_cid:
- logger.error(
- "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
- )
- return
+ if not session_curr_cid:
+ logger.error(
+ "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
+ )
+ return
- conv = await self.context.conversation_manager.get_conversation(
- event.unified_msg_origin,
- session_curr_cid,
- )
- else:
- # Dify 自己有维护对话,不需要 bot 端维护。
- assert isinstance(provider, ProviderDify)
- cid = provider.conversation_ids.get(
- event.unified_msg_origin,
- None,
- )
- if cid is None:
- logger.error(
- "[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
- )
- return
+ conv = await self.context.conversation_manager.get_conversation(
+ event.unified_msg_origin,
+ session_curr_cid,
+ )
prompt = event.message_str
@@ -334,17 +99,30 @@ class Main(star.Star):
except BaseException as e:
logger.error(f"ltm: {e}")
- @filter.after_message_sent()
- async def after_llm_req(self, event: AstrMessageEvent):
- """在 LLM 请求后记录对话"""
+ @filter.on_llm_response()
+ async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
+ """在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
+ umo = event.unified_msg_origin
+ cfg = self.context.get_config(umo).get("provider_settings", {})
+ show_reasoning = cfg.get("display_reasoning_text", False)
+ if show_reasoning and resp.reasoning_content:
+ resp.completion_text = (
+ f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
+ )
+
if self.ltm and self.ltm_enabled(event):
try:
- await self.ltm.after_req_llm(event)
+ await self.ltm.after_req_llm(event, resp)
except Exception as e:
logger.error(f"ltm: {e}")
- @filter.permission_type(filter.PermissionType.ADMIN)
- @filter.command("alter_cmd", alias={"alter"})
- async def alter_cmd(self, event: AstrMessageEvent):
- """修改命令权限"""
- await self.alter_cmd_c.alter_cmd(event)
+ @filter.after_message_sent()
+ async def after_message_sent(self, event: AstrMessageEvent):
+ """消息发送后处理"""
+ if self.ltm and self.ltm_enabled(event):
+ try:
+ clean_session = event.get_extra("_clean_ltm_session", False)
+ if clean_session:
+ await self.ltm.remove_session(event)
+ except Exception as e:
+ logger.error(f"ltm: {e}")
diff --git a/packages/astrbot/metadata.yaml b/packages/astrbot/metadata.yaml
index 81b1e5c7f..93affaf70 100644
--- a/packages/astrbot/metadata.yaml
+++ b/packages/astrbot/metadata.yaml
@@ -1,4 +1,4 @@
name: astrbot
-desc: AstrBot 基础指令结合 + 拓展功能
+desc: AstrBot 自带插件,包含人格注入、思考内容注入、群聊上下文感知等功能的实现,禁用后将无法使用这些功能。
author: Soulter
-version: 4.0.0
\ No newline at end of file
+version: 4.1.0
\ No newline at end of file
diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py
index 6d8c896f4..28c41df9f 100644
--- a/packages/astrbot/process_llm_request.py
+++ b/packages/astrbot/process_llm_request.py
@@ -3,7 +3,7 @@ import copy
import datetime
import zoneinfo
-from astrbot.api import logger, star
+from astrbot.api import logger, sp, star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import Image, Reply
from astrbot.api.provider import Provider, ProviderRequest
@@ -21,16 +21,26 @@ class ProcessLLMRequest:
else:
logger.info(f"Timezone set to: {self.timezone}")
- def _ensure_persona(self, req: ProviderRequest, cfg: dict):
+ async def _ensure_persona(self, req: ProviderRequest, cfg: dict, umo: str):
"""确保用户人格已加载"""
if not req.conversation:
return
# persona inject
- persona_id = req.conversation.persona_id or cfg.get("default_personality")
- if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
- default_persona = self.ctx.persona_manager.selected_default_persona_v3
- if default_persona:
- persona_id = default_persona["name"]
+
+ # custom rule is preferred
+ persona_id = (
+ await sp.get_async(
+ scope="umo", scope_id=umo, key="session_service_config", default={}
+ )
+ ).get("persona_id")
+
+ if not persona_id:
+ persona_id = req.conversation.persona_id or cfg.get("default_personality")
+ if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
+ default_persona = self.ctx.persona_manager.selected_default_persona_v3
+ if default_persona:
+ persona_id = default_persona["name"]
+
persona = next(
builtins.filter(
lambda persona: persona["name"] == persona_id,
@@ -129,6 +139,11 @@ class ProcessLLMRequest:
# group name identifier
if cfg.get("group_name_display") and event.message_obj.group_id:
+ if not event.message_obj.group:
+ logger.error(
+ f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}"
+ )
+ return
group_name = event.message_obj.group.group_name
if group_name:
req.system_prompt += f"\nGroup name: {group_name}\n"
@@ -152,7 +167,7 @@ class ProcessLLMRequest:
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
if req.conversation:
# inject persona for this request
- self._ensure_persona(req, cfg)
+ await self._ensure_persona(req, cfg, event.unified_msg_origin)
# image caption
if img_cap_prov_id and req.image_urls:
diff --git a/packages/astrbot/commands/__init__.py b/packages/builtin_commands/commands/__init__.py
similarity index 100%
rename from packages/astrbot/commands/__init__.py
rename to packages/builtin_commands/commands/__init__.py
diff --git a/packages/astrbot/commands/admin.py b/packages/builtin_commands/commands/admin.py
similarity index 99%
rename from packages/astrbot/commands/admin.py
rename to packages/builtin_commands/commands/admin.py
index 2073f45a2..83d4b5974 100644
--- a/packages/astrbot/commands/admin.py
+++ b/packages/builtin_commands/commands/admin.py
@@ -71,6 +71,7 @@ class AdminCommands:
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
async def update_dashboard(self, event: AstrMessageEvent):
+ """更新管理面板"""
await event.send(MessageChain().message("正在尝试更新管理面板..."))
await download_dashboard(version=f"v{VERSION}", latest=False)
await event.send(MessageChain().message("管理面板更新完成。"))
diff --git a/packages/astrbot/commands/alter_cmd.py b/packages/builtin_commands/commands/alter_cmd.py
similarity index 100%
rename from packages/astrbot/commands/alter_cmd.py
rename to packages/builtin_commands/commands/alter_cmd.py
diff --git a/packages/astrbot/commands/conversation.py b/packages/builtin_commands/commands/conversation.py
similarity index 61%
rename from packages/astrbot/commands/conversation.py
rename to packages/builtin_commands/commands/conversation.py
index 9538d8f53..de3d11ac8 100644
--- a/packages/astrbot/commands/conversation.py
+++ b/packages/builtin_commands/commands/conversation.py
@@ -1,20 +1,23 @@
import datetime
-from astrbot.api import logger, sp, star
+from astrbot.api import sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
-from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.message_type import MessageType
-from astrbot.core.provider.sources.coze_source import ProviderCoze
-from astrbot.core.provider.sources.dify_source import ProviderDify
-from ..long_term_memory import LongTermMemory
from .utils.rst_scene import RstScene
+THIRD_PARTY_AGENT_RUNNER_KEY = {
+ "dify": "dify_conversation_id",
+ "coze": "coze_conversation_id",
+ "dashscope": "dashscope_conversation_id",
+}
+THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
+
class ConversationCommands:
- def __init__(self, context: star.Context, ltm: LongTermMemory | None = None):
+ def __init__(self, context: star.Context):
self.context = context
- self.ltm = ltm
async def _get_current_persona_id(self, session_id):
curr = await self.context.conversation_manager.get_curr_conversation_id(
@@ -26,21 +29,15 @@ class ConversationCommands:
session_id,
curr,
)
+ if not conv:
+ return None
return conv.persona_id
- def ltm_enabled(self, event: AstrMessageEvent):
- if not self.ltm:
- return False
- ltmse = self.context.get_config(umo=event.unified_msg_origin)[
- "provider_ltm_settings"
- ]
- return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
-
async def reset(self, message: AstrMessageEvent):
"""重置 LLM 会话"""
- is_unique_session = self.context.get_config()["platform_settings"][
- "unique_session"
- ]
+ umo = message.unified_msg_origin
+ cfg = self.context.get_config(umo=message.unified_msg_origin)
+ is_unique_session = cfg["platform_settings"]["unique_session"]
is_group = bool(message.get_group_id())
scene = RstScene.get_scene(is_group, is_unique_session)
@@ -63,28 +60,23 @@ class ConversationCommands:
)
return
- if not self.context.get_using_provider(message.unified_msg_origin):
+ agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
+ if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
+ await sp.remove_async(
+ scope="umo",
+ scope_id=umo,
+ key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
+ )
+ message.set_result(MessageEventResult().message("重置对话成功。"))
+ return
+
+ if not self.context.get_using_provider(umo):
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
- provider = self.context.get_using_provider(message.unified_msg_origin)
- if provider and provider.meta().type in ["dify", "coze"]:
- assert isinstance(provider, (ProviderDify, ProviderCoze)), (
- "provider type is not dify or coze"
- )
- await provider.forget(message.unified_msg_origin)
- message.set_result(
- MessageEventResult().message(
- "已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。",
- ),
- )
- return
-
- cid = await self.context.conversation_manager.get_curr_conversation_id(
- message.unified_msg_origin,
- )
+ cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
if not cid:
message.set_result(
@@ -95,15 +87,14 @@ class ConversationCommands:
return
await self.context.conversation_manager.update_conversation(
- message.unified_msg_origin,
+ umo,
cid,
[],
)
- ret = "清除会话 LLM 聊天历史成功。"
- if self.ltm and self.ltm_enabled(message):
- cnt = await self.ltm.remove_session(event=message)
- ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
+ ret = "清除聊天历史成功!"
+
+ message.set_extra("_clean_ltm_session", True)
message.set_result(MessageEventResult().message(ret))
@@ -152,29 +143,14 @@ class ConversationCommands:
async def convs(self, message: AstrMessageEvent, page: int = 1):
"""查看对话列表"""
- provider = self.context.get_using_provider(message.unified_msg_origin)
- if provider and provider.meta().type == "dify":
- """原有的Dify处理逻辑保持不变"""
- parts = ["Dify 对话列表:\n"]
- assert isinstance(provider, ProviderDify)
- data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
- idx = 1
- for conv in data["data"]:
- ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime(
- "%m-%d %H:%M",
- )
- parts.append(
- f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
- )
- idx += 1
- if idx == 1:
- parts.append("没有找到任何对话。")
- dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None)
- parts.append(
- f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
+ cfg = self.context.get_config(umo=message.unified_msg_origin)
+ agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
+ if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
+ message.set_result(
+ MessageEventResult().message(
+ f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。",
+ ),
)
- ret = "".join(parts)
- message.set_result(MessageEventResult().message(ret))
return
size_per_page = 6
@@ -227,9 +203,8 @@ class ConversationCommands:
else:
ret += "\n当前对话: 无"
- unique_session = self.context.get_config()["platform_settings"][
- "unique_session"
- ]
+ cfg = self.context.get_config(umo=message.unified_msg_origin)
+ unique_session = cfg["platform_settings"]["unique_session"]
if unique_session:
ret += "\n会话隔离粒度: 个人"
else:
@@ -243,15 +218,15 @@ class ConversationCommands:
async def new_conv(self, message: AstrMessageEvent):
"""创建新对话"""
- provider = self.context.get_using_provider(message.unified_msg_origin)
- if provider and provider.meta().type in ["dify", "coze"]:
- assert isinstance(provider, (ProviderDify, ProviderCoze)), (
- "provider type is not dify or coze"
- )
- await provider.forget(message.unified_msg_origin)
- message.set_result(
- MessageEventResult().message("成功,下次聊天将是新对话。"),
+ cfg = self.context.get_config(umo=message.unified_msg_origin)
+ agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
+ if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
+ await sp.remove_async(
+ scope="umo",
+ scope_id=message.unified_msg_origin,
+ key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
+ message.set_result(MessageEventResult().message("已创建新对话。"))
return
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
@@ -261,12 +236,7 @@ class ConversationCommands:
persona_id=cpersona,
)
- # 长期记忆
- if self.ltm and self.ltm_enabled(message):
- try:
- await self.ltm.remove_session(event=message)
- except Exception as e:
- logger.error(f"清理聊天增强记录失败: {e}")
+ message.set_extra("_clean_ltm_session", True)
message.set_result(
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
@@ -274,19 +244,9 @@ class ConversationCommands:
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
"""创建新群聊对话"""
- provider = self.context.get_using_provider(message.unified_msg_origin)
- if provider and provider.meta().type in ["dify", "coze"]:
- assert isinstance(provider, (ProviderDify, ProviderCoze)), (
- "provider type is not dify or coze"
- )
- await provider.forget(message.unified_msg_origin)
- message.set_result(
- MessageEventResult().message("成功,下次聊天将是新对话。"),
- )
- return
if sid:
session = str(
- MessageSesion(
+ MessageSession(
platform_name=message.platform_meta.id,
message_type=MessageType("GroupMessage"),
session_id=sid,
@@ -321,31 +281,6 @@ class ConversationCommands:
)
return
- provider = self.context.get_using_provider(message.unified_msg_origin)
- if provider and provider.meta().type == "dify":
- assert isinstance(provider, ProviderDify), "provider type is not dify"
- data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
- if not data["data"]:
- message.set_result(MessageEventResult().message("未找到任何对话。"))
- return
- selected_conv = None
- if index is not None:
- try:
- selected_conv = data["data"][index - 1]
- except IndexError:
- message.set_result(
- MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
- )
- return
- else:
- selected_conv = data["data"][0]
- ret = (
- f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。"
- )
- provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"]
- message.set_result(MessageEventResult().message(ret))
- return
-
if index is None:
message.set_result(
MessageEventResult().message(
@@ -378,19 +313,6 @@ class ConversationCommands:
if not new_name:
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
return
-
- provider = self.context.get_using_provider(message.unified_msg_origin)
-
- if provider and provider.meta().type == "dify":
- assert isinstance(provider, ProviderDify)
- cid = provider.conversation_ids.get(message.unified_msg_origin, None)
- if not cid:
- message.set_result(MessageEventResult().message("未找到当前对话。"))
- return
- await provider.api_client.rename(cid, new_name, message.unified_msg_origin)
- message.set_result(MessageEventResult().message("重命名对话成功。"))
- return
-
await self.context.conversation_manager.update_conversation_title(
message.unified_msg_origin,
new_name,
@@ -399,9 +321,8 @@ class ConversationCommands:
async def del_conv(self, message: AstrMessageEvent):
"""删除当前对话"""
- is_unique_session = self.context.get_config()["platform_settings"][
- "unique_session"
- ]
+ cfg = self.context.get_config(umo=message.unified_msg_origin)
+ is_unique_session = cfg["platform_settings"]["unique_session"]
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
message.set_result(
@@ -411,20 +332,14 @@ class ConversationCommands:
)
return
- provider = self.context.get_using_provider(message.unified_msg_origin)
- if provider and provider.meta().type == "dify":
- assert isinstance(provider, ProviderDify)
- dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None)
- if dify_cid:
- await provider.api_client.delete_chat_conv(
- message.unified_msg_origin,
- dify_cid,
- )
- message.set_result(
- MessageEventResult().message(
- "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。",
- ),
+ agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
+ if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
+ await sp.remove_async(
+ scope="umo",
+ scope_id=message.unified_msg_origin,
+ key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
+ message.set_result(MessageEventResult().message("重置对话成功。"))
return
session_curr_cid = (
@@ -447,7 +362,5 @@ class ConversationCommands:
)
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
- if self.ltm and self.ltm_enabled(message):
- cnt = await self.ltm.remove_session(event=message)
- ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
+ message.set_extra("_clean_ltm_session", True)
message.set_result(MessageEventResult().message(ret))
diff --git a/packages/builtin_commands/commands/help.py b/packages/builtin_commands/commands/help.py
new file mode 100644
index 000000000..092fc59ec
--- /dev/null
+++ b/packages/builtin_commands/commands/help.py
@@ -0,0 +1,88 @@
+import aiohttp
+
+from astrbot.api import star
+from astrbot.api.event import AstrMessageEvent, MessageEventResult
+from astrbot.core.config.default import VERSION
+from astrbot.core.star import command_management
+from astrbot.core.utils.io import get_dashboard_version
+
+
+class HelpCommand:
+ def __init__(self, context: star.Context):
+ self.context = context
+
+ async def _query_astrbot_notice(self):
+ try:
+ async with aiohttp.ClientSession(trust_env=True) as session:
+ async with session.get(
+ "https://astrbot.app/notice.json",
+ timeout=2,
+ ) as resp:
+ return (await resp.json())["notice"]
+ except BaseException:
+ return ""
+
+ async def _build_reserved_command_lines(self) -> list[str]:
+ """
+ 使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
+ """
+ try:
+ commands = await command_management.list_commands()
+ except BaseException:
+ return []
+
+ lines: list[str] = []
+ hidden_commands = {"set", "unset", "websearch"}
+
+ def walk(items: list[dict], indent: int = 0):
+ for item in items:
+ if not item.get("reserved") or not item.get("enabled"):
+ continue
+ # 仅展示顶级指令或指令组
+ if item.get("type") == "sub_command":
+ continue
+ if item.get("parent_signature"):
+ continue
+
+ effective = (
+ item.get("effective_command")
+ or item.get("original_command")
+ or item.get("handler_name")
+ )
+ if not effective:
+ continue
+ if effective in hidden_commands:
+ continue
+
+ description = item.get("description") or ""
+ desc_text = f" - {description}" if description else ""
+ indent_prefix = " " * indent
+ lines.append(f"{indent_prefix}/{effective}{desc_text}")
+
+ walk(commands)
+ return lines
+
+ async def help(self, event: AstrMessageEvent):
+ """查看帮助"""
+ notice = ""
+ try:
+ notice = await self._query_astrbot_notice()
+ except BaseException:
+ pass
+
+ dashboard_version = await get_dashboard_version()
+ command_lines = await self._build_reserved_command_lines()
+ commands_section = (
+ "\n".join(command_lines) if command_lines else "暂无启用的内置指令"
+ )
+
+ msg_parts = [
+ f"AstrBot v{VERSION}(WebUI: {dashboard_version})",
+ "内置指令:",
+ commands_section,
+ ]
+ if notice:
+ msg_parts.append(notice)
+ msg = "\n".join(msg_parts)
+
+ event.set_result(MessageEventResult().message(msg).use_t2i(False))
diff --git a/packages/astrbot/commands/llm.py b/packages/builtin_commands/commands/llm.py
similarity index 100%
rename from packages/astrbot/commands/llm.py
rename to packages/builtin_commands/commands/llm.py
diff --git a/packages/astrbot/commands/persona.py b/packages/builtin_commands/commands/persona.py
similarity index 85%
rename from packages/astrbot/commands/persona.py
rename to packages/builtin_commands/commands/persona.py
index 1289cb569..13a57f07f 100644
--- a/packages/astrbot/commands/persona.py
+++ b/packages/builtin_commands/commands/persona.py
@@ -1,6 +1,6 @@
import builtins
-from astrbot.api import star
+from astrbot.api import sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
@@ -17,6 +17,13 @@ class PersonaCommands:
default_persona = await self.context.persona_manager.get_default_persona_v3(
umo=umo,
)
+
+ force_applied_persona_id = (
+ await sp.get_async(
+ scope="umo", scope_id=umo, key="session_service_config", default={}
+ )
+ ).get("persona_id")
+
curr_cid_title = "无"
if cid:
conv = await self.context.conversation_manager.get_conversation(
@@ -36,6 +43,9 @@ class PersonaCommands:
else:
curr_persona_name = conv.persona_id
+ if force_applied_persona_id:
+ curr_persona_name = f"{curr_persona_name} (自定义规则)"
+
curr_cid_title = conv.title if conv.title else "新对话"
curr_cid_title += f"({cid[:4]})"
@@ -113,9 +123,15 @@ class PersonaCommands:
message.unified_msg_origin,
ps,
)
+ force_warn_msg = ""
+ if force_applied_persona_id:
+ force_warn_msg = (
+ "提醒:由于自定义规则,您现在切换的人格将不会生效。"
+ )
+
message.set_result(
MessageEventResult().message(
- "设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。",
+ f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。{force_warn_msg}",
),
)
else:
diff --git a/packages/astrbot/commands/plugin.py b/packages/builtin_commands/commands/plugin.py
similarity index 100%
rename from packages/astrbot/commands/plugin.py
rename to packages/builtin_commands/commands/plugin.py
diff --git a/packages/astrbot/commands/provider.py b/packages/builtin_commands/commands/provider.py
similarity index 59%
rename from packages/astrbot/commands/provider.py
rename to packages/builtin_commands/commands/provider.py
index 8db7324e4..ce8f31831 100644
--- a/packages/astrbot/commands/provider.py
+++ b/packages/builtin_commands/commands/provider.py
@@ -1,5 +1,7 @@
+import asyncio
import re
+from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType
@@ -9,6 +11,39 @@ class ProviderCommands:
def __init__(self, context: star.Context):
self.context = context
+ def _log_reachability_failure(
+ self,
+ provider,
+ provider_capability_type: ProviderType | None,
+ err_code: str,
+ err_reason: str,
+ ):
+ """记录不可达原因到日志。"""
+ meta = provider.meta()
+ logger.warning(
+ "Provider reachability check failed: id=%s type=%s code=%s reason=%s",
+ meta.id,
+ provider_capability_type.name if provider_capability_type else "unknown",
+ err_code,
+ err_reason,
+ )
+
+ async def _test_provider_capability(self, provider):
+ """测试单个 provider 的可用性"""
+ meta = provider.meta()
+ provider_capability_type = meta.provider_type
+
+ try:
+ await provider.test()
+ return True, None, None
+ except Exception as e:
+ err_code = "TEST_FAILED"
+ err_reason = str(e)
+ self._log_reachability_failure(
+ provider, provider_capability_type, err_code, err_reason
+ )
+ return False, err_code, err_reason
+
async def provider(
self,
event: AstrMessageEvent,
@@ -17,46 +52,131 @@ class ProviderCommands:
):
"""查看或者切换 LLM Provider"""
umo = event.unified_msg_origin
+ cfg = self.context.get_config(umo).get("provider_settings", {})
+ reachability_check_enabled = cfg.get("reachability_check", True)
if idx is None:
parts = ["## 载入的 LLM 提供商\n"]
- for idx, llm in enumerate(self.context.get_all_providers()):
- id_ = llm.meta().id
- line = f"{idx + 1}. {id_} ({llm.meta().model})"
+
+ # 获取所有类型的提供商
+ llms = list(self.context.get_all_providers())
+ ttss = self.context.get_all_tts_providers()
+ stts = self.context.get_all_stt_providers()
+
+ # 构造待检测列表: [(provider, type_label), ...]
+ all_providers = []
+ all_providers.extend([(p, "llm") for p in llms])
+ all_providers.extend([(p, "tts") for p in ttss])
+ all_providers.extend([(p, "stt") for p in stts])
+
+ # 并发测试连通性
+ if reachability_check_enabled:
+ if all_providers:
+ await event.send(
+ MessageEventResult().message(
+ "正在进行提供商可达性测试,请稍候..."
+ )
+ )
+ check_results = await asyncio.gather(
+ *[self._test_provider_capability(p) for p, _ in all_providers],
+ return_exceptions=True,
+ )
+ else:
+ # 用 None 表示未检测
+ check_results = [None for _ in all_providers]
+
+ # 整合结果
+ display_data = []
+ for (p, p_type), reachable in zip(all_providers, check_results):
+ meta = p.meta()
+ id_ = meta.id
+ error_code = None
+
+ if isinstance(reachable, Exception):
+ # 异常情况下兜底处理,避免单个 provider 导致列表失败
+ self._log_reachability_failure(
+ p,
+ None,
+ reachable.__class__.__name__,
+ str(reachable),
+ )
+ reachable_flag = False
+ error_code = reachable.__class__.__name__
+ elif isinstance(reachable, tuple):
+ reachable_flag, error_code, _ = reachable
+ else:
+ reachable_flag = reachable
+
+ # 根据类型构建显示名称
+ if p_type == "llm":
+ info = f"{id_} ({meta.model})"
+ else:
+ info = f"{id_}"
+
+ # 确定状态标记
+ if reachable_flag is True:
+ mark = " ✅"
+ elif reachable_flag is False:
+ if error_code:
+ mark = f" ❌(错误码: {error_code})"
+ else:
+ mark = " ❌"
+ else:
+ mark = "" # 不支持检测时不显示标记
+
+ display_data.append(
+ {
+ "type": p_type,
+ "info": info,
+ "mark": mark,
+ "provider": p,
+ }
+ )
+
+ # 分组输出
+ # 1. LLM
+ llm_data = [d for d in display_data if d["type"] == "llm"]
+ for i, d in enumerate(llm_data):
+ line = f"{i + 1}. {d['info']}{d['mark']}"
provider_using = self.context.get_using_provider(umo=umo)
- if provider_using and provider_using.meta().id == id_:
+ if (
+ provider_using
+ and provider_using.meta().id == d["provider"].meta().id
+ ):
line += " (当前使用)"
parts.append(line + "\n")
- tts_providers = self.context.get_all_tts_providers()
- if tts_providers:
+ # 2. TTS
+ tts_data = [d for d in display_data if d["type"] == "tts"]
+ if tts_data:
parts.append("\n## 载入的 TTS 提供商\n")
- for idx, tts in enumerate(tts_providers):
- id_ = tts.meta().id
- line = f"{idx + 1}. {id_}"
+ for i, d in enumerate(tts_data):
+ line = f"{i + 1}. {d['info']}{d['mark']}"
tts_using = self.context.get_using_tts_provider(umo=umo)
- if tts_using and tts_using.meta().id == id_:
+ if tts_using and tts_using.meta().id == d["provider"].meta().id:
line += " (当前使用)"
parts.append(line + "\n")
- stt_providers = self.context.get_all_stt_providers()
- if stt_providers:
+ # 3. STT
+ stt_data = [d for d in display_data if d["type"] == "stt"]
+ if stt_data:
parts.append("\n## 载入的 STT 提供商\n")
- for idx, stt in enumerate(stt_providers):
- id_ = stt.meta().id
- line = f"{idx + 1}. {id_}"
+ for i, d in enumerate(stt_data):
+ line = f"{i + 1}. {d['info']}{d['mark']}"
stt_using = self.context.get_using_stt_provider(umo=umo)
- if stt_using and stt_using.meta().id == id_:
+ if stt_using and stt_using.meta().id == d["provider"].meta().id:
line += " (当前使用)"
parts.append(line + "\n")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
ret = "".join(parts)
- if tts_providers:
+ if ttss:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
- if stt_providers:
- ret += "\n使用 /provider stt <切换> STT 提供商。"
+ if stts:
+ ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
+ if not reachability_check_enabled:
+ ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
event.set_result(MessageEventResult().message(ret))
elif idx == "tts":
diff --git a/packages/astrbot/commands/setunset.py b/packages/builtin_commands/commands/setunset.py
similarity index 100%
rename from packages/astrbot/commands/setunset.py
rename to packages/builtin_commands/commands/setunset.py
diff --git a/packages/astrbot/commands/sid.py b/packages/builtin_commands/commands/sid.py
similarity index 100%
rename from packages/astrbot/commands/sid.py
rename to packages/builtin_commands/commands/sid.py
diff --git a/packages/astrbot/commands/t2i.py b/packages/builtin_commands/commands/t2i.py
similarity index 100%
rename from packages/astrbot/commands/t2i.py
rename to packages/builtin_commands/commands/t2i.py
diff --git a/packages/astrbot/commands/tool.py b/packages/builtin_commands/commands/tool.py
similarity index 100%
rename from packages/astrbot/commands/tool.py
rename to packages/builtin_commands/commands/tool.py
diff --git a/packages/astrbot/commands/tts.py b/packages/builtin_commands/commands/tts.py
similarity index 100%
rename from packages/astrbot/commands/tts.py
rename to packages/builtin_commands/commands/tts.py
diff --git a/packages/astrbot/commands/utils/rst_scene.py b/packages/builtin_commands/commands/utils/rst_scene.py
similarity index 100%
rename from packages/astrbot/commands/utils/rst_scene.py
rename to packages/builtin_commands/commands/utils/rst_scene.py
diff --git a/packages/builtin_commands/main.py b/packages/builtin_commands/main.py
new file mode 100644
index 000000000..7809c4359
--- /dev/null
+++ b/packages/builtin_commands/main.py
@@ -0,0 +1,237 @@
+from astrbot.api import star
+from astrbot.api.event import AstrMessageEvent, filter
+
+from .commands import (
+ AdminCommands,
+ AlterCmdCommands,
+ ConversationCommands,
+ HelpCommand,
+ LLMCommands,
+ PersonaCommands,
+ PluginCommands,
+ ProviderCommands,
+ SetUnsetCommands,
+ SIDCommand,
+ T2ICommand,
+ ToolCommands,
+ TTSCommand,
+)
+
+
+class Main(star.Star):
+ def __init__(self, context: star.Context) -> None:
+ self.context = context
+
+ self.help_c = HelpCommand(self.context)
+ self.llm_c = LLMCommands(self.context)
+ self.tool_c = ToolCommands(self.context)
+ self.plugin_c = PluginCommands(self.context)
+ self.admin_c = AdminCommands(self.context)
+ self.conversation_c = ConversationCommands(self.context)
+ self.provider_c = ProviderCommands(self.context)
+ self.persona_c = PersonaCommands(self.context)
+ self.alter_cmd_c = AlterCmdCommands(self.context)
+ self.setunset_c = SetUnsetCommands(self.context)
+ self.t2i_c = T2ICommand(self.context)
+ self.tts_c = TTSCommand(self.context)
+ self.sid_c = SIDCommand(self.context)
+
+ @filter.command("help")
+ async def help(self, event: AstrMessageEvent):
+ """查看帮助"""
+ await self.help_c.help(event)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("llm")
+ async def llm(self, event: AstrMessageEvent):
+ """开启/关闭 LLM"""
+ await self.llm_c.llm(event)
+
+ @filter.command_group("tool")
+ def tool(self):
+ """函数工具管理"""
+
+ @tool.command("ls")
+ async def tool_ls(self, event: AstrMessageEvent):
+ """查看函数工具列表"""
+ await self.tool_c.tool_ls(event)
+
+ @tool.command("on")
+ async def tool_on(self, event: AstrMessageEvent, tool_name: str):
+ """启用一个函数工具"""
+ await self.tool_c.tool_on(event, tool_name)
+
+ @tool.command("off")
+ async def tool_off(self, event: AstrMessageEvent, tool_name: str):
+ """停用一个函数工具"""
+ await self.tool_c.tool_off(event, tool_name)
+
+ @tool.command("off_all")
+ async def tool_all_off(self, event: AstrMessageEvent):
+ """停用所有函数工具"""
+ await self.tool_c.tool_all_off(event)
+
+ @filter.command_group("plugin")
+ def plugin(self):
+ """插件管理"""
+
+ @plugin.command("ls")
+ async def plugin_ls(self, event: AstrMessageEvent):
+ """获取已经安装的插件列表。"""
+ await self.plugin_c.plugin_ls(event)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @plugin.command("off")
+ async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
+ """禁用插件"""
+ await self.plugin_c.plugin_off(event, plugin_name)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @plugin.command("on")
+ async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
+ """启用插件"""
+ await self.plugin_c.plugin_on(event, plugin_name)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @plugin.command("get")
+ async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
+ """安装插件"""
+ await self.plugin_c.plugin_get(event, plugin_repo)
+
+ @plugin.command("help")
+ async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
+ """获取插件帮助"""
+ await self.plugin_c.plugin_help(event, plugin_name)
+
+ @filter.command("t2i")
+ async def t2i(self, event: AstrMessageEvent):
+ """开关文本转图片"""
+ await self.t2i_c.t2i(event)
+
+ @filter.command("tts")
+ async def tts(self, event: AstrMessageEvent):
+ """开关文本转语音(会话级别)"""
+ await self.tts_c.tts(event)
+
+ @filter.command("sid")
+ async def sid(self, event: AstrMessageEvent):
+ """获取会话 ID 和 管理员 ID"""
+ await self.sid_c.sid(event)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("op")
+ async def op(self, event: AstrMessageEvent, admin_id: str = ""):
+ """授权管理员。op """
+ await self.admin_c.op(event, admin_id)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("deop")
+ async def deop(self, event: AstrMessageEvent, admin_id: str):
+ """取消授权管理员。deop """
+ await self.admin_c.deop(event, admin_id)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("wl")
+ async def wl(self, event: AstrMessageEvent, sid: str = ""):
+ """添加白名单。wl """
+ await self.admin_c.wl(event, sid)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("dwl")
+ async def dwl(self, event: AstrMessageEvent, sid: str):
+ """删除白名单。dwl """
+ await self.admin_c.dwl(event, sid)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("provider")
+ async def provider(
+ self,
+ event: AstrMessageEvent,
+ idx: str | int | None = None,
+ idx2: int | None = None,
+ ):
+ """查看或者切换 LLM Provider"""
+ await self.provider_c.provider(event, idx, idx2)
+
+ @filter.command("reset")
+ async def reset(self, message: AstrMessageEvent):
+ """重置 LLM 会话"""
+ await self.conversation_c.reset(message)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("model")
+ async def model_ls(
+ self,
+ message: AstrMessageEvent,
+ idx_or_name: int | str | None = None,
+ ):
+ """查看或者切换模型"""
+ await self.provider_c.model_ls(message, idx_or_name)
+
+ @filter.command("history")
+ async def his(self, message: AstrMessageEvent, page: int = 1):
+ """查看对话记录"""
+ await self.conversation_c.his(message, page)
+
+ @filter.command("ls")
+ async def convs(self, message: AstrMessageEvent, page: int = 1):
+ """查看对话列表"""
+ await self.conversation_c.convs(message, page)
+
+ @filter.command("new")
+ async def new_conv(self, message: AstrMessageEvent):
+ """创建新对话"""
+ await self.conversation_c.new_conv(message)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("groupnew")
+ async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
+ """创建新群聊对话"""
+ await self.conversation_c.groupnew_conv(message, sid)
+
+ @filter.command("switch")
+ async def switch_conv(self, message: AstrMessageEvent, index: int | None = None):
+ """通过 /ls 前面的序号切换对话"""
+ await self.conversation_c.switch_conv(message, index)
+
+ @filter.command("rename")
+ async def rename_conv(self, message: AstrMessageEvent, new_name: str):
+ """重命名对话"""
+ await self.conversation_c.rename_conv(message, new_name)
+
+ @filter.command("del")
+ async def del_conv(self, message: AstrMessageEvent):
+ """删除当前对话"""
+ await self.conversation_c.del_conv(message)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("key")
+ async def key(self, message: AstrMessageEvent, index: int | None = None):
+ """查看或者切换 Key"""
+ await self.provider_c.key(message, index)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("persona")
+ async def persona(self, message: AstrMessageEvent):
+ """查看或者切换 Persona"""
+ await self.persona_c.persona(message)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("dashboard_update")
+ async def update_dashboard(self, event: AstrMessageEvent):
+ """更新管理面板"""
+ await self.admin_c.update_dashboard(event)
+
+ @filter.command("set")
+ async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
+ await self.setunset_c.set_variable(event, key, value)
+
+ @filter.command("unset")
+ async def unset_variable(self, event: AstrMessageEvent, key: str):
+ await self.setunset_c.unset_variable(event, key)
+
+ @filter.permission_type(filter.PermissionType.ADMIN)
+ @filter.command("alter_cmd", alias={"alter"})
+ async def alter_cmd(self, event: AstrMessageEvent):
+ """修改命令权限"""
+ await self.alter_cmd_c.alter_cmd(event)
diff --git a/packages/builtin_commands/metadata.yaml b/packages/builtin_commands/metadata.yaml
new file mode 100644
index 000000000..5e283b9f1
--- /dev/null
+++ b/packages/builtin_commands/metadata.yaml
@@ -0,0 +1,4 @@
+name: builtin_commands
+desc: AstrBot 自带指令,提供常用的对话管理、工具使用、插件管理等功能。
+author: Soulter
+version: 0.0.1
\ No newline at end of file
diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py
index 35a2f2698..afbef7560 100644
--- a/packages/python_interpreter/main.py
+++ b/packages/python_interpreter/main.py
@@ -14,6 +14,7 @@ from astrbot.api import llm_tool, logger, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
from astrbot.api.message_components import File, Image
from astrbot.api.provider import ProviderRequest
+from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url
@@ -224,6 +225,8 @@ class Main(star.Star):
del self.user_waiting[uid]
elif isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
+ if image_url is None:
+ raise ValueError("Image URL is None")
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
elif image_url.startswith("file:///"):
@@ -240,11 +243,13 @@ class Main(star.Star):
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
+ if not request.prompt:
+ request.prompt = ""
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
def pi(self):
- pass
+ """代码执行器配置"""
@pi.command("absdir")
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
@@ -477,7 +482,9 @@ class Main(star.Star):
# file_s3_url = await self.file_upload(file_path)
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
- chain = [File(name=file_name, file=file_path)]
+ chain: list[BaseMessageComponent] = [
+ File(name=file_name, file=file_path)
+ ]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
diff --git a/packages/reminder/main.py b/packages/reminder/main.py
index eaeec8d73..62af7ae56 100644
--- a/packages/reminder/main.py
+++ b/packages/reminder/main.py
@@ -5,6 +5,7 @@ import uuid
import zoneinfo
from apscheduler.schedulers.asyncio import AsyncIOScheduler
+from apscheduler.triggers.cron import CronTrigger
from astrbot.api import llm_tool, logger, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
@@ -62,13 +63,13 @@ class Main(star.Star):
misfire_grace_time=60,
)
elif "cron" in reminder:
+ trigger = CronTrigger(**self._parse_cron_expr(reminder["cron"]))
self.scheduler.add_job(
self._reminder_callback,
- trigger="cron",
+ trigger=trigger,
id=id_,
args=[group, reminder],
misfire_grace_time=60,
- **self._parse_cron_expr(reminder["cron"]),
)
def check_is_outdated(self, reminder: dict):
@@ -101,10 +102,10 @@ class Main(star.Star):
async def reminder_tool(
self,
event: AstrMessageEvent,
- text: str = None,
- datetime_str: str = None,
- cron_expression: str = None,
- human_readable_cron: str = None,
+ text: str | None = None,
+ datetime_str: str | None = None,
+ cron_expression: str | None = None,
+ human_readable_cron: str | None = None,
):
"""Call this function when user is asking for setting a reminder.
@@ -139,17 +140,19 @@ class Main(star.Star):
"id": str(uuid.uuid4()),
}
self.reminder_data[event.unified_msg_origin].append(d)
+ trigger = CronTrigger(**self._parse_cron_expr(cron_expression))
self.scheduler.add_job(
self._reminder_callback,
- "cron",
+ trigger,
id=d["id"],
misfire_grace_time=60,
- **self._parse_cron_expr(cron_expression),
args=[event.unified_msg_origin, d],
)
if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else:
+ if datetime_str is None:
+ raise ValueError("datetime_str cannot be None.")
d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())}
self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime(
@@ -176,7 +179,7 @@ class Main(star.Star):
@filter.command_group("reminder")
def reminder(self):
- """The command group of the reminder."""
+ """待办提醒"""
async def get_upcoming_reminders(self, unified_msg_origin: str):
"""Get upcoming reminders."""
diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py
index 4d4a42528..9ea62ea30 100644
--- a/packages/session_controller/main.py
+++ b/packages/session_controller/main.py
@@ -14,7 +14,7 @@ from astrbot.core.utils.session_waiter import (
)
-class Waiter(Star):
+class Main(Star):
"""会话控制"""
def __init__(self, context: Context):
diff --git a/packages/thinking_filter/main.py b/packages/thinking_filter/main.py
deleted file mode 100644
index a3bc65d20..000000000
--- a/packages/thinking_filter/main.py
+++ /dev/null
@@ -1,208 +0,0 @@
-import json
-import logging
-import re
-from typing import Any
-
-from openai.types.chat.chat_completion import ChatCompletion
-
-from astrbot.api.event import AstrMessageEvent, filter
-from astrbot.api.provider import LLMResponse
-from astrbot.api.star import Context, Star
-
-try:
- # 谨慎引入,避免在未安装 google-genai 的环境下报错
- from google.genai.types import GenerateContentResponse
-except Exception: # pragma: no cover - 兼容无此依赖的运行环境
- GenerateContentResponse = None # type: ignore
-
-
-class R1Filter(Star):
- def __init__(self, context: Context):
- super().__init__(context)
-
- @filter.on_llm_response()
- async def resp(self, event: AstrMessageEvent, response: LLMResponse):
- cfg = self.context.get_config(umo=event.unified_msg_origin).get(
- "provider_settings",
- {},
- )
- show_reasoning = cfg.get("display_reasoning_text", False)
-
- # --- Gemini: 过滤/展示 thought:true 片段 ---
- # Gemini 可能在 parts 中注入 {"thought": true, "text": "..."}
- # 官方 SDK 默认不会返回此字段。
- if GenerateContentResponse is not None and isinstance(
- response.raw_completion,
- GenerateContentResponse,
- ):
- thought_text, answer_text = self._extract_gemini_texts(
- response.raw_completion,
- )
-
- if thought_text or answer_text:
- # 有明确的思考/正文分离信号,则按配置处理
- if show_reasoning:
- merged = (
- (f"🤔思考:{thought_text}\n\n" if thought_text else "")
- + (answer_text or "")
- ).strip()
- if merged:
- response.completion_text = merged
- return
- # 默认隐藏思考内容,仅保留正文
- elif answer_text:
- response.completion_text = answer_text
- return
-
- # --- 非 Gemini 或无明确 thought:true 情况 ---
- if show_reasoning:
- # 显示推理内容的处理逻辑
- if (
- response
- and response.raw_completion
- and isinstance(response.raw_completion, ChatCompletion)
- and len(response.raw_completion.choices) > 0
- and response.raw_completion.choices[0].message
- ):
- message = response.raw_completion.choices[0].message
- reasoning_content = "" # 初始化 reasoning_content
-
- # 检查 Groq deepseek-r1-distill-llama-70b 模型的 'reasoning' 属性
- if hasattr(message, "reasoning") and message.reasoning:
- reasoning_content = message.reasoning
- # 检查 DeepSeek deepseek-reasoner 模型的 'reasoning_content'
- elif (
- hasattr(message, "reasoning_content") and message.reasoning_content
- ):
- reasoning_content = message.reasoning_content
-
- if reasoning_content:
- response.completion_text = (
- f"🤔思考:{reasoning_content}\n\n{message.content}"
- )
- else:
- response.completion_text = message.content
- else:
- # 过滤推理标签的处理逻辑
- completion_text = response.completion_text
-
- # 检查并移除 标签
- if r"" in completion_text or r" " in completion_text:
- # 移除配对的标签及其内容
- completion_text = re.sub(
- r".*? ",
- "",
- completion_text,
- flags=re.DOTALL,
- ).strip()
-
- # 移除可能残留的单个标签
- completion_text = (
- completion_text.replace(r"", "")
- .replace(r" ", "")
- .strip()
- )
-
- response.completion_text = completion_text
-
- # ------------------------
- # helpers
- # ------------------------
- def _get_part_dict(self, p: Any) -> dict:
- """优先使用 SDK 标准序列化方法获取字典,失败则逐级回退。
-
- 顺序: model_dump → model_dump_json → json → to_dict → dict → __dict__。
- """
- for getter in ("model_dump", "model_dump_json", "json", "to_dict", "dict"):
- fn = getattr(p, getter, None)
- if callable(fn):
- try:
- result = fn()
- if isinstance(result, (str, bytes)):
- try:
- if isinstance(result, bytes):
- result = result.decode("utf-8", "ignore")
- return json.loads(result) or {}
- except json.JSONDecodeError:
- continue
- if isinstance(result, dict):
- return result
- except (AttributeError, TypeError):
- continue
- except Exception as e:
- logging.exception(
- f"Unexpected error when calling {getter} on {type(p).__name__}: {e}",
- )
- continue
- try:
- d = getattr(p, "__dict__", None)
- if isinstance(d, dict):
- return d
- except (AttributeError, TypeError):
- pass
- except Exception as e:
- logging.exception(
- f"Unexpected error when accessing __dict__ on {type(p).__name__}: {e}",
- )
- return {}
-
- def _is_thought_part(self, p: Any) -> bool:
- """判断是否为思考片段。
-
- 规则:
- 1) 直接 thought 属性
- 2) 字典字段 thought 或 metadata.thought
- 3) data/raw/extra/_raw 中嵌入的 JSON 串包含 thought: true
- """
- try:
- if getattr(p, "thought", False):
- return True
- except Exception:
- # best-effort
- pass
-
- d = self._get_part_dict(p)
- if d.get("thought") is True:
- return True
- meta = d.get("metadata")
- if isinstance(meta, dict) and meta.get("thought") is True:
- return True
- for k in ("data", "raw", "extra", "_raw"):
- v = d.get(k)
- if isinstance(v, (str, bytes)):
- try:
- if isinstance(v, bytes):
- v = v.decode("utf-8", "ignore")
- parsed = json.loads(v)
- if isinstance(parsed, dict) and parsed.get("thought") is True:
- return True
- except json.JSONDecodeError:
- continue
- return False
-
- def _extract_gemini_texts(self, resp: Any) -> tuple[str, str]:
- """从 GenerateContentResponse 中提取 (思考文本, 正文文本)。"""
- try:
- cand0 = next(iter(getattr(resp, "candidates", []) or []), None)
- if not cand0:
- return "", ""
- content = getattr(cand0, "content", None)
- parts = getattr(content, "parts", None) or []
- except (AttributeError, TypeError, ValueError):
- return "", ""
-
- thought_buf: list[str] = []
- answer_buf: list[str] = []
- for p in parts:
- txt = getattr(p, "text", None)
- if txt is None:
- continue
- txt_str = str(txt).strip()
- if not txt_str:
- continue
- if self._is_thought_part(p):
- thought_buf.append(txt_str)
- else:
- answer_buf.append(txt_str)
-
- return "\n".join(thought_buf).strip(), "\n".join(answer_buf).strip()
diff --git a/packages/thinking_filter/metadata.yaml b/packages/thinking_filter/metadata.yaml
deleted file mode 100644
index 8afbff1e0..000000000
--- a/packages/thinking_filter/metadata.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-name: thinking_filter
-desc: 可选择是否过滤推理模型的思考内容
-author: Soulter
-version: 1.0.0
-repo: https://astrbot.app
\ No newline at end of file
diff --git a/packages/web_searcher/engines/__init__.py b/packages/web_searcher/engines/__init__.py
index 706cfa87b..699438602 100644
--- a/packages/web_searcher/engines/__init__.py
+++ b/packages/web_searcher/engines/__init__.py
@@ -3,7 +3,7 @@ import urllib.parse
from dataclasses import dataclass
from aiohttp import ClientSession
-from bs4 import BeautifulSoup
+from bs4 import BeautifulSoup, Tag
HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
@@ -45,13 +45,13 @@ class SearchEngine:
self.page = 1
self.headers = HEADERS
- def _set_selector(self, selector: str) -> None:
+ def _set_selector(self, selector: str) -> str:
raise NotImplementedError
- def _get_next_page(self):
+ def _get_next_page(self, query: str):
raise NotImplementedError
- async def _get_html(self, url: str, data: dict = None) -> str:
+ async def _get_html(self, url: str, data: dict | None = None) -> str:
headers = self.headers
headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS)
@@ -83,6 +83,9 @@ class SearchEngine:
"""清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
+ def _get_url(self, tag: Tag) -> str:
+ return self.tidy_text(tag.get_text())
+
async def search(self, query: str, num_results: int) -> list[SearchResult]:
query = urllib.parse.quote(query)
@@ -92,12 +95,16 @@ class SearchEngine:
links = soup.select(self._set_selector("links"))
results = []
for link in links:
- title = self.tidy_text(
- link.select_one(self._set_selector("title")).text,
- )
- url = link.select_one(self._set_selector("url"))
+ # Safely get the title text (select_one may return None)
+ title_elem = link.select_one(self._set_selector("title"))
+ title = ""
+ if title_elem is not None:
+ title = self.tidy_text(title_elem.get_text())
+
+ url_tag = link.select_one(self._set_selector("url"))
snippet = ""
- if title and url:
+ if title and url_tag:
+ url = self._get_url(url_tag)
results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results
except Exception as e:
diff --git a/packages/web_searcher/engines/bing.py b/packages/web_searcher/engines/bing.py
index 4c2ec319d..7565e5df3 100644
--- a/packages/web_searcher/engines/bing.py
+++ b/packages/web_searcher/engines/bing.py
@@ -1,4 +1,4 @@
-from . import USER_AGENT_BING, SearchEngine, SearchResult
+from . import USER_AGENT_BING, SearchEngine
class Bing(SearchEngine):
@@ -28,11 +28,3 @@ class Bing(SearchEngine):
self.base_url = base_url
continue
raise Exception("Bing search failed")
-
- async def search(self, query: str, num_results: int) -> list[SearchResult]:
- results = await super().search(query, num_results)
- for result in results:
- if not isinstance(result.url, str):
- result.url = result.url.text
-
- return results
diff --git a/packages/web_searcher/engines/sogo.py b/packages/web_searcher/engines/sogo.py
index 382e7c937..f490f1106 100644
--- a/packages/web_searcher/engines/sogo.py
+++ b/packages/web_searcher/engines/sogo.py
@@ -1,7 +1,8 @@
import random
import re
+from typing import cast
-from bs4 import BeautifulSoup
+from bs4 import BeautifulSoup, Tag
from . import USER_AGENTS, SearchEngine, SearchResult
@@ -26,10 +27,12 @@ class Sogo(SearchEngine):
url = f"{self.base_url}/web?query={query}"
return await self._get_html(url, None)
+ def _get_url(self, tag: Tag) -> str:
+ return cast(str, tag.get("href"))
+
async def search(self, query: str, num_results: int) -> list[SearchResult]:
results = await super().search(query, num_results)
for result in results:
- result.url = result.url.get("href")
if result.url.startswith("/link?"):
result.url = self.base_url + result.url
result.url = await self._parse_url(result.url)
@@ -40,7 +43,10 @@ class Sogo(SearchEngine):
soup = BeautifulSoup(html, "html.parser")
script = soup.find("script")
if script:
- url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(
- 1,
+ script_text = (
+ script.string if script.string is not None else script.get_text()
)
+ match = re.search(r'window.location.replace\("(.+?)"\)', script_text)
+ if match:
+ url = match.group(1)
return url
diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py
index 118ef2483..4745cd0c0 100644
--- a/packages/web_searcher/main.py
+++ b/packages/web_searcher/main.py
@@ -185,6 +185,7 @@ class Main(star.Star):
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
+ """网页搜索指令(已废弃)"""
event.set_result(
MessageEventResult().message(
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。",
diff --git a/pyproject.toml b/pyproject.toml
index 861a799a8..f56b101ef 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
-version = "4.5.6"
+version = "4.9.2"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.10"
@@ -26,7 +26,7 @@ dependencies = [
"docstring-parser>=0.16",
"faiss-cpu==1.10.0",
"filelock>=3.18.0",
- "google-genai>=1.14.0",
+ "google-genai>=1.56.0",
"lark-oapi>=1.4.15",
"lxml-html-clean>=0.4.2",
"mcp>=1.8.0",
@@ -59,6 +59,7 @@ dependencies = [
"jieba>=0.42.1",
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
"xinference-client",
+ "tenacity>=9.1.2",
]
[dependency-groups]
@@ -107,4 +108,4 @@ exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
[build-system]
requires = ["hatchling"]
-build-backend = "hatchling.build"
\ No newline at end of file
+build-backend = "hatchling.build"
diff --git a/requirements.txt b/requirements.txt
index e8b3dee3c..5b70f33ff 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -19,7 +19,7 @@ dingtalk-stream>=0.22.1
docstring-parser>=0.16
faiss-cpu==1.10.0
filelock>=3.18.0
-google-genai>=1.14.0
+google-genai>=1.56.0
lark-oapi>=1.4.15
lxml-html-clean>=0.4.2
mcp>=1.8.0
@@ -52,3 +52,4 @@ rank-bm25>=0.2.2
jieba>=0.42.1
markitdown-no-magika[docx,xls,xlsx]>=0.1.2
xinference-client
+tenacity>=9.1.2
\ No newline at end of file
diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py
index a2710c841..969f0da6d 100644
--- a/tests/test_dashboard.py
+++ b/tests/test_dashboard.py
@@ -21,7 +21,17 @@ async def core_lifecycle_td(tmp_path_factory):
log_broker = LogBroker()
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
- return core_lifecycle
+ try:
+ yield core_lifecycle
+ finally:
+ # 优先停止核心生命周期以释放资源(包括关闭 MCP 等后台任务)
+ try:
+ _stop_res = core_lifecycle.stop()
+ if asyncio.iscoroutine(_stop_res):
+ await _stop_res
+ except Exception:
+ # 停止过程中如有异常,不影响后续清理
+ pass
@pytest.fixture(scope="module")
@@ -150,6 +160,34 @@ async def test_plugins(app: Quart, authenticated_header: dict):
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
+@pytest.mark.asyncio
+async def test_commands_api(app: Quart, authenticated_header: dict):
+ """Tests the command management API endpoints."""
+ test_client = app.test_client()
+
+ # GET /api/commands - list commands
+ response = await test_client.get("/api/commands", headers=authenticated_header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data["status"] == "ok"
+ assert "items" in data["data"]
+ assert "summary" in data["data"]
+ summary = data["data"]["summary"]
+ assert "total" in summary
+ assert "disabled" in summary
+ assert "conflicts" in summary
+
+ # GET /api/commands/conflicts - list conflicts
+ response = await test_client.get(
+ "/api/commands/conflicts", headers=authenticated_header
+ )
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data["status"] == "ok"
+ # conflicts is a list
+ assert isinstance(data["data"], list)
+
+
@pytest.mark.asyncio
async def test_check_update(app: Quart, authenticated_header: dict):
test_client = app.test_client()
diff --git a/tests/test_kb_import.py b/tests/test_kb_import.py
new file mode 100644
index 000000000..8ad40f540
--- /dev/null
+++ b/tests/test_kb_import.py
@@ -0,0 +1,209 @@
+import asyncio
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+import pytest_asyncio
+from quart import Quart
+
+from astrbot.core import LogBroker
+from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
+from astrbot.core.db.sqlite import SQLiteDatabase
+from astrbot.core.knowledge_base.kb_helper import KBHelper
+from astrbot.core.knowledge_base.models import KBDocument
+from astrbot.dashboard.server import AstrBotDashboard
+
+
+@pytest_asyncio.fixture(scope="module")
+async def core_lifecycle_td(tmp_path_factory):
+ """Creates and initializes a core lifecycle instance with a temporary database."""
+ tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db"
+ db = SQLiteDatabase(str(tmp_db_path))
+ log_broker = LogBroker()
+ core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
+ await core_lifecycle.initialize()
+
+ # Mock kb_manager and kb_helper
+ kb_manager = MagicMock()
+ kb_helper = AsyncMock(spec=KBHelper)
+
+ # Configure get_kb to be an async mock that returns kb_helper
+ kb_manager.get_kb = AsyncMock(return_value=kb_helper)
+
+ # Mock upload_document return value
+ mock_doc = KBDocument(
+ doc_id="test_doc_id",
+ kb_id="test_kb_id",
+ doc_name="test_file.txt",
+ file_type="txt",
+ file_size=100,
+ file_path="",
+ chunk_count=2,
+ media_count=0,
+ )
+ kb_helper.upload_document.return_value = mock_doc
+
+ # kb_manager.get_kb.return_value = kb_helper # Removed this line as it's handled above
+ core_lifecycle.kb_manager = kb_manager
+
+ try:
+ yield core_lifecycle
+ finally:
+ try:
+ _stop_res = core_lifecycle.stop()
+ if asyncio.iscoroutine(_stop_res):
+ await _stop_res
+ except Exception:
+ pass
+
+
+@pytest.fixture(scope="module")
+def app(core_lifecycle_td: AstrBotCoreLifecycle):
+ """Creates a Quart app instance for testing."""
+ shutdown_event = asyncio.Event()
+ server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
+ return server.app
+
+
+@pytest_asyncio.fixture(scope="module")
+async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
+ """Handles login and returns an authenticated header."""
+ test_client = app.test_client()
+ response = await test_client.post(
+ "/api/auth/login",
+ json={
+ "username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
+ "password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
+ },
+ )
+ data = await response.get_json()
+ assert data["status"] == "ok"
+ token = data["data"]["token"]
+ return {"Authorization": f"Bearer {token}"}
+
+
+@pytest.mark.asyncio
+async def test_import_documents(
+ app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle
+):
+ """Tests the import documents functionality."""
+ test_client = app.test_client()
+
+ # Test data
+ import_data = {
+ "kb_id": "test_kb_id",
+ "documents": [
+ {"file_name": "test_file_1.txt", "chunks": ["chunk1", "chunk2"]},
+ {"file_name": "test_file_2.md", "chunks": ["chunk3", "chunk4", "chunk5"]},
+ ],
+ }
+
+ # Send request
+ response = await test_client.post(
+ "/api/kb/document/import", json=import_data, headers=authenticated_header
+ )
+
+ # Verify response
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data["status"] == "ok"
+ assert "task_id" in data["data"]
+ assert data["data"]["doc_count"] == 2
+
+ task_id = data["data"]["task_id"]
+
+ # Wait for background task to complete (mocked)
+ # Since we mocked upload_document, it should be fast, but we might need to poll progress
+ for _ in range(10):
+ progress_response = await test_client.get(
+ f"/api/kb/document/upload/progress?task_id={task_id}",
+ headers=authenticated_header,
+ )
+ progress_data = await progress_response.get_json()
+ if progress_data["data"]["status"] == "completed":
+ break
+ await asyncio.sleep(0.1)
+
+ assert progress_data["data"]["status"] == "completed"
+ result = progress_data["data"]["result"]
+ assert result["success_count"] == 2
+ assert result["failed_count"] == 0
+
+ # Verify kb_helper.upload_document was called correctly
+ kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id")
+ assert kb_helper.upload_document.call_count == 2
+
+ # Check first call arguments
+ call_args_list = kb_helper.upload_document.call_args_list
+
+ # First document
+ args1, kwargs1 = call_args_list[0]
+ assert kwargs1["file_name"] == "test_file_1.txt"
+ assert kwargs1["pre_chunked_text"] == ["chunk1", "chunk2"]
+
+ # Second document
+ args2, kwargs2 = call_args_list[1]
+ assert kwargs2["file_name"] == "test_file_2.md"
+ assert kwargs2["pre_chunked_text"] == ["chunk3", "chunk4", "chunk5"]
+
+
+@pytest.mark.asyncio
+async def test_import_documents_invalid_input(app: Quart, authenticated_header: dict):
+ """Tests import documents with invalid input."""
+ test_client = app.test_client()
+
+ # Missing kb_id
+ response = await test_client.post(
+ "/api/kb/document/import", json={"documents": []}, headers=authenticated_header
+ )
+ data = await response.get_json()
+ assert data["status"] == "error"
+ assert "缺少参数 kb_id" in data["message"]
+
+ # Missing documents
+ response = await test_client.post(
+ "/api/kb/document/import",
+ json={"kb_id": "test_kb"},
+ headers=authenticated_header,
+ )
+ data = await response.get_json()
+ assert data["status"] == "error"
+ assert "缺少参数 documents" in data["message"]
+
+ # Invalid document format
+ response = await test_client.post(
+ "/api/kb/document/import",
+ json={
+ "kb_id": "test_kb",
+ "documents": [{"file_name": "test"}], # Missing chunks
+ },
+ headers=authenticated_header,
+ )
+ data = await response.get_json()
+ assert data["status"] == "error"
+ assert "文档格式错误" in data["message"]
+
+ # Invalid chunks type
+ response = await test_client.post(
+ "/api/kb/document/import",
+ json={
+ "kb_id": "test_kb",
+ "documents": [{"file_name": "test", "chunks": "not-a-list"}],
+ },
+ headers=authenticated_header,
+ )
+ data = await response.get_json()
+ assert data["status"] == "error"
+ assert "chunks 必须是列表" in data["message"]
+
+ # Invalid chunks content
+ response = await test_client.post(
+ "/api/kb/document/import",
+ json={
+ "kb_id": "test_kb",
+ "documents": [{"file_name": "test", "chunks": ["valid", ""]}],
+ },
+ headers=authenticated_header,
+ )
+ data = await response.get_json()
+ assert data["status"] == "error"
+ assert "chunks 必须是非空字符串列表" in data["message"]
diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py
index 277f8fa4d..1e4cd866a 100644
--- a/tests/test_plugin_manager.py
+++ b/tests/test_plugin_manager.py
@@ -39,6 +39,7 @@ def plugin_manager_pm(tmp_path):
message_history_manager = MagicMock()
persona_manager = MagicMock()
astrbot_config_mgr = MagicMock()
+ knowledge_base_manager = MagicMock()
star_context = Context(
event_queue,
@@ -50,6 +51,7 @@ def plugin_manager_pm(tmp_path):
message_history_manager,
persona_manager,
astrbot_config_mgr,
+ knowledge_base_manager=knowledge_base_manager,
)
# Create the PluginManager instance
diff --git a/typings/faiss/__init__.pyi b/typings/faiss/__init__.pyi
new file mode 100644
index 000000000..6f2bace36
--- /dev/null
+++ b/typings/faiss/__init__.pyi
@@ -0,0 +1,90 @@
+"""Minimal type stubs for faiss used in this project.
+
+This file only exposes a small subset of the faiss API that the
+project uses, including the runtime-monkeypatched signatures such as
+`Index.add_with_ids` so Pyright/Pylance stops reporting false positives.
+"""
+
+from typing import Any, overload
+
+import numpy as np
+
+class Index:
+ d: int
+ ntotal: int
+ code_size: int
+ nprobe: int
+
+ def add(self, x: np.ndarray) -> None: ...
+ def add_with_ids(self, x: np.ndarray, ids: np.ndarray) -> None: ...
+ def search(
+ self,
+ x: np.ndarray,
+ k: int,
+ *,
+ params: Any = ...,
+ D: np.ndarray | None = ...,
+ I: np.ndarray | None = ...,
+ ) -> tuple[np.ndarray, np.ndarray]: ...
+ def remove_ids(self, x: np.ndarray) -> int: ...
+ @overload
+ def reconstruct(self, key: int) -> np.ndarray: ...
+ @overload
+ def reconstruct(self, key: int, x: np.ndarray) -> None: ...
+ def reconstruct(
+ self, key: int, x: np.ndarray | None = ...
+ ) -> np.ndarray | None: ...
+ @overload
+ def reconstruct_n(self, n0: int, ni: int) -> np.ndarray: ...
+ @overload
+ def reconstruct_n(self, n0: int, ni: int, x: np.ndarray) -> None: ...
+ def reconstruct_n(
+ self, n0: int = ..., ni: int = ..., x: np.ndarray | None = ...
+ ) -> np.ndarray | None: ...
+ def range_search(
+ self, x: np.ndarray, thresh: float, *, params: Any = ...
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ...
+ def add_sa_codes(self, codes: np.ndarray, ids: np.ndarray | None = ...) -> None: ...
+ def sa_encode(self, x: np.ndarray) -> np.ndarray: ...
+ def sa_decode(self, codes: np.ndarray) -> np.ndarray: ...
+
+class IndexFlatL2(Index):
+ def __init__(self, d: int) -> None: ...
+
+class IndexIDMap(Index):
+ index: Index
+
+ def __init__(self, index: Index) -> None: ...
+
+def read_index(path: str) -> Index: ...
+def write_index(index: Index, path: str | None = ...) -> None: ...
+def normalize_L2(x: np.ndarray) -> None: ...
+
+# Additional concrete-ish classes exposed by some faiss builds (SWIG helpers
+# expose `downcast_*` helpers to convert generic objects to these concrete
+# types). We keep these minimal — only the names are important for typing.
+class IndexBinary(Index):
+ def __init__(self, d: int) -> None: ...
+
+class InvertedLists:
+ def __len__(self) -> int: ...
+
+class AdditiveQuantizer:
+ pass
+
+class Quantizer:
+ pass
+
+class VectorTransform:
+ pass
+
+# SWIG-provided downcast helpers (present in some faiss Python builds).
+def downcast_IndexBinary(obj: Any) -> IndexBinary: ...
+def downcast_InvertedLists(obj: Any) -> InvertedLists: ...
+def downcast_AdditiveQuantizer(obj: Any) -> AdditiveQuantizer: ...
+def downcast_Quantizer(obj: Any) -> Quantizer: ...
+def downcast_VectorTransform(obj: Any) -> VectorTransform: ...
+def downcast_index(obj: Any) -> Index: ...
+
+# version exposed by runtime
+__version__: str