Compare commits
158 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 82a96a8cce | |||
| 343b153263 | |||
| 3a41b19318 | |||
| af444ea6cc | |||
| cb84db532e | |||
| 99b82f48ec | |||
| 00471f904e | |||
| 5df15c60ff | |||
| 32e523b7da | |||
| 0de4fd9f0d | |||
| e23a7e2505 | |||
| 1ed4d9f484 | |||
| d842155770 | |||
| 7f5cc7cf1a | |||
| f26867c77d | |||
| a14d588b44 | |||
| e236402d92 | |||
| 454841de10 | |||
| 442b5403df | |||
| 9db7bf59b8 | |||
| 3622504021 | |||
| fc42db40ce | |||
| e413a002c1 | |||
| 6437d759a3 | |||
| c758b2d888 | |||
| 510290fe0e | |||
| c61d62edb6 | |||
| 45bce6fe76 | |||
| f156adddf8 | |||
| b5a4b80c36 | |||
| 792fb69d6d | |||
| 300a73ace0 | |||
| a5b9de3695 | |||
| 90142bcafe | |||
| 79d0487c03 | |||
| 4f15102e79 | |||
| ef1feb639c | |||
| 1039a4f864 | |||
| 66e2f49c11 | |||
| c5773fe63e | |||
| 4e9ef48af2 | |||
| 9eafd7b44a | |||
| fc61f7ad32 | |||
| f51810997a | |||
| fb4baf676f | |||
| 71ad974c3c | |||
| f0fff68947 | |||
| 3e3599835e | |||
| 5255388e2d | |||
| fbdd60b64c | |||
| bd1b0a2836 | |||
| 19541d9d07 | |||
| 2a5d574394 | |||
| f2924fbd1b | |||
| 703e208947 | |||
| 9a5cc977c2 | |||
| aa38fe776a | |||
| 701399c00c | |||
| eaee98d4b8 | |||
| 76c66000a7 | |||
| 4b365143c0 | |||
| 6e4e5011e2 | |||
| d853bfde84 | |||
| a0e856f80f | |||
| 8c94a0010c | |||
| a44fdaaec0 | |||
| 60105c76f5 | |||
| bcf87d3ce4 | |||
| 4d7c8c8453 | |||
| a064a9115f | |||
| 6ef99e1553 | |||
| c0dbe5cf65 | |||
| 3598c51eff | |||
| b5cdb8f650 | |||
| fc5b520f9b | |||
| 904f56b32f | |||
| 2f15fd019c | |||
| 82330b8d10 | |||
| 3ee6af7027 | |||
| 6e20ebe901 | |||
| 4d6150fd6d | |||
| 544e52191b | |||
| f2c2a6da4a | |||
| dd3df425ee | |||
| 40b4a27a3d | |||
| 9d991c7468 | |||
| ad6a8b5c94 | |||
| 1b4bfcbd72 | |||
| 9d3cc593a1 | |||
| f0dee35ba9 | |||
| 4135bd84d5 | |||
| f6da614e5d | |||
| 5f531c9be5 | |||
| 94591d965b | |||
| 8a0f865af1 | |||
| 4aced976a8 | |||
| 0299aa6e4c | |||
| e8b54a019e | |||
| 98ce796275 | |||
| b87dcf2275 | |||
| 591a228431 | |||
| f52f375154 | |||
| 975c685a17 | |||
| 6db80d36a8 | |||
| 4651bd2807 | |||
| 94ada3793e | |||
| fd05b0bf09 | |||
| 4d046f8490 | |||
| 58e32b7b70 | |||
| 903dd0f9f7 | |||
| 1acac0cac2 | |||
| 80b89fd2ea | |||
| 26f863ba81 | |||
| f78a90218e | |||
| a3ecebd2aa | |||
| 67c33b842d | |||
| 5431c9f46e | |||
| 764b91a5f7 | |||
| c20c1b84bf | |||
| fd66a0ac00 | |||
| aaee283367 | |||
| 4a5b7d1976 | |||
| 08244548ab | |||
| b486de6a98 | |||
| e2f928a7e5 | |||
| b8e4068c75 | |||
| 0916177a57 | |||
| 02cd5e396b | |||
| 56673ad78f | |||
| 9a4d05e2b6 | |||
| b2e9dab233 | |||
| 45110200ea | |||
| c3f45449e8 | |||
| 65da469deb | |||
| 16df64c405 | |||
| 6b73b19e54 | |||
| a70088b799 | |||
| e7e97730af | |||
| 467ca1eb5c | |||
| bb45d9cb54 | |||
| 46528391c2 | |||
| 8a0b7717cc | |||
| 3b81fb4985 | |||
| c09d57a820 | |||
| ec408a2aff | |||
| 417179a6b9 | |||
| fcd29445c7 | |||
| 5f535001db | |||
| 750d245b16 | |||
| f624971613 | |||
| aa6d07afcc | |||
| 2c36649874 | |||
| c95735dcc0 | |||
| 03bb278f50 | |||
| a5e0974da3 | |||
| f0fb447fbc | |||
| 37566182b0 | |||
| e460b411da |
@@ -15,7 +15,6 @@ Always reference these instructions first and fallback to search or bash command
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
@@ -35,7 +34,7 @@ Always reference these instructions first and fallback to search or bash command
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: dist-without-markdown
|
||||
path: |
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
name: Smoke Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README*.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
smoke-test:
|
||||
name: Run smoke tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV package manager
|
||||
run: |
|
||||
pip install uv
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync
|
||||
timeout-minutes: 15
|
||||
|
||||
- name: Run smoke tests
|
||||
run: |
|
||||
uv run main.py &
|
||||
APP_PID=$!
|
||||
|
||||
echo "Waiting for application to start..."
|
||||
for i in {1..60}; do
|
||||
if curl -f http://localhost:6185 > /dev/null 2>&1; then
|
||||
echo "Application started successfully!"
|
||||
kill $APP_PID
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Application failed to start within 30 seconds"
|
||||
kill $APP_PID 2>/dev/null || true
|
||||
exit 1
|
||||
timeout-minutes: 2
|
||||
+52
-15
@@ -1,27 +1,64 @@
|
||||
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
||||
# 本工作流用于标记并关闭长期不活跃的 Issue。
|
||||
# 目前仅针对带 `bug` 标签的 Issue 生效,不会处理 PR。
|
||||
#
|
||||
# You can adjust the behavior by modifying this file.
|
||||
# For more information, see:
|
||||
# https://github.com/actions/stale
|
||||
name: Mark stale issues and pull requests
|
||||
# 文档: https://github.com/actions/stale
|
||||
name: Mark stale bug issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '21 23 * * *'
|
||||
# 每天 UTC 08:30 执行 (北京时间 16:30)
|
||||
- cron: '30 8 * * *'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry-run:
|
||||
description: '仅预览, 不实际执行 (Dry run mode)'
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'Stale issue message'
|
||||
stale-pr-message: 'Stale pull request message'
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 200
|
||||
|
||||
# 只处理带 bug 标签的 Issue
|
||||
any-of-labels: 'bug'
|
||||
|
||||
# 不处理 PR
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
|
||||
# 不活跃判定与关闭策略: 先标记 stale, 再延迟关闭
|
||||
days-before-issue-stale: 60
|
||||
days-before-issue-close: 30
|
||||
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: |
|
||||
This issue has been automatically marked as **stale** because it has not had any activity.
|
||||
It will be closed in a certain period of time if no further activity occurs.
|
||||
If this issue is still relevant, please leave a comment.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 已较长时间无活动, 已被标记为 `stale`。
|
||||
如无后续活动, 将在一段时间后自动关闭。
|
||||
如仍需跟进, 请回复评论。
|
||||
close-issue-message: |
|
||||
This issue has been automatically closed due to inactivity.
|
||||
If the problem still exists, feel free to reopen or create a new issue with updated information.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 因长期无活动已自动关闭。
|
||||
如问题仍存在, 欢迎补充复现信息并重新打开或新建 Issue。
|
||||
|
||||
remove-stale-when-updated: true
|
||||
|
||||
debug-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run }}
|
||||
|
||||
+2
-2
@@ -24,9 +24,9 @@ configs/session
|
||||
configs/config.yaml
|
||||
cmd_config.json
|
||||
|
||||
# Plugins and packages
|
||||
# Plugins
|
||||
addons/plugins
|
||||
packages/python_interpreter/workplace
|
||||
astrbot/builtin_stars/python_interpreter/workplace
|
||||
tests/astrbot_plugin_openai
|
||||
|
||||
# Dashboard
|
||||
|
||||
+26
-1
@@ -33,6 +33,20 @@
|
||||
- 请使用英文描述您的 PR。
|
||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||
|
||||
#### 代码规范
|
||||
|
||||
##### Core
|
||||
|
||||
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
||||
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
First off, thanks for taking the time to contribute! ❤️
|
||||
@@ -62,4 +76,15 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
|
||||
|
||||
#### PR Description
|
||||
- Please use English to describe your PR.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
|
||||
#### Code Style
|
||||
|
||||
##### Core
|
||||
|
||||
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
|
||||
@@ -131,6 +132,7 @@ uv run main.py
|
||||
|
||||
**社区维护**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
@@ -206,6 +208,8 @@ pre-commit install
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 7 群:743746109
|
||||
- 8 群:1030353265
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
@@ -241,4 +245,10 @@ pre-commit install
|
||||
|
||||
</details>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
</div
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ Or refer to the official documentation: [Deploy AstrBot from Source](https://ast
|
||||
|
||||
**Community Maintained**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources
|
||||
|
||||
**Maintenues par la communauté**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**コミュニティメンテナンス**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**Поддерживаемые сообществом**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**社群維護**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -21,6 +21,9 @@ from astrbot.core.star.register import (
|
||||
from astrbot.core.star.register import register_on_llm_request as on_llm_request
|
||||
from astrbot.core.star.register import register_on_llm_response as on_llm_response
|
||||
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
|
||||
from astrbot.core.star.register import (
|
||||
register_on_waiting_llm_request as on_waiting_llm_request,
|
||||
)
|
||||
from astrbot.core.star.register import register_permission_type as permission_type
|
||||
from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
@@ -46,6 +49,7 @@ __all__ = [
|
||||
"on_llm_request",
|
||||
"on_llm_response",
|
||||
"on_platform_loaded",
|
||||
"on_waiting_llm_request",
|
||||
"permission_type",
|
||||
"platform_adapter_type",
|
||||
"regex",
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import traceback
|
||||
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.provider import LLMResponse, ProviderRequest
|
||||
from astrbot.core import logger
|
||||
|
||||
from .long_term_memory import LongTermMemory
|
||||
from .process_llm_request import ProcessLLMRequest
|
||||
|
||||
|
||||
class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.ltm = None
|
||||
try:
|
||||
self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"聊天增强 err: {e}")
|
||||
|
||||
self.proc_llm_req = ProcessLLMRequest(self.context)
|
||||
|
||||
def ltm_enabled(self, event: AstrMessageEvent):
|
||||
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_ltm_settings"
|
||||
]
|
||||
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
"""群聊记忆增强"""
|
||||
has_image_or_plain = False
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Plain) or isinstance(comp, Image):
|
||||
has_image_or_plain = True
|
||||
break
|
||||
|
||||
if self.ltm_enabled(event) and self.ltm and has_image_or_plain:
|
||||
need_active = await self.ltm.need_active_reply(event)
|
||||
|
||||
group_icl_enable = self.context.get_config()["provider_ltm_settings"][
|
||||
"group_icl_enable"
|
||||
]
|
||||
if group_icl_enable:
|
||||
"""记录对话"""
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
|
||||
if need_active:
|
||||
"""主动回复"""
|
||||
provider = self.context.get_using_provider(event.unified_msg_origin)
|
||||
if not provider:
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
conv = None
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
if not session_curr_cid:
|
||||
logger.error(
|
||||
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
)
|
||||
return
|
||||
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
session_curr_cid,
|
||||
)
|
||||
|
||||
prompt = event.message_str
|
||||
|
||||
if not conv:
|
||||
logger.error("未找到对话,无法主动回复")
|
||||
return
|
||||
|
||||
yield event.request_llm(
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
conversation=conv,
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
await self.proc_llm_req.process_llm_request(event, req)
|
||||
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.on_req_llm(event, req)
|
||||
except BaseException as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_message_sent(self, event: AstrMessageEvent):
|
||||
"""消息发送后处理"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
clean_session = event.get_extra("_clean_ltm_session", False)
|
||||
if clean_session:
|
||||
await self.ltm.remove_session(event)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
@@ -0,0 +1,4 @@
|
||||
name: astrbot
|
||||
desc: AstrBot 自带插件,包含人格注入、思考内容注入、群聊上下文感知等功能的实现,禁用后将无法使用这些功能。
|
||||
author: Soulter
|
||||
version: 4.1.0
|
||||
+49
-16
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
|
||||
|
||||
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
|
||||
req.image_urls,
|
||||
)
|
||||
if caption:
|
||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||
)
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
@@ -129,19 +132,25 @@ class ProcessLLMRequest:
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# 收集系统提醒信息
|
||||
system_parts = []
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
req.prompt = (
|
||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
||||
)
|
||||
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
if not event.message_obj.group:
|
||||
logger.error(
|
||||
f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}"
|
||||
)
|
||||
return
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
system_parts.append(f"Group name: {group_name}")
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
@@ -157,7 +166,7 @@ class ProcessLLMRequest:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
system_parts.append(f"Current datetime: {current_time}")
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
@@ -176,37 +185,61 @@ class ProcessLLMRequest:
|
||||
quote = comp
|
||||
break
|
||||
if quote:
|
||||
sender_info = ""
|
||||
if quote.sender_nickname:
|
||||
sender_info = f"(Sent by {quote.sender_nickname})"
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
req.system_prompt += (
|
||||
f"\nUser is quoting a message{sender_info}.\n"
|
||||
f"Here are the information of the quoted message: Text Content: {message_str}.\n"
|
||||
content_parts = []
|
||||
|
||||
# 1. 处理引用的文本
|
||||
sender_info = (
|
||||
f"({quote.sender_nickname}): " if quote.sender_nickname else ""
|
||||
)
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
content_parts.append(f"{sender_info}{message_str}")
|
||||
|
||||
# 2. 处理引用的图片 (保留原有逻辑,但改变输出目标)
|
||||
image_seg = None
|
||||
if quote.chain:
|
||||
for comp in quote.chain:
|
||||
if isinstance(comp, Image):
|
||||
image_seg = comp
|
||||
break
|
||||
|
||||
if image_seg:
|
||||
try:
|
||||
# 找到可以生成图片描述的 provider
|
||||
prov = None
|
||||
if img_cap_prov_id:
|
||||
prov = self.ctx.get_provider_by_id(img_cap_prov_id)
|
||||
if prov is None:
|
||||
prov = self.ctx.get_using_provider(event.unified_msg_origin)
|
||||
|
||||
# 调用 provider 生成图片描述
|
||||
if prov and isinstance(prov, Provider):
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt="Please describe the image content.",
|
||||
image_urls=[await image_seg.convert_to_file_path()],
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
req.system_prompt += (
|
||||
f"Image Caption: {llm_resp.completion_text}\n"
|
||||
# 将图片描述作为文本添加到 content_parts
|
||||
content_parts.append(
|
||||
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
|
||||
)
|
||||
else:
|
||||
logger.warning("No provider found for image captioning.")
|
||||
logger.warning(
|
||||
"No provider found for image captioning in quote."
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
|
||||
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts 中
|
||||
# 确保引用内容被正确的标签包裹
|
||||
quoted_content = "\n".join(content_parts)
|
||||
# 确保所有内容都在<Quoted Message>标签内
|
||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||
|
||||
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||
|
||||
# 统一包裹所有系统提醒
|
||||
if system_parts:
|
||||
system_content = (
|
||||
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||
)
|
||||
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||
+1
@@ -71,6 +71,7 @@ class AdminCommands:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
"""更新管理面板"""
|
||||
await event.send(MessageChain().message("正在尝试更新管理面板..."))
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
await event.send(MessageChain().message("管理面板更新完成。"))
|
||||
+10
-25
@@ -1,24 +1,23 @@
|
||||
import datetime
|
||||
|
||||
from astrbot.api import logger, sp, star
|
||||
from astrbot.api import sp, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
from ..long_term_memory import LongTermMemory
|
||||
from .utils.rst_scene import RstScene
|
||||
|
||||
THIRD_PARTY_AGENT_RUNNER_KEY = {
|
||||
"dify": "dify_conversation_id",
|
||||
"coze": "coze_conversation_id",
|
||||
"dashscope": "dashscope_conversation_id",
|
||||
}
|
||||
THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
|
||||
|
||||
|
||||
class ConversationCommands:
|
||||
def __init__(self, context: star.Context, ltm: LongTermMemory | None = None):
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
self.ltm = ltm
|
||||
|
||||
async def _get_current_persona_id(self, session_id):
|
||||
curr = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
@@ -30,16 +29,10 @@ class ConversationCommands:
|
||||
session_id,
|
||||
curr,
|
||||
)
|
||||
if not conv:
|
||||
return None
|
||||
return conv.persona_id
|
||||
|
||||
def ltm_enabled(self, event: AstrMessageEvent):
|
||||
if not self.ltm:
|
||||
return False
|
||||
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_ltm_settings"
|
||||
]
|
||||
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
|
||||
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
"""重置 LLM 会话"""
|
||||
umo = message.unified_msg_origin
|
||||
@@ -99,10 +92,9 @@ class ConversationCommands:
|
||||
[],
|
||||
)
|
||||
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
ret = "清除聊天历史成功!"
|
||||
|
||||
message.set_extra("_clean_ltm_session", True)
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
@@ -244,12 +236,7 @@ class ConversationCommands:
|
||||
persona_id=cpersona,
|
||||
)
|
||||
|
||||
# 长期记忆
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
try:
|
||||
await self.ltm.remove_session(event=message)
|
||||
except Exception as e:
|
||||
logger.error(f"清理聊天增强记录失败: {e}")
|
||||
message.set_extra("_clean_ltm_session", True)
|
||||
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
|
||||
@@ -375,7 +362,5 @@ class ConversationCommands:
|
||||
)
|
||||
|
||||
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
message.set_extra("_clean_ltm_session", True)
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
@@ -0,0 +1,88 @@
|
||||
import aiohttp
|
||||
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.star import command_management
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
|
||||
|
||||
class HelpCommand:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://astrbot.app/notice.json",
|
||||
timeout=2,
|
||||
) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
return ""
|
||||
|
||||
async def _build_reserved_command_lines(self) -> list[str]:
|
||||
"""
|
||||
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
|
||||
"""
|
||||
try:
|
||||
commands = await command_management.list_commands()
|
||||
except BaseException:
|
||||
return []
|
||||
|
||||
lines: list[str] = []
|
||||
hidden_commands = {"set", "unset", "websearch"}
|
||||
|
||||
def walk(items: list[dict], indent: int = 0):
|
||||
for item in items:
|
||||
if not item.get("reserved") or not item.get("enabled"):
|
||||
continue
|
||||
# 仅展示顶级指令或指令组
|
||||
if item.get("type") == "sub_command":
|
||||
continue
|
||||
if item.get("parent_signature"):
|
||||
continue
|
||||
|
||||
effective = (
|
||||
item.get("effective_command")
|
||||
or item.get("original_command")
|
||||
or item.get("handler_name")
|
||||
)
|
||||
if not effective:
|
||||
continue
|
||||
if effective in hidden_commands:
|
||||
continue
|
||||
|
||||
description = item.get("description") or ""
|
||||
desc_text = f" - {description}" if description else ""
|
||||
indent_prefix = " " * indent
|
||||
lines.append(f"{indent_prefix}/{effective}{desc_text}")
|
||||
|
||||
walk(commands)
|
||||
return lines
|
||||
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
"""查看帮助"""
|
||||
notice = ""
|
||||
try:
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
command_lines = await self._build_reserved_command_lines()
|
||||
commands_section = (
|
||||
"\n".join(command_lines) if command_lines else "暂无启用的内置指令"
|
||||
)
|
||||
|
||||
msg_parts = [
|
||||
f"AstrBot v{VERSION}(WebUI: {dashboard_version})",
|
||||
"内置指令:",
|
||||
commands_section,
|
||||
]
|
||||
if notice:
|
||||
msg_parts.append(notice)
|
||||
msg = "\n".join(msg_parts)
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
+6
-4
@@ -184,7 +184,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -198,7 +199,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -209,8 +211,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
+2
-2
@@ -14,13 +14,13 @@ class TTSCommand:
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
"""开关文本转语音(会话级别)"""
|
||||
umo = event.unified_msg_origin
|
||||
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
tts_enable = cfg["provider_tts_settings"]["enable"]
|
||||
|
||||
# 切换状态
|
||||
new_status = not ses_tts
|
||||
SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import traceback
|
||||
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.provider import LLMResponse, ProviderRequest
|
||||
from astrbot.core import logger
|
||||
|
||||
from .commands import (
|
||||
AdminCommands,
|
||||
@@ -21,25 +16,18 @@ from .commands import (
|
||||
ToolCommands,
|
||||
TTSCommand,
|
||||
)
|
||||
from .long_term_memory import LongTermMemory
|
||||
from .process_llm_request import ProcessLLMRequest
|
||||
|
||||
|
||||
class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.ltm = None
|
||||
try:
|
||||
self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"聊天增强 err: {e}")
|
||||
|
||||
self.help_c = HelpCommand(self.context)
|
||||
self.llm_c = LLMCommands(self.context)
|
||||
self.tool_c = ToolCommands(self.context)
|
||||
self.plugin_c = PluginCommands(self.context)
|
||||
self.admin_c = AdminCommands(self.context)
|
||||
self.conversation_c = ConversationCommands(self.context, self.ltm)
|
||||
self.conversation_c = ConversationCommands(self.context)
|
||||
self.provider_c = ProviderCommands(self.context)
|
||||
self.persona_c = PersonaCommands(self.context)
|
||||
self.alter_cmd_c = AlterCmdCommands(self.context)
|
||||
@@ -47,13 +35,6 @@ class Main(star.Star):
|
||||
self.t2i_c = T2ICommand(self.context)
|
||||
self.tts_c = TTSCommand(self.context)
|
||||
self.sid_c = SIDCommand(self.context)
|
||||
self.proc_llm_req = ProcessLLMRequest(self.context)
|
||||
|
||||
def ltm_enabled(self, event: AstrMessageEvent):
|
||||
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_ltm_settings"
|
||||
]
|
||||
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
|
||||
|
||||
@filter.command("help")
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
@@ -68,7 +49,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("tool")
|
||||
def tool(self):
|
||||
pass
|
||||
"""函数工具管理"""
|
||||
|
||||
@tool.command("ls")
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
@@ -92,7 +73,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("plugin")
|
||||
def plugin(self):
|
||||
pass
|
||||
"""插件管理"""
|
||||
|
||||
@plugin.command("ls")
|
||||
async def plugin_ls(self, event: AstrMessageEvent):
|
||||
@@ -238,6 +219,7 @@ class Main(star.Star):
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard_update")
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
"""更新管理面板"""
|
||||
await self.admin_c.update_dashboard(event)
|
||||
|
||||
@filter.command("set")
|
||||
@@ -248,95 +230,6 @@ class Main(star.Star):
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
await self.setunset_c.unset_variable(event, key)
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
"""群聊记忆增强"""
|
||||
has_image_or_plain = False
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Plain) or isinstance(comp, Image):
|
||||
has_image_or_plain = True
|
||||
break
|
||||
|
||||
if self.ltm_enabled(event) and self.ltm and has_image_or_plain:
|
||||
need_active = await self.ltm.need_active_reply(event)
|
||||
|
||||
group_icl_enable = self.context.get_config()["provider_ltm_settings"][
|
||||
"group_icl_enable"
|
||||
]
|
||||
if group_icl_enable:
|
||||
"""记录对话"""
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
|
||||
if need_active:
|
||||
"""主动回复"""
|
||||
provider = self.context.get_using_provider(event.unified_msg_origin)
|
||||
if not provider:
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
conv = None
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
if not session_curr_cid:
|
||||
logger.error(
|
||||
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
)
|
||||
return
|
||||
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
session_curr_cid,
|
||||
)
|
||||
|
||||
prompt = event.message_str
|
||||
|
||||
if not conv:
|
||||
logger.error("未找到对话,无法主动回复")
|
||||
return
|
||||
|
||||
yield event.request_llm(
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
conversation=conv,
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
await self.proc_llm_req.process_llm_request(event, req)
|
||||
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.on_req_llm(event, req)
|
||||
except BaseException as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
if show_reasoning and resp.reasoning_content:
|
||||
resp.completion_text = (
|
||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||
)
|
||||
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("alter_cmd", alias={"alter"})
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
@@ -0,0 +1,4 @@
|
||||
name: builtin_commands
|
||||
desc: AstrBot 自带指令,提供常用的对话管理、工具使用、插件管理等功能。
|
||||
author: Soulter
|
||||
version: 0.0.1
|
||||
+138
-132
@@ -14,6 +14,7 @@ from astrbot.api import llm_tool, logger, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
from astrbot.api.message_components import File, Image
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||
|
||||
@@ -156,9 +157,8 @@ class Main(star.Star):
|
||||
async def is_docker_available(self) -> bool:
|
||||
"""Check if docker is available"""
|
||||
try:
|
||||
docker = aiodocker.Docker()
|
||||
await docker.version()
|
||||
await docker.close()
|
||||
async with aiodocker.Docker() as docker:
|
||||
await docker.version()
|
||||
return True
|
||||
except BaseException as e:
|
||||
logger.info(f"检查 Docker 可用性: {e}")
|
||||
@@ -224,6 +224,8 @@ class Main(star.Star):
|
||||
del self.user_waiting[uid]
|
||||
elif isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
if image_url is None:
|
||||
raise ValueError("Image URL is None")
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
elif image_url.startswith("file:///"):
|
||||
@@ -240,11 +242,13 @@ class Main(star.Star):
|
||||
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
||||
if event.get_session_id() in self.user_file_msg_buffer:
|
||||
files = self.user_file_msg_buffer[event.get_session_id()]
|
||||
if not request.prompt:
|
||||
request.prompt = ""
|
||||
request.prompt += f"\nUser provided files: {files}"
|
||||
|
||||
@filter.command_group("pi")
|
||||
def pi(self):
|
||||
pass
|
||||
"""代码执行器配置"""
|
||||
|
||||
@pi.command("absdir")
|
||||
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
|
||||
@@ -274,14 +278,14 @@ class Main(star.Star):
|
||||
@pi.command("repull")
|
||||
async def pi_repull(self, event: AstrMessageEvent):
|
||||
"""重新拉取沙箱镜像"""
|
||||
docker = aiodocker.Docker()
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
await docker.images.delete(image_name, force=True)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
async with aiodocker.Docker() as docker:
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
await docker.images.delete(image_name, force=True)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
yield event.plain_result("重新拉取沙箱镜像成功。")
|
||||
|
||||
@pi.command("file")
|
||||
@@ -366,135 +370,137 @@ class Main(star.Star):
|
||||
obs = ""
|
||||
n = 5
|
||||
|
||||
for i in range(n):
|
||||
if i > 0:
|
||||
logger.info(f"Try {i + 1}/{n}")
|
||||
async with aiodocker.Docker() as docker:
|
||||
for i in range(n):
|
||||
if i > 0:
|
||||
logger.info(f"Try {i + 1}/{n}")
|
||||
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
provider = self.context.get_using_provider()
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=PROMPT_,
|
||||
session_id=f"{event.session_id}_{magic_code}_{i!s}",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"code interpreter llm gened code:" + llm_response.completion_text,
|
||||
)
|
||||
|
||||
# 整理代码并保存
|
||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
||||
f.write(code_clean)
|
||||
|
||||
# 启动容器
|
||||
docker = aiodocker.Docker()
|
||||
|
||||
# 检查有没有image
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
# 拉取镜像
|
||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
||||
await docker.images.pull(image_name)
|
||||
|
||||
yield event.plain_result(
|
||||
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
|
||||
)
|
||||
|
||||
self.docker_host_astrbot_abs_path = self.config.get(
|
||||
"docker_host_astrbot_abs_path",
|
||||
"",
|
||||
)
|
||||
if self.docker_host_astrbot_abs_path:
|
||||
host_shared = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
self.shared_path,
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
host_output = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
output_path,
|
||||
)
|
||||
host_workplace = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
workplace_path,
|
||||
provider = self.context.get_using_provider()
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=PROMPT_,
|
||||
session_id=f"{event.session_id}_{magic_code}_{i!s}",
|
||||
)
|
||||
|
||||
else:
|
||||
host_shared = os.path.abspath(self.shared_path)
|
||||
host_output = os.path.abspath(output_path)
|
||||
host_workplace = os.path.abspath(workplace_path)
|
||||
logger.debug(
|
||||
"code interpreter llm gened code:" + llm_response.completion_text,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
|
||||
)
|
||||
# 整理代码并保存
|
||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
||||
f.write(code_clean)
|
||||
|
||||
container = await docker.containers.run(
|
||||
{
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
"Memory": 512 * 1024 * 1024,
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
||||
],
|
||||
},
|
||||
"Env": [f"MAGIC_CODE={magic_code}"],
|
||||
"AutoRemove": True,
|
||||
},
|
||||
)
|
||||
# 检查有没有image
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
# 拉取镜像
|
||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
||||
await docker.images.pull(image_name)
|
||||
|
||||
logger.debug(f"Container {container.id} created.")
|
||||
logs = await self.run_container(container)
|
||||
yield event.plain_result(
|
||||
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
|
||||
)
|
||||
|
||||
logger.debug(f"Container {container.id} finished.")
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
match = re.match(pattern, log)
|
||||
if match:
|
||||
ok = True
|
||||
if match.group(1) == "TEXT":
|
||||
yield event.plain_result(match.group(2))
|
||||
elif match.group(1) == "IMAGE":
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
# logger.debug(f"Sending file: {file_path}")
|
||||
# file_s3_url = await self.file_upload(file_path)
|
||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain = [File(name=file_name, file=file_path)]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(
|
||||
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
|
||||
self.docker_host_astrbot_abs_path = self.config.get(
|
||||
"docker_host_astrbot_abs_path",
|
||||
"",
|
||||
)
|
||||
if self.docker_host_astrbot_abs_path:
|
||||
host_shared = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
self.shared_path,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# 成功了
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
return
|
||||
host_output = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
output_path,
|
||||
)
|
||||
host_workplace = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
workplace_path,
|
||||
)
|
||||
|
||||
else:
|
||||
host_shared = os.path.abspath(self.shared_path)
|
||||
host_output = os.path.abspath(output_path)
|
||||
host_workplace = os.path.abspath(workplace_path)
|
||||
|
||||
logger.debug(
|
||||
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
|
||||
)
|
||||
|
||||
container = await docker.containers.run(
|
||||
{
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
"Memory": 512 * 1024 * 1024,
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
||||
],
|
||||
},
|
||||
"Env": [f"MAGIC_CODE={magic_code}"],
|
||||
"AutoRemove": True,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"Container {container.id} created.")
|
||||
logs = await self.run_container(container)
|
||||
|
||||
logger.debug(f"Container {container.id} finished.")
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
match = re.match(pattern, log)
|
||||
if match:
|
||||
ok = True
|
||||
if match.group(1) == "TEXT":
|
||||
yield event.plain_result(match.group(2))
|
||||
elif match.group(1) == "IMAGE":
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
# logger.debug(f"Sending file: {file_path}")
|
||||
# file_s3_url = await self.file_upload(file_path)
|
||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain: list[BaseMessageComponent] = [
|
||||
File(name=file_name, file=file_path)
|
||||
]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif (
|
||||
"Traceback (most recent call last)" in log or "[Error]: " in log
|
||||
):
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(
|
||||
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
|
||||
)
|
||||
break
|
||||
else:
|
||||
# 成功了
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
return
|
||||
|
||||
yield event.plain_result(
|
||||
"经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。",
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
import zoneinfo
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from astrbot.api import llm_tool, logger, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
@@ -62,13 +63,13 @@ class Main(star.Star):
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
elif "cron" in reminder:
|
||||
trigger = CronTrigger(**self._parse_cron_expr(reminder["cron"]))
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
trigger="cron",
|
||||
trigger=trigger,
|
||||
id=id_,
|
||||
args=[group, reminder],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(reminder["cron"]),
|
||||
)
|
||||
|
||||
def check_is_outdated(self, reminder: dict):
|
||||
@@ -101,10 +102,10 @@ class Main(star.Star):
|
||||
async def reminder_tool(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
text: str = None,
|
||||
datetime_str: str = None,
|
||||
cron_expression: str = None,
|
||||
human_readable_cron: str = None,
|
||||
text: str | None = None,
|
||||
datetime_str: str | None = None,
|
||||
cron_expression: str | None = None,
|
||||
human_readable_cron: str | None = None,
|
||||
):
|
||||
"""Call this function when user is asking for setting a reminder.
|
||||
|
||||
@@ -139,17 +140,19 @@ class Main(star.Star):
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
trigger = CronTrigger(**self._parse_cron_expr(cron_expression))
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
"cron",
|
||||
trigger,
|
||||
id=d["id"],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(cron_expression),
|
||||
args=[event.unified_msg_origin, d],
|
||||
)
|
||||
if human_readable_cron:
|
||||
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
||||
else:
|
||||
if datetime_str is None:
|
||||
raise ValueError("datetime_str cannot be None.")
|
||||
d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())}
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
datetime_scheduled = datetime.datetime.strptime(
|
||||
@@ -176,7 +179,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("reminder")
|
||||
def reminder(self):
|
||||
"""The command group of the reminder."""
|
||||
"""待办提醒"""
|
||||
|
||||
async def get_upcoming_reminders(self, unified_msg_origin: str):
|
||||
"""Get upcoming reminders."""
|
||||
+16
-9
@@ -3,7 +3,7 @@ import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
|
||||
@@ -45,13 +45,13 @@ class SearchEngine:
|
||||
self.page = 1
|
||||
self.headers = HEADERS
|
||||
|
||||
def _set_selector(self, selector: str) -> None:
|
||||
def _set_selector(self, selector: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_next_page(self):
|
||||
def _get_next_page(self, query: str):
|
||||
raise NotImplementedError
|
||||
|
||||
async def _get_html(self, url: str, data: dict = None) -> str:
|
||||
async def _get_html(self, url: str, data: dict | None = None) -> str:
|
||||
headers = self.headers
|
||||
headers["Referer"] = url
|
||||
headers["User-Agent"] = random.choice(USER_AGENTS)
|
||||
@@ -83,6 +83,9 @@ class SearchEngine:
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
def _get_url(self, tag: Tag) -> str:
|
||||
return self.tidy_text(tag.get_text())
|
||||
|
||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||
query = urllib.parse.quote(query)
|
||||
|
||||
@@ -92,12 +95,16 @@ class SearchEngine:
|
||||
links = soup.select(self._set_selector("links"))
|
||||
results = []
|
||||
for link in links:
|
||||
title = self.tidy_text(
|
||||
link.select_one(self._set_selector("title")).text,
|
||||
)
|
||||
url = link.select_one(self._set_selector("url"))
|
||||
# Safely get the title text (select_one may return None)
|
||||
title_elem = link.select_one(self._set_selector("title"))
|
||||
title = ""
|
||||
if title_elem is not None:
|
||||
title = self.tidy_text(title_elem.get_text())
|
||||
|
||||
url_tag = link.select_one(self._set_selector("url"))
|
||||
snippet = ""
|
||||
if title and url:
|
||||
if title and url_tag:
|
||||
url = self._get_url(url_tag)
|
||||
results.append(SearchResult(title=title, url=url, snippet=snippet))
|
||||
return results[:num_results] if len(results) > num_results else results
|
||||
except Exception as e:
|
||||
+1
-9
@@ -1,4 +1,4 @@
|
||||
from . import USER_AGENT_BING, SearchEngine, SearchResult
|
||||
from . import USER_AGENT_BING, SearchEngine
|
||||
|
||||
|
||||
class Bing(SearchEngine):
|
||||
@@ -28,11 +28,3 @@ class Bing(SearchEngine):
|
||||
self.base_url = base_url
|
||||
continue
|
||||
raise Exception("Bing search failed")
|
||||
|
||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||
results = await super().search(query, num_results)
|
||||
for result in results:
|
||||
if not isinstance(result.url, str):
|
||||
result.url = result.url.text
|
||||
|
||||
return results
|
||||
+10
-4
@@ -1,7 +1,8 @@
|
||||
import random
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
from . import USER_AGENTS, SearchEngine, SearchResult
|
||||
|
||||
@@ -26,10 +27,12 @@ class Sogo(SearchEngine):
|
||||
url = f"{self.base_url}/web?query={query}"
|
||||
return await self._get_html(url, None)
|
||||
|
||||
def _get_url(self, tag: Tag) -> str:
|
||||
return cast(str, tag.get("href"))
|
||||
|
||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||
results = await super().search(query, num_results)
|
||||
for result in results:
|
||||
result.url = result.url.get("href")
|
||||
if result.url.startswith("/link?"):
|
||||
result.url = self.base_url + result.url
|
||||
result.url = await self._parse_url(result.url)
|
||||
@@ -40,7 +43,10 @@ class Sogo(SearchEngine):
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
script = soup.find("script")
|
||||
if script:
|
||||
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(
|
||||
1,
|
||||
script_text = (
|
||||
script.string if script.string is not None else script.get_text()
|
||||
)
|
||||
match = re.search(r'window.location.replace\("(.+?)"\)', script_text)
|
||||
if match:
|
||||
url = match.group(1)
|
||||
return url
|
||||
@@ -185,6 +185,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command("websearch")
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
|
||||
"""网页搜索指令(已废弃)"""
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。",
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.8.0"
|
||||
__version__ = "4.11.0"
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot import logger
|
||||
else:
|
||||
try:
|
||||
from astrbot import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContextCompressor(Protocol):
|
||||
"""
|
||||
Protocol for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens for the model.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor:
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
drop_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
def split_history(
|
||||
messages: list[Message], keep_recent: int
|
||||
) -> tuple[list[Message], list[Message], list[Message]]:
|
||||
"""Split the message list into system messages, messages to summarize, and recent messages.
|
||||
|
||||
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
keep_recent: The number of latest messages to keep.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, messages_to_summarize, recent_messages)
|
||||
"""
|
||||
# keep the system messages
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) <= keep_recent:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
# Find the split point, ensuring recent_messages starts with a user message
|
||||
# This maintains complete conversation turns
|
||||
split_index = len(non_system_messages) - keep_recent
|
||||
|
||||
# Search backward from split_index to find the first user message
|
||||
# This ensures recent_messages starts with a user message (complete turn)
|
||||
while split_index > 0 and non_system_messages[split_index].role != "user":
|
||||
# TODO: +=1 or -=1 ? calculate by tokens
|
||||
split_index -= 1
|
||||
|
||||
# If we couldn't find a user message, keep all messages as recent
|
||||
if split_index == 0:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
messages_to_summarize = non_system_messages[:split_index]
|
||||
recent_messages = non_system_messages[split_index:]
|
||||
|
||||
return system_messages, messages_to_summarize, recent_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor:
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
1. Divide messages: keep the system message and the latest N messages.
|
||||
2. Send the old messages + the instruction message to the LLM.
|
||||
3. Reconstruct the message list: [system message, summary message, latest messages].
|
||||
"""
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
system_messages, messages_to_summarize, recent_messages = split_history(
|
||||
messages, self.keep_recent
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
response = await self.provider.text_chat(contexts=llm_payload)
|
||||
summary_content = response.completion_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
return messages
|
||||
|
||||
# build result
|
||||
result = []
|
||||
result.extend(system_messages)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=f"Our previous history conversation summary: {summary_content}",
|
||||
)
|
||||
)
|
||||
result.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Acknowledged the summary of our previous conversation history.",
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,35 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compressor import ContextCompressor
|
||||
from .token_counter import TokenCounter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Context configuration class."""
|
||||
|
||||
max_context_tokens: int = 0
|
||||
"""Maximum number of context tokens. <= 0 means no limit."""
|
||||
enforce_max_turns: int = -1 # -1 means no limit
|
||||
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
||||
truncate_turns: int = 1
|
||||
"""Number of conversation turns to discard at once when truncation is triggered.
|
||||
Two processes will use this value:
|
||||
|
||||
1. Enforce max turns truncation.
|
||||
2. Truncation by turns compression strategy.
|
||||
"""
|
||||
llm_compress_instruction: str | None = None
|
||||
"""Instruction prompt for LLM-based compression."""
|
||||
llm_compress_keep_recent: int = 0
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
custom_token_counter: TokenCounter | None = None
|
||||
"""Custom token counting method. If None, the default method is used."""
|
||||
custom_compressor: ContextCompressor | None = None
|
||||
"""Custom context compression method. If None, the default method is used."""
|
||||
@@ -0,0 +1,120 @@
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
1. Truncate by turns: remove older messages by turns.
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
config: The context configuration.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if config.custom_compressor:
|
||||
self.compressor = config.custom_compressor
|
||||
elif config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
truncate_turns=config.truncate_turns
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> list[Message]:
|
||||
"""Process the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
total_tokens = self.token_counter.count_tokens(
|
||||
result, trusted_token_usage
|
||||
)
|
||||
|
||||
if self.compressor.should_compress(
|
||||
result, total_tokens, self.config.max_context_tokens
|
||||
):
|
||||
result = await self._run_compression(result, total_tokens)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], prev_tokens: int
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
prev_tokens: The token count before compression.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
|
||||
# calculate compress rate
|
||||
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
|
||||
logger.info(
|
||||
f"Compress completed."
|
||||
f" {prev_tokens} -> {tokens_after_summary} tokens,"
|
||||
f" compression rate: {compress_rate:.2f}%.",
|
||||
)
|
||||
|
||||
# last check
|
||||
if self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
|
||||
return messages
|
||||
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TokenCounter(Protocol):
|
||||
"""
|
||||
Protocol for token counters.
|
||||
Provides an interface for counting tokens in message lists.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
"""Count the total tokens in the message list.
|
||||
|
||||
Args:
|
||||
messages: The message list.
|
||||
trusted_token_usage: The total token usage that LLM API returned.
|
||||
For some cases, this value is more accurate.
|
||||
But some API does not return it, so the value defaults to 0.
|
||||
|
||||
Returns:
|
||||
The total token count.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
if trusted_token_usage > 0:
|
||||
return trusted_token_usage
|
||||
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
total += self._estimate_tokens(tc_str)
|
||||
|
||||
return total
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
||||
other_count = len(text) - chinese_count
|
||||
return int(chinese_count * 0.6 + other_count * 0.3)
|
||||
@@ -0,0 +1,141 @@
|
||||
from ..message import Message
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
"""Context truncator."""
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.role == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
def truncate_by_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
return messages
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class ContentPart(BaseModel):
|
||||
|
||||
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
||||
|
||||
type: str
|
||||
type: Literal["text", "think", "image_url", "audio_url"]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -63,6 +63,28 @@ class TextPart(ContentPart):
|
||||
text: str
|
||||
|
||||
|
||||
class ThinkPart(ContentPart):
|
||||
"""
|
||||
>>> ThinkPart(think="I think I need to think about this.").model_dump()
|
||||
{'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None}
|
||||
"""
|
||||
|
||||
type: str = "think"
|
||||
think: str
|
||||
encrypted: str | None = None
|
||||
"""Encrypted thinking content, or signature."""
|
||||
|
||||
def merge_in_place(self, other: Any) -> bool:
|
||||
if not isinstance(other, ThinkPart):
|
||||
return False
|
||||
if self.encrypted:
|
||||
return False
|
||||
self.think += other.think
|
||||
if other.encrypted:
|
||||
self.encrypted = other.encrypted
|
||||
return True
|
||||
|
||||
|
||||
class ImageURLPart(ContentPart):
|
||||
"""
|
||||
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
||||
@@ -122,10 +144,12 @@ class ToolCall(BaseModel):
|
||||
extra_content: dict[str, Any] | None = None
|
||||
"""Extra metadata for the tool call."""
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.extra_content is None:
|
||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||
return super().model_dump(**kwargs)
|
||||
data.pop("extra_content", None)
|
||||
return data
|
||||
|
||||
|
||||
class ToolCallPart(BaseModel):
|
||||
@@ -167,6 +191,15 @@ class Message(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.tool_calls is None:
|
||||
data.pop("tool_calls", None)
|
||||
if self.tool_call_id is None:
|
||||
data.pop("tool_call_id", None)
|
||||
return data
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import TokenUsage
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStats:
|
||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||
start_time: float = 0.0
|
||||
end_time: float = 0.0
|
||||
time_to_first_token: float = 0.0
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"token_usage": self.token_usage.__dict__,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"time_to_first_token": self.time_to_first_token,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from .message import Message
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
@@ -12,6 +13,8 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import TextPart, ThinkPart
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
@@ -22,9 +25,13 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||
@@ -44,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
enforce_max_turns: int = -1,
|
||||
# llm compressor
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
# truncate by turns compressor
|
||||
truncate_turns: int = 1,
|
||||
# customize
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.streaming = streaming
|
||||
self.enforce_max_turns = enforce_max_turns
|
||||
self.llm_compress_instruction = llm_compress_instruction
|
||||
self.llm_compress_keep_recent = llm_compress_keep_recent
|
||||
self.llm_compress_provider = llm_compress_provider
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_config = ContextConfig(
|
||||
# <=0 will never do compress
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
# enforce max turns before compression
|
||||
enforce_max_turns=self.enforce_max_turns,
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
custom_token_counter=self.custom_token_counter,
|
||||
custom_compressor=self.custom_compressor,
|
||||
)
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
@@ -69,14 +113,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
self.stats = AgentStats()
|
||||
self.stats.start_time = time.time()
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
@@ -96,9 +151,18 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate and compress
|
||||
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
|
||||
self.run_context.messages = await self.context_manager.process(
|
||||
self.run_context.messages, trusted_token_usage=token_usage
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
if self.stats.time_to_first_token == 0:
|
||||
self.stats.time_to_first_token = time.time() - self.stats.start_time
|
||||
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
@@ -122,6 +186,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
|
||||
if not llm_response.is_chunk and llm_response.usage:
|
||||
# only count the token usage of the final response for computation purpose
|
||||
self.stats.token_usage += llm_response.usage
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
@@ -133,6 +201,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self.stats.end_time = time.time()
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
@@ -147,13 +216,21 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
),
|
||||
)
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
self.run_context.messages.append(Message(role="assistant", content=parts))
|
||||
|
||||
# call the on_agent_done hook
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -176,29 +253,35 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
if result.type is None:
|
||||
# should not happen
|
||||
continue
|
||||
if result.type == "tool_direct_result":
|
||||
ar_type = "tool_call_result"
|
||||
else:
|
||||
ar_type = result.type
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
type=ar_type,
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||
content=llm_resp.completion_text,
|
||||
content=parts,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
@@ -219,6 +302,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
|
||||
if not self.done():
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
# 拔掉所有工具
|
||||
if self.req:
|
||||
self.req.func_tool = None
|
||||
# 注入提示词
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
# 再执行最后一步
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -234,6 +336,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
yield MessageChain(
|
||||
type="tool_call",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"name": func_tool_name,
|
||||
"args": func_tool_args,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
@@ -307,7 +422,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=res.content[0].text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -329,7 +443,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=resource.text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
@@ -353,20 +466,34 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content="返回的数据类型不受支持",
|
||||
),
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||
logger.warning(
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具没有返回值或者将结果直接发送给了用户*",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||
f"Tool 返回了不支持的类型: {type(resp)}。",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -388,6 +515,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any, Generic
|
||||
|
||||
import jsonschema
|
||||
@@ -7,6 +7,8 @@ from deprecated import deprecated
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
@@ -38,7 +40,10 @@ class ToolSchema:
|
||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
handler: (
|
||||
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
|
||||
| None
|
||||
) = None
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
|
||||
@@ -6,8 +6,10 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
__pydantic_config__ = {"arbitrary_types_allowed": True}
|
||||
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
|
||||
@@ -13,6 +13,12 @@ from astrbot.core.star.star_handler import EventType
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
if llm_response and llm_response.reasoning_content:
|
||||
# we will use this in result_decorate stage to inject reasoning content to chain
|
||||
run_context.context.event.set_extra(
|
||||
"_llm_reasoning_content", llm_response.reasoning_content
|
||||
)
|
||||
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
|
||||
@@ -2,8 +2,10 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
@@ -23,8 +25,25 @@ async def run_agent(
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
while step_idx < max_step + 1:
|
||||
step_idx += 1
|
||||
|
||||
if step_idx == max_step + 1:
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
if not agent_runner.done():
|
||||
# 拔掉所有工具
|
||||
if agent_runner.req:
|
||||
agent_runner.req.func_tool = None
|
||||
# 注入提示词
|
||||
agent_runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
@@ -33,16 +52,27 @@ async def run_agent(
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(resp.data["chain"])
|
||||
await astr_event.send(msg_chain)
|
||||
continue
|
||||
if astr_event.get_platform_id() == "webchat":
|
||||
await astr_event.send(msg_chain)
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use:
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
await astr_event.send(chain)
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
@@ -69,6 +99,15 @@ async def run_agent(
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
# send agent stats to webchat
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(
|
||||
MessageChain(
|
||||
type="agent_stats",
|
||||
chain=[Json(data=agent_runner.stats.to_dict())],
|
||||
)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[
|
||||
...,
|
||||
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
|
||||
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
|
||||
],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
@@ -205,12 +209,42 @@ async def call_local_llm_tool(
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
raise Exception(f"Tool execution ValueError: {e}") from e
|
||||
except TypeError as e:
|
||||
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
|
||||
try:
|
||||
sig = inspect.signature(handler)
|
||||
params = list(sig.parameters.values())
|
||||
# 跳过第一个参数(event 或 context)
|
||||
if params:
|
||||
params = params[1:]
|
||||
|
||||
param_strs = []
|
||||
for param in params:
|
||||
param_str = param.name
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
# 获取类型注解的字符串表示
|
||||
if isinstance(param.annotation, type):
|
||||
type_str = param.annotation.__name__
|
||||
else:
|
||||
type_str = str(param.annotation)
|
||||
param_str += f": {type_str}"
|
||||
if param.default != inspect.Parameter.empty:
|
||||
param_str += f" = {param.default!r}"
|
||||
param_strs.append(param_str)
|
||||
|
||||
handler_param_str = (
|
||||
", ".join(param_strs) if param_strs else "(no additional parameters)"
|
||||
)
|
||||
except Exception:
|
||||
handler_param_str = "(unable to inspect signature)"
|
||||
|
||||
raise Exception(
|
||||
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""AstrBot 备份与恢复模块
|
||||
|
||||
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
|
||||
"""
|
||||
|
||||
# 从 constants 模块导入共享常量
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
# 导入导出器和导入器
|
||||
from .exporter import AstrBotExporter
|
||||
from .importer import AstrBotImporter, ImportPreCheckResult
|
||||
|
||||
__all__ = [
|
||||
"AstrBotExporter",
|
||||
"AstrBotImporter",
|
||||
"ImportPreCheckResult",
|
||||
"MAIN_DB_MODELS",
|
||||
"KB_METADATA_MODELS",
|
||||
"get_backup_directories",
|
||||
"BACKUP_MANIFEST_VERSION",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""AstrBot 备份模块共享常量
|
||||
|
||||
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
)
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_data_path,
|
||||
get_astrbot_plugin_path,
|
||||
get_astrbot_t2i_templates_path,
|
||||
get_astrbot_temp_path,
|
||||
get_astrbot_webchat_path,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 共享常量 - 确保导出和导入端配置一致
|
||||
# ============================================================
|
||||
|
||||
# 主数据库模型类映射
|
||||
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
|
||||
"platform_stats": PlatformStat,
|
||||
"conversations": ConversationV2,
|
||||
"personas": Persona,
|
||||
"preferences": Preference,
|
||||
"platform_message_history": PlatformMessageHistory,
|
||||
"platform_sessions": PlatformSession,
|
||||
"attachments": Attachment,
|
||||
"command_configs": CommandConfig,
|
||||
"command_conflicts": CommandConflict,
|
||||
}
|
||||
|
||||
# 知识库元数据模型类映射
|
||||
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
|
||||
"knowledge_bases": KnowledgeBase,
|
||||
"kb_documents": KBDocument,
|
||||
"kb_media": KBMedia,
|
||||
}
|
||||
|
||||
|
||||
def get_backup_directories() -> dict[str, str]:
|
||||
"""获取需要备份的目录列表
|
||||
|
||||
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
|
||||
|
||||
Returns:
|
||||
dict: 键为备份文件中的目录名称,值为目录的绝对路径
|
||||
"""
|
||||
return {
|
||||
"plugins": get_astrbot_plugin_path(), # 插件本体
|
||||
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
|
||||
"config": get_astrbot_config_path(), # 配置目录
|
||||
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
|
||||
"webchat": get_astrbot_webchat_path(), # WebChat 数据
|
||||
"temp": get_astrbot_temp_path(), # 临时文件
|
||||
}
|
||||
|
||||
|
||||
# 备份清单版本号
|
||||
BACKUP_MANIFEST_VERSION = "1.1"
|
||||
@@ -0,0 +1,477 @@
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
负责将所有数据导出为 ZIP 备份文件。
|
||||
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
|
||||
|
||||
class AstrBotExporter:
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
导出内容:
|
||||
- 主数据库所有表(data/data_v4.db)
|
||||
- 知识库元数据(data/knowledge_base/kb.db)
|
||||
- 每个知识库的向量文档数据
|
||||
- 配置文件(data/cmd_config.json)
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self._checksums: dict[str, str] = {}
|
||||
|
||||
async def export_all(
|
||||
self,
|
||||
output_dir: str | None = None,
|
||||
progress_callback: Any | None = None,
|
||||
) -> str:
|
||||
"""导出所有数据到 ZIP 文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
str: 生成的 ZIP 文件路径
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = get_astrbot_backups_path()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||
zip_path = os.path.join(output_dir, zip_filename)
|
||||
|
||||
logger.info(f"开始导出备份到 {zip_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# 1. 导出主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
|
||||
main_data = await self._export_main_database()
|
||||
main_db_json = json.dumps(
|
||||
main_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/main_db.json", main_db_json)
|
||||
self._add_checksum("databases/main_db.json", main_db_json)
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导出完成")
|
||||
|
||||
# 2. 导出知识库数据
|
||||
kb_meta_data: dict[str, Any] = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
"kb_media": [],
|
||||
}
|
||||
if self.kb_manager:
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 0, 100, "正在导出知识库元数据..."
|
||||
)
|
||||
kb_meta_data = await self._export_kb_metadata()
|
||||
kb_meta_json = json.dumps(
|
||||
kb_meta_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/kb_metadata.json", kb_meta_json)
|
||||
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 100, 100, "知识库元数据导出完成"
|
||||
)
|
||||
|
||||
# 导出每个知识库的文档数据
|
||||
kb_insts = self.kb_manager.kb_insts
|
||||
total_kbs = len(kb_insts)
|
||||
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents",
|
||||
idx,
|
||||
total_kbs,
|
||||
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
|
||||
)
|
||||
doc_data = await self._export_kb_documents(kb_helper)
|
||||
doc_json = json.dumps(
|
||||
doc_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
zf.writestr(doc_path, doc_json)
|
||||
self._add_checksum(doc_path, doc_json)
|
||||
|
||||
# 导出 FAISS 索引文件
|
||||
await self._export_faiss_index(zf, kb_helper, kb_id)
|
||||
|
||||
# 导出知识库多媒体文件
|
||||
await self._export_kb_media_files(zf, kb_helper, kb_id)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
|
||||
)
|
||||
|
||||
# 3. 导出配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_content = f.read()
|
||||
zf.writestr("config/cmd_config.json", config_content)
|
||||
self._add_checksum("config/cmd_config.json", config_content)
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导出完成")
|
||||
|
||||
# 4. 导出附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导出附件...")
|
||||
await self._export_attachments(zf, main_data.get("attachments", []))
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导出完成")
|
||||
|
||||
# 5. 导出插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导出插件和数据目录..."
|
||||
)
|
||||
dir_stats = await self._export_directories(zf)
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导出完成")
|
||||
|
||||
# 6. 生成 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 0, 100, "正在生成清单...")
|
||||
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", manifest_json)
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 100, 100, "清单生成完成")
|
||||
|
||||
logger.info(f"备份导出完成: {zip_path}")
|
||||
return zip_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份导出失败: {e}")
|
||||
# 清理失败的文件
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
raise
|
||||
|
||||
async def _export_main_database(self) -> dict[str, list[dict]]:
|
||||
"""导出主数据库所有表"""
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
|
||||
"""导出知识库元数据库"""
|
||||
if not self.kb_manager:
|
||||
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
|
||||
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
|
||||
"""导出知识库的文档块数据"""
|
||||
try:
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
|
||||
vec_db: FaissVecDB = kb_helper.vec_db
|
||||
if not vec_db or not vec_db.document_storage:
|
||||
return {"documents": []}
|
||||
|
||||
# 获取所有文档
|
||||
docs = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={},
|
||||
offset=0,
|
||||
limit=None, # 获取全部
|
||||
)
|
||||
|
||||
return {"documents": docs}
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库文档失败: {e}")
|
||||
return {"documents": []}
|
||||
|
||||
async def _export_faiss_index(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_helper: Any,
|
||||
kb_id: str,
|
||||
) -> None:
|
||||
"""导出 FAISS 索引文件"""
|
||||
try:
|
||||
index_path = kb_helper.kb_dir / "index.faiss"
|
||||
if index_path.exists():
|
||||
archive_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
zf.write(str(index_path), archive_path)
|
||||
logger.debug(f"导出 FAISS 索引: {archive_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"导出 FAISS 索引失败: {e}")
|
||||
|
||||
async def _export_kb_media_files(
|
||||
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
|
||||
) -> None:
|
||||
"""导出知识库的多媒体文件"""
|
||||
try:
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if not media_dir.exists():
|
||||
return
|
||||
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(kb_helper.kb_dir)
|
||||
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库媒体文件失败: {e}")
|
||||
|
||||
async def _export_directories(
|
||||
self, zf: zipfile.ZipFile
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""导出插件和其他数据目录
|
||||
|
||||
Returns:
|
||||
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
|
||||
"""
|
||||
stats: dict[str, dict[str, int]] = {}
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name, dir_path in backup_directories.items():
|
||||
full_path = Path(dir_path)
|
||||
if not full_path.exists():
|
||||
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||
continue
|
||||
|
||||
file_count = 0
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
for root, dirs, files in os.walk(full_path):
|
||||
# 跳过 __pycache__ 目录
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
|
||||
for file in files:
|
||||
# 跳过 .pyc 文件
|
||||
if file.endswith(".pyc"):
|
||||
continue
|
||||
|
||||
file_path = Path(root) / file
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(full_path)
|
||||
archive_path = f"directories/{dir_name}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
file_count += 1
|
||||
total_size += file_path.stat().st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"导出文件 {file_path} 失败: {e}")
|
||||
|
||||
stats[dir_name] = {"files": file_count, "size": total_size}
|
||||
logger.debug(
|
||||
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出目录 {dir_path} 失败: {e}")
|
||||
stats[dir_name] = {"files": 0, "size": 0}
|
||||
|
||||
return stats
|
||||
|
||||
async def _export_attachments(
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""导出附件文件"""
|
||||
for attachment in attachments:
|
||||
try:
|
||||
file_path = attachment.get("path", "")
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 使用 attachment_id 作为文件名
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出附件失败: {e}")
|
||||
|
||||
def _model_to_dict(self, record: Any) -> dict:
|
||||
"""将 SQLModel 实例转换为字典
|
||||
|
||||
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
|
||||
"""
|
||||
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
|
||||
if hasattr(record, "model_dump"):
|
||||
data = record.model_dump(mode="python")
|
||||
# 处理 datetime 类型
|
||||
for key, value in data.items():
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
return data
|
||||
|
||||
# 回退到手动提取
|
||||
data = {}
|
||||
# 使用 inspect 获取表信息
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
mapper = sa_inspect(record.__class__)
|
||||
for column in mapper.columns:
|
||||
value = getattr(record, column.name)
|
||||
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
data[column.name] = value
|
||||
return data
|
||||
|
||||
def _add_checksum(self, path: str, content: str | bytes) -> None:
|
||||
"""计算并添加文件校验和"""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
self._checksums[path] = f"sha256:{checksum}"
|
||||
|
||||
def _generate_manifest(
|
||||
self,
|
||||
main_data: dict[str, list[dict]],
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
dir_stats: dict[str, dict[str, int]] | None = None,
|
||||
) -> dict:
|
||||
"""生成备份清单"""
|
||||
if dir_stats is None:
|
||||
dir_stats = {}
|
||||
# 收集知识库 ID
|
||||
kb_document_tables = {}
|
||||
if self.kb_manager:
|
||||
for kb_id in self.kb_manager.kb_insts.keys():
|
||||
kb_document_tables[kb_id] = "documents"
|
||||
|
||||
# 收集附件文件列表
|
||||
attachment_files = []
|
||||
for attachment in main_data.get("attachments", []):
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
path = attachment.get("path", "")
|
||||
if attachment_id and path:
|
||||
ext = os.path.splitext(path)[1]
|
||||
attachment_files.append(f"{attachment_id}{ext}")
|
||||
|
||||
# 收集知识库媒体文件
|
||||
kb_media_files: dict[str, list[str]] = {}
|
||||
if self.kb_manager:
|
||||
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
|
||||
media_files: list[str] = []
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if media_dir.exists():
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
media_files.append(file)
|
||||
if media_files:
|
||||
kb_media_files[kb_id] = media_files
|
||||
|
||||
manifest = {
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
"kb_db": "v1",
|
||||
},
|
||||
"tables": {
|
||||
"main_db": list(main_data.keys()),
|
||||
"kb_metadata": list(kb_meta_data.keys()),
|
||||
"kb_documents": kb_document_tables,
|
||||
},
|
||||
"files": {
|
||||
"attachments": attachment_files,
|
||||
"kb_media": kb_media_files,
|
||||
},
|
||||
"directories": list(dir_stats.keys()),
|
||||
"checksums": self._checksums,
|
||||
"statistics": {
|
||||
"main_db": {
|
||||
table: len(records) for table, records in main_data.items()
|
||||
},
|
||||
"kb_metadata": {
|
||||
table: len(records) for table, records in kb_meta_data.items()
|
||||
},
|
||||
"directories": dir_stats,
|
||||
},
|
||||
}
|
||||
|
||||
return manifest
|
||||
@@ -0,0 +1,761 @@
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
负责从 ZIP 备份文件恢复所有数据。
|
||||
导入时进行版本校验:
|
||||
- 主版本(前两位)不同时直接拒绝导入
|
||||
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||||
- 版本匹配时也需要用户确认
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_knowledge_base_path,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
def _get_major_version(version_str: str) -> str:
|
||||
"""提取版本的主版本部分(前两位)
|
||||
|
||||
Args:
|
||||
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||||
|
||||
Returns:
|
||||
主版本字符串,如 "4.9", "4.10"
|
||||
"""
|
||||
if not version_str:
|
||||
return "0.0"
|
||||
# 移除 v 前缀和预发布标签
|
||||
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||||
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[0]}.{parts[1]}"
|
||||
elif len(parts) == 1 and parts[0]:
|
||||
return f"{parts[0]}.0"
|
||||
return "0.0"
|
||||
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
KB_PATH = get_astrbot_knowledge_base_path()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportPreCheckResult:
|
||||
"""导入预检查结果
|
||||
|
||||
用于在实际导入前检查备份文件的版本兼容性,
|
||||
并返回确认信息让用户决定是否继续导入。
|
||||
"""
|
||||
|
||||
# 检查是否通过(文件有效且版本可导入)
|
||||
valid: bool = False
|
||||
# 是否可以导入(版本兼容)
|
||||
can_import: bool = False
|
||||
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||||
version_status: str = ""
|
||||
# 备份文件中的 AstrBot 版本
|
||||
backup_version: str = ""
|
||||
# 当前运行的 AstrBot 版本
|
||||
current_version: str = VERSION
|
||||
# 备份创建时间
|
||||
backup_time: str = ""
|
||||
# 确认消息(显示给用户)
|
||||
confirm_message: str = ""
|
||||
# 警告消息列表
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# 错误消息(如果检查失败)
|
||||
error: str = ""
|
||||
# 备份包含的内容摘要
|
||||
backup_summary: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"valid": self.valid,
|
||||
"can_import": self.can_import,
|
||||
"version_status": self.version_status,
|
||||
"backup_version": self.backup_version,
|
||||
"current_version": self.current_version,
|
||||
"backup_time": self.backup_time,
|
||||
"confirm_message": self.confirm_message,
|
||||
"warnings": self.warnings,
|
||||
"error": self.error,
|
||||
"backup_summary": self.backup_summary,
|
||||
}
|
||||
|
||||
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.success = True
|
||||
self.imported_tables: dict[str, int] = {}
|
||||
self.imported_files: dict[str, int] = {}
|
||||
self.imported_directories: dict[str, int] = {}
|
||||
self.warnings: list[str] = []
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_warning(self, msg: str) -> None:
|
||||
self.warnings.append(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
def add_error(self, msg: str) -> None:
|
||||
self.errors.append(msg)
|
||||
self.success = False
|
||||
logger.error(msg)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"imported_tables": self.imported_tables,
|
||||
"imported_files": self.imported_files,
|
||||
"imported_directories": self.imported_directories,
|
||||
"warnings": self.warnings,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class AstrBotImporter:
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
导入备份文件中的所有数据,包括:
|
||||
- 主数据库所有表
|
||||
- 知识库元数据和文档
|
||||
- 配置文件
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
kb_root_dir: str = KB_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self.kb_root_dir = kb_root_dir
|
||||
|
||||
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||||
"""预检查备份文件
|
||||
|
||||
在实际导入前检查备份文件的有效性和版本兼容性。
|
||||
返回检查结果供前端显示确认对话框。
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
|
||||
Returns:
|
||||
ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
result = ImportPreCheckResult()
|
||||
result.current_version = VERSION
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.error = f"备份文件不存在: {zip_path}"
|
||||
return result
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.error = f"manifest.json 格式错误: {e}"
|
||||
return result
|
||||
|
||||
# 提取基本信息
|
||||
result.backup_version = manifest.get("astrbot_version", "未知")
|
||||
result.backup_time = manifest.get("exported_at", "未知")
|
||||
result.valid = True
|
||||
|
||||
# 构建备份摘要
|
||||
result.backup_summary = {
|
||||
"tables": list(manifest.get("tables", {}).keys()),
|
||||
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||||
"has_config": manifest.get("has_config", False),
|
||||
"directories": manifest.get("directories", []),
|
||||
}
|
||||
|
||||
# 检查版本兼容性
|
||||
version_check = self._check_version_compatibility(result.backup_version)
|
||||
result.version_status = version_check["status"]
|
||||
result.can_import = version_check["can_import"]
|
||||
|
||||
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||||
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.error = "无效的 ZIP 文件"
|
||||
return result
|
||||
except Exception as e:
|
||||
result.error = f"检查备份文件失败: {e}"
|
||||
return result
|
||||
|
||||
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||||
"""检查版本兼容性
|
||||
|
||||
规则:
|
||||
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||||
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||||
|
||||
Returns:
|
||||
dict: {status, can_import, message}
|
||||
"""
|
||||
if not backup_version:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": "备份文件缺少版本信息",
|
||||
}
|
||||
|
||||
# 提取主版本(前两位)进行比较
|
||||
backup_major = _get_major_version(backup_version)
|
||||
current_major = _get_major_version(VERSION)
|
||||
|
||||
# 比较主版本
|
||||
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": (
|
||||
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||||
),
|
||||
}
|
||||
|
||||
# 比较完整版本
|
||||
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||||
if version_cmp != 0:
|
||||
return {
|
||||
"status": "minor_diff",
|
||||
"can_import": True,
|
||||
"message": (
|
||||
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "match",
|
||||
"can_import": True,
|
||||
"message": "版本匹配",
|
||||
}
|
||||
|
||||
async def import_all(
|
||||
self,
|
||||
zip_path: str,
|
||||
mode: str = "replace", # "replace" 清空后导入
|
||||
progress_callback: Any | None = None,
|
||||
) -> ImportResult:
|
||||
"""从 ZIP 文件导入所有数据
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.add_error(f"备份文件不存在: {zip_path}")
|
||||
return result
|
||||
|
||||
logger.info(f"开始从 {zip_path} 导入备份")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 1. 读取并验证 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||||
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.add_error("备份文件缺少 manifest.json")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.add_error(f"manifest.json 格式错误: {e}")
|
||||
return result
|
||||
|
||||
# 版本校验
|
||||
try:
|
||||
self._validate_version(manifest)
|
||||
except ValueError as e:
|
||||
result.add_error(str(e))
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 100, 100, "验证完成")
|
||||
|
||||
# 2. 导入主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||||
|
||||
try:
|
||||
main_data_content = zf.read("databases/main_db.json")
|
||||
main_data = json.loads(main_data_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_main_db()
|
||||
|
||||
imported = await self._import_main_database(main_data)
|
||||
result.imported_tables.update(imported)
|
||||
except Exception as e:
|
||||
result.add_error(f"导入主数据库失败: {e}")
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||||
|
||||
# 3. 导入知识库
|
||||
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||||
|
||||
try:
|
||||
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||||
kb_meta_data = json.loads(kb_meta_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_kb_data()
|
||||
|
||||
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||||
|
||||
# 4. 导入配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||||
|
||||
if "config/cmd_config.json" in zf.namelist():
|
||||
try:
|
||||
config_content = zf.read("config/cmd_config.json")
|
||||
# 备份现有配置
|
||||
if os.path.exists(self.config_path):
|
||||
backup_path = f"{self.config_path}.bak"
|
||||
shutil.copy2(self.config_path, backup_path)
|
||||
|
||||
with open(self.config_path, "wb") as f:
|
||||
f.write(config_content)
|
||||
result.imported_files["config"] = 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入配置文件失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||||
|
||||
# 5. 导入附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||||
|
||||
attachment_count = await self._import_attachments(
|
||||
zf, main_data.get("attachments", [])
|
||||
)
|
||||
result.imported_files["attachments"] = attachment_count
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||||
|
||||
# 6. 导入插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导入插件和数据目录..."
|
||||
)
|
||||
|
||||
dir_stats = await self._import_directories(zf, manifest, result)
|
||||
result.imported_directories = dir_stats
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导入完成")
|
||||
|
||||
logger.info(f"备份导入完成: {result.to_dict()}")
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.add_error("无效的 ZIP 文件")
|
||||
return result
|
||||
except Exception as e:
|
||||
result.add_error(f"导入失败: {e}")
|
||||
return result
|
||||
|
||||
def _validate_version(self, manifest: dict) -> None:
|
||||
"""验证版本兼容性 - 仅允许相同主版本导入
|
||||
|
||||
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||||
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||||
"""
|
||||
backup_version = manifest.get("astrbot_version")
|
||||
if not backup_version:
|
||||
raise ValueError("备份文件缺少版本信息")
|
||||
|
||||
# 使用新的版本兼容性检查
|
||||
version_check = self._check_version_compatibility(backup_version)
|
||||
|
||||
if version_check["status"] == "major_diff":
|
||||
raise ValueError(version_check["message"])
|
||||
|
||||
# minor_diff 和 match 都允许导入
|
||||
if version_check["status"] == "minor_diff":
|
||||
logger.warning(f"版本差异警告: {version_check['message']}")
|
||||
|
||||
async def _clear_main_db(self) -> None:
|
||||
"""清空主数据库所有表"""
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||
|
||||
async def _clear_kb_data(self) -> None:
|
||||
"""清空知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 清空知识库元数据表
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空知识库表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||||
|
||||
# 删除知识库文件目录
|
||||
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||||
try:
|
||||
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||||
await kb_helper.terminate()
|
||||
if kb_helper.kb_dir.exists():
|
||||
shutil.rmtree(kb_helper.kb_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||||
|
||||
self.kb_manager.kb_insts.clear()
|
||||
|
||||
async def _import_main_database(
|
||||
self, data: dict[str, list[dict]]
|
||||
) -> dict[str, int]:
|
||||
"""导入主数据库数据"""
|
||||
imported: dict[str, int] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in data.items():
|
||||
model_class = MAIN_DB_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
logger.warning(f"未知的表: {table_name}")
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
# 转换 datetime 字符串为 datetime 对象
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||||
|
||||
imported[table_name] = count
|
||||
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||||
|
||||
return imported
|
||||
|
||||
async def _import_knowledge_bases(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
result: ImportResult,
|
||||
) -> None:
|
||||
"""导入知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 1. 导入知识库元数据
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in kb_meta_data.items():
|
||||
model_class = KB_METADATA_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||||
|
||||
result.imported_tables[f"kb_{table_name}"] = count
|
||||
|
||||
# 2. 导入每个知识库的文档和文件
|
||||
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||||
kb_id = kb_data.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
|
||||
# 创建知识库目录
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 导入文档数据
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
if doc_path in zf.namelist():
|
||||
try:
|
||||
doc_content = zf.read(doc_path)
|
||||
doc_data = json.loads(doc_content)
|
||||
|
||||
# 导入到文档存储数据库
|
||||
await self._import_kb_documents(kb_id, doc_data)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||||
|
||||
# 导入 FAISS 索引
|
||||
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
if faiss_path in zf.namelist():
|
||||
try:
|
||||
target_path = kb_dir / "index.faiss"
|
||||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||
|
||||
# 导入媒体文件
|
||||
media_prefix = f"files/kb_media/{kb_id}/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(media_prefix):
|
||||
try:
|
||||
rel_path = name[len(media_prefix) :]
|
||||
target_path = kb_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||
|
||||
# 3. 重新加载知识库实例
|
||||
await self.kb_manager.load_kbs()
|
||||
|
||||
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||||
"""导入知识库文档到向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
doc_db_path = kb_dir / "doc.db"
|
||||
|
||||
# 初始化文档存储
|
||||
doc_storage = DocumentStorage(str(doc_db_path))
|
||||
await doc_storage.initialize()
|
||||
|
||||
try:
|
||||
documents = doc_data.get("documents", [])
|
||||
for doc in documents:
|
||||
try:
|
||||
await doc_storage.insert_document(
|
||||
doc_id=doc.get("doc_id", ""),
|
||||
text=doc.get("text", ""),
|
||||
metadata=json.loads(doc.get("metadata", "{}")),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导入文档块失败: {e}")
|
||||
finally:
|
||||
await doc_storage.close()
|
||||
|
||||
async def _import_attachments(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
attachments: list[dict],
|
||||
) -> int:
|
||||
"""导入附件文件"""
|
||||
count = 0
|
||||
|
||||
attachments_dir = Path(self.config_path).parent / "attachments"
|
||||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
attachment_prefix = "files/attachments/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||||
try:
|
||||
# 从附件记录中找到原始路径
|
||||
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||||
original_path = None
|
||||
for att in attachments:
|
||||
if att.get("attachment_id") == attachment_id:
|
||||
original_path = att.get("path")
|
||||
break
|
||||
|
||||
if original_path:
|
||||
target_path = Path(original_path)
|
||||
else:
|
||||
target_path = attachments_dir / os.path.basename(name)
|
||||
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||
|
||||
return count
|
||||
|
||||
async def _import_directories(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
manifest: dict,
|
||||
result: ImportResult,
|
||||
) -> dict[str, int]:
|
||||
"""导入插件和其他数据目录
|
||||
|
||||
Args:
|
||||
zf: ZIP 文件对象
|
||||
manifest: 备份清单
|
||||
result: 导入结果对象
|
||||
|
||||
Returns:
|
||||
dict: 每个目录导入的文件数量
|
||||
"""
|
||||
dir_stats: dict[str, int] = {}
|
||||
|
||||
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||||
backup_version = manifest.get("version", "1.0")
|
||||
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||||
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||||
return dir_stats
|
||||
|
||||
backed_up_dirs = manifest.get("directories", [])
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name in backed_up_dirs:
|
||||
if dir_name not in backup_directories:
|
||||
result.add_warning(f"未知的目录类型: {dir_name}")
|
||||
continue
|
||||
|
||||
target_dir = Path(backup_directories[dir_name])
|
||||
archive_prefix = f"directories/{dir_name}/"
|
||||
|
||||
file_count = 0
|
||||
|
||||
try:
|
||||
# 获取该目录下的所有文件
|
||||
dir_files = [
|
||||
name
|
||||
for name in zf.namelist()
|
||||
if name.startswith(archive_prefix) and name != archive_prefix
|
||||
]
|
||||
|
||||
if not dir_files:
|
||||
continue
|
||||
|
||||
# 备份现有目录(如果存在)
|
||||
if target_dir.exists():
|
||||
backup_path = Path(f"{target_dir}.bak")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(str(target_dir), str(backup_path))
|
||||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||
|
||||
# 创建目标目录
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解压文件
|
||||
for name in dir_files:
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = name[len(archive_prefix) :]
|
||||
if not rel_path: # 跳过目录条目
|
||||
continue
|
||||
|
||||
target_path = target_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
file_count += 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||
|
||||
dir_stats[dir_name] = file_count
|
||||
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||||
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||||
dir_stats[dir_name] = 0
|
||||
|
||||
return dir_stats
|
||||
|
||||
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||||
"""转换 datetime 字符串字段为 datetime 对象"""
|
||||
result = row.copy()
|
||||
|
||||
# 获取模型的 datetime 字段
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
try:
|
||||
mapper = sa_inspect(model_class)
|
||||
for column in mapper.columns:
|
||||
if column.name in result and result[column.name] is not None:
|
||||
# 检查是否是 datetime 类型的列
|
||||
from sqlalchemy import DateTime
|
||||
|
||||
if isinstance(column.type, DateTime):
|
||||
value = result[column.name]
|
||||
if isinstance(value, str):
|
||||
# 解析 ISO 格式的日期时间字符串
|
||||
result[column.name] = datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
"""
|
||||
|
||||
config_path: str
|
||||
default_config: dict
|
||||
schema: dict | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
@@ -76,6 +80,8 @@ class AstrBotConfig(dict):
|
||||
if v["type"] == "object":
|
||||
conf[k] = {}
|
||||
_parse_schema(v["items"], conf[k])
|
||||
elif v["type"] == "template_list":
|
||||
conf[k] = default
|
||||
else:
|
||||
conf[k] = default
|
||||
|
||||
|
||||
+358
-242
@@ -1,10 +1,11 @@
|
||||
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
|
||||
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.8.0"
|
||||
VERSION = "4.11.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -13,6 +14,7 @@ WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
"wecom",
|
||||
"wecom_ai_bot",
|
||||
"slack",
|
||||
"lark",
|
||||
]
|
||||
|
||||
# 默认配置
|
||||
@@ -42,7 +44,15 @@ DEFAULT_CONFIG = {
|
||||
"interval": "1.5,3.5",
|
||||
"log_base": 2.6,
|
||||
"words_count_threshold": 150,
|
||||
"split_mode": "regex", # regex 或 words
|
||||
"regex": ".*?[。?!~…]+|.+$",
|
||||
"split_words": [
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
"…",
|
||||
], # 当 split_mode 为 words 时使用
|
||||
"content_cleanup_rule": "",
|
||||
},
|
||||
"no_permission_reply": True,
|
||||
@@ -52,7 +62,8 @@ DEFAULT_CONFIG = {
|
||||
"ignore_bot_self_message": False,
|
||||
"ignore_at_all": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_sources": [], # provider sources
|
||||
"provider": [], # models from provider_sources
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
@@ -72,6 +83,16 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
|
||||
"llm_compress_instruction": (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
),
|
||||
"llm_compress_keep_recent": 4,
|
||||
"llm_compress_provider_id": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
@@ -99,6 +120,7 @@ DEFAULT_CONFIG = {
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
"use_file_service": False,
|
||||
"trigger_probability": 1.0,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
@@ -157,9 +179,28 @@ DEFAULT_CONFIG = {
|
||||
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||
"kb_agentic_mode": False,
|
||||
"disable_builtin_commands": False,
|
||||
}
|
||||
|
||||
|
||||
class ChatProviderTemplate(TypedDict):
|
||||
id: str
|
||||
provider_source_id: str
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
max_context_tokens: int
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
"id": "",
|
||||
"provide_source_id": "",
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
"max_context_tokens": 0,
|
||||
}
|
||||
|
||||
"""
|
||||
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
|
||||
|
||||
@@ -198,7 +239,7 @@ CONFIG_METADATA_2 = {
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"QQ 个人号(OneBot v11)": {
|
||||
"OneBot v11 (QQ 个人号等)": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -206,16 +247,6 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
"微信公众平台": {
|
||||
"id": "weixin_official_account",
|
||||
"type": "weixin_official_account",
|
||||
@@ -268,6 +299,10 @@ CONFIG_METADATA_2 = {
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"domain": "https://open.feishu.cn",
|
||||
"lark_connection_mode": "socket", # webhook, socket
|
||||
"webhook_uuid": "",
|
||||
"lark_encrypt_key": "",
|
||||
"lark_verification_token": "",
|
||||
},
|
||||
"钉钉(DingTalk)": {
|
||||
"id": "dingtalk",
|
||||
@@ -341,6 +376,16 @@ CONFIG_METADATA_2 = {
|
||||
"satori_heartbeat_interval": 10,
|
||||
"satori_reconnect_delay": 5,
|
||||
},
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
# "WebChat": {
|
||||
# "id": "webchat",
|
||||
# "type": "webchat",
|
||||
@@ -361,6 +406,28 @@ CONFIG_METADATA_2 = {
|
||||
# "type": "string",
|
||||
# "options": ["fullscreen", "embedded"],
|
||||
# },
|
||||
"lark_connection_mode": {
|
||||
"description": "订阅方式",
|
||||
"type": "string",
|
||||
"options": ["socket", "webhook"],
|
||||
"labels": ["长连接模式", "推送至服务器模式"],
|
||||
},
|
||||
"lark_encrypt_key": {
|
||||
"description": "Encrypt Key",
|
||||
"type": "string",
|
||||
"hint": "用于解密飞书回调数据的加密密钥",
|
||||
"condition": {
|
||||
"lark_connection_mode": "webhook",
|
||||
},
|
||||
},
|
||||
"lark_verification_token": {
|
||||
"description": "Verification Token",
|
||||
"type": "string",
|
||||
"hint": "用于验证飞书回调请求的令牌",
|
||||
"condition": {
|
||||
"lark_connection_mode": "webhook",
|
||||
},
|
||||
},
|
||||
"is_sandbox": {
|
||||
"description": "沙箱模式",
|
||||
"type": "bool",
|
||||
@@ -807,6 +874,7 @@ CONFIG_METADATA_2 = {
|
||||
"metadata": {
|
||||
"provider": {
|
||||
"type": "list",
|
||||
# provider sources templates
|
||||
"config_template": {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
@@ -817,107 +885,10 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"xai_native_search": False,
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Anthropic": {
|
||||
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
|
||||
"id": "claude",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Ollama": {
|
||||
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
|
||||
"id": "ollama_default",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://localhost:1234/v1",
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"Google Gemini": {
|
||||
"id": "google_gemini",
|
||||
"provider": "google",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -925,10 +896,6 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
@@ -939,13 +906,44 @@ CONFIG_METADATA_2 = {
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"budget": 0,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
||||
},
|
||||
"Anthropic": {
|
||||
"id": "anthropic",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"anth_thinking_config": {"budget": 0},
|
||||
},
|
||||
"Moonshot": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "xai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
"xai_native_search": False,
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"id": "deepseek",
|
||||
"provider": "deepseek",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -953,13 +951,75 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Zhipu": {
|
||||
"id": "zhipu",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure_openai",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Ollama": {
|
||||
"id": "ollama",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://127.0.0.1:11434/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://127.0.0.1:1234/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"ModelStack": {
|
||||
"id": "modelstack",
|
||||
"provider": "modelstack",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://modelstack.app/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Gemini_OpenAI_API": {
|
||||
"id": "google_gemini_openai",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"id": "groq",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -967,13 +1027,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
@@ -984,12 +1038,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"硅基流动": {
|
||||
"SiliconFlow": {
|
||||
"id": "siliconflow",
|
||||
"provider": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -998,15 +1049,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"model_config": {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"PPIO": {
|
||||
"id": "ppio",
|
||||
"provider": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1015,14 +1060,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
"TokenPony": {
|
||||
"id": "tokenpony",
|
||||
"provider": "tokenpony",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1031,14 +1071,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"Compshare": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1047,42 +1082,18 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.modelverse.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
@@ -1097,7 +1108,6 @@ CONFIG_METADATA_2 = {
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
@@ -1128,20 +1138,6 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"provider": "fastgpt",
|
||||
@@ -1165,7 +1161,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"Whisper(Local)": {
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1174,7 +1169,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "tiny",
|
||||
},
|
||||
"SenseVoice(Local)": {
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider": "sensevoice",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1196,7 +1190,6 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": "20",
|
||||
},
|
||||
"Edge TTS": {
|
||||
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"provider": "microsoft",
|
||||
"type": "edge_tts",
|
||||
@@ -1306,7 +1299,7 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-is-timber-weight": False,
|
||||
"minimax-voice-id": "female-shaonv",
|
||||
"minimax-timber-weight": '[\n {\n "voice_id": "Chinese (Mandarin)_Warm_Girl",\n "weight": 25\n },\n {\n "voice_id": "Chinese (Mandarin)_BashfulGirl",\n "weight": 50\n }\n]',
|
||||
"minimax-voice-emotion": "neutral",
|
||||
"minimax-voice-emotion": "auto",
|
||||
"minimax-voice-latex": False,
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
@@ -1412,6 +1405,10 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"provider_source_id": {
|
||||
"invisible": True,
|
||||
"type": "string",
|
||||
},
|
||||
"xai_native_search": {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
@@ -1466,7 +1463,32 @@ CONFIG_METADATA_2 = {
|
||||
"description": "自定义请求体参数",
|
||||
"type": "dict",
|
||||
"items": {},
|
||||
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
|
||||
"hint": "用于在请求时添加额外的参数,如 temperature、top_p、max_tokens 等。",
|
||||
"template_schema": {
|
||||
"temperature": {
|
||||
"name": "Temperature",
|
||||
"description": "温度参数",
|
||||
"hint": "控制输出的随机性,范围通常为 0-2。值越高越随机。",
|
||||
"type": "float",
|
||||
"default": 0.6,
|
||||
"slider": {"min": 0, "max": 2, "step": 0.1},
|
||||
},
|
||||
"top_p": {
|
||||
"name": "Top-p",
|
||||
"description": "Top-p 采样",
|
||||
"hint": "核采样参数,范围通常为 0-1。控制模型考虑的概率质量。",
|
||||
"type": "float",
|
||||
"default": 1.0,
|
||||
"slider": {"min": 0, "max": 1, "step": 0.01},
|
||||
},
|
||||
"max_tokens": {
|
||||
"name": "Max Tokens",
|
||||
"description": "最大令牌数",
|
||||
"hint": "生成的最大令牌数。",
|
||||
"type": "int",
|
||||
"default": 8192,
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
@@ -1782,13 +1804,35 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"description": "Gemini思考设置",
|
||||
"description": "Thinking Config",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "思考预算",
|
||||
"description": "Thinking Budget",
|
||||
"type": "int",
|
||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
|
||||
},
|
||||
"level": {
|
||||
"description": "Thinking Level",
|
||||
"type": "string",
|
||||
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
|
||||
"options": [
|
||||
"MINIMAL",
|
||||
"LOW",
|
||||
"MEDIUM",
|
||||
"HIGH",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"anth_thinking_config": {
|
||||
"description": "Thinking Config",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "Thinking Budget",
|
||||
"type": "int",
|
||||
"hint": "Anthropic thinking.budget_tokens param. Must >= 1024. See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1863,15 +1907,18 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-voice-emotion": {
|
||||
"type": "string",
|
||||
"description": "情绪",
|
||||
"hint": "控制合成语音的情绪",
|
||||
"hint": "控制合成语音的情绪。当为 auto 时,将根据文本内容自动选择情绪。",
|
||||
"options": [
|
||||
"auto",
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"fearful",
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
"calm",
|
||||
"fluent",
|
||||
"whisper",
|
||||
],
|
||||
},
|
||||
"minimax-voice-latex": {
|
||||
@@ -1969,7 +2016,6 @@ CONFIG_METADATA_2 = {
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"hint": "模型提供商名字。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商种类",
|
||||
@@ -1989,29 +2035,20 @@ CONFIG_METADATA_2 = {
|
||||
"description": "API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "提供商 API Key。",
|
||||
},
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
},
|
||||
"model_config": {
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {
|
||||
"description": "模型名称",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "模型最大输出长度(tokens)",
|
||||
"type": "int",
|
||||
},
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
},
|
||||
"model": {
|
||||
"description": "模型 ID",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_context_tokens": {
|
||||
"description": "模型上下文窗口大小",
|
||||
"type": "int",
|
||||
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
@@ -2173,6 +2210,9 @@ CONFIG_METADATA_2 = {
|
||||
"use_file_service": {
|
||||
"type": "bool",
|
||||
},
|
||||
"trigger_probability": {
|
||||
"type": "float",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -2383,6 +2423,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.trigger_probability": {
|
||||
"description": "TTS 触发概率",
|
||||
"type": "float",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.image_caption_prompt": {
|
||||
"description": "图片转述提示词",
|
||||
"type": "text",
|
||||
@@ -2509,6 +2557,66 @@ CONFIG_METADATA_3 = {
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"type": "string",
|
||||
"options": ["truncate_by_turns", "llm_compress"],
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
"hint": "",
|
||||
},
|
||||
"provider_settings.llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"type": "text",
|
||||
"hint": "如果为空则使用默认提示词。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"type": "int",
|
||||
"hint": "始终保留的最近 N 轮对话。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时将降级为“按对话轮数截断”的策略。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2573,22 +2681,6 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
@@ -2661,6 +2753,11 @@ CONFIG_METADATA_3 = {
|
||||
"description": "只 @ 机器人是否触发等待",
|
||||
"type": "bool",
|
||||
},
|
||||
"disable_builtin_commands": {
|
||||
"description": "禁用自带指令",
|
||||
"type": "bool",
|
||||
"hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"whitelist": {
|
||||
@@ -2875,9 +2972,26 @@ CONFIG_METADATA_3 = {
|
||||
"description": "分段回复字数阈值",
|
||||
"type": "int",
|
||||
},
|
||||
"platform_settings.segmented_reply.split_mode": {
|
||||
"description": "分段模式",
|
||||
"type": "string",
|
||||
"options": ["regex", "words"],
|
||||
"labels": ["正则表达式", "分段词列表"],
|
||||
},
|
||||
"platform_settings.segmented_reply.regex": {
|
||||
"description": "分段正则表达式",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"platform_settings.segmented_reply.split_mode": "regex",
|
||||
},
|
||||
},
|
||||
"platform_settings.segmented_reply.split_words": {
|
||||
"description": "分段词列表",
|
||||
"type": "list",
|
||||
"hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
|
||||
"condition": {
|
||||
"platform_settings.segmented_reply.split_mode": "words",
|
||||
},
|
||||
},
|
||||
"platform_settings.segmented_reply.content_cleanup_rule": {
|
||||
"description": "内容过滤正则表达式",
|
||||
@@ -2928,6 +3042,7 @@ CONFIG_METADATA_3 = {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"hint": "0.0-1.0 之间的数值",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_ltm_settings.active_reply.enable": True,
|
||||
},
|
||||
@@ -3035,4 +3150,5 @@ DEFAULT_VALUE_MAP = {
|
||||
"text": "",
|
||||
"list": [],
|
||||
"object": {},
|
||||
"template_list": [],
|
||||
}
|
||||
|
||||
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
|
||||
"_special",
|
||||
"invisible",
|
||||
"options",
|
||||
"slider",
|
||||
]:
|
||||
if attr in field_data:
|
||||
field_result[attr] = field_data[attr]
|
||||
|
||||
@@ -69,6 +69,7 @@ class ConversationManager:
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
token_usage=conv_v2.token_usage,
|
||||
)
|
||||
|
||||
async def new_conversation(
|
||||
@@ -256,6 +257,7 @@ class ConversationManager:
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""更新会话的对话.
|
||||
|
||||
@@ -263,6 +265,7 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
token_usage (int | None): token 使用量。None 表示不更新
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
@@ -274,6 +277,7 @@ class ConversationManager:
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
|
||||
@@ -33,6 +33,7 @@ from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.llm_metadata import update_llm_metadata
|
||||
from astrbot.core.utils.migra_helper import migra
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
@@ -89,6 +90,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化 UMOP 配置路由器
|
||||
self.umop_config_router = UmopConfigRouter(sp=sp)
|
||||
await self.umop_config_router.initialize()
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
@@ -185,6 +187,8 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
asyncio.create_task(update_llm_metadata())
|
||||
|
||||
def _load(self) -> None:
|
||||
"""加载事件总线和任务并初始化."""
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
@@ -197,7 +201,7 @@ class AstrBotCoreLifecycle:
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
extra_tasks = []
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
|
||||
|
||||
tasks_ = [event_bus_task, *extra_tasks]
|
||||
for task in tasks_:
|
||||
|
||||
@@ -5,11 +5,12 @@ from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -32,7 +33,7 @@ class BaseDatabase(abc.ABC):
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
@@ -151,6 +152,7 @@ class BaseDatabase(abc.ABC):
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
@@ -315,6 +317,76 @@ class BaseDatabase(abc.ABC):
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
"""Get all stored command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
|
||||
"""Fetch a single command configuration by handler."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
"""Create or update a command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
"""Delete a single command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
"""Bulk delete command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
"""List recorded command conflict entries."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
"""Create or update a conflict record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
"""Delete conflict records."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
|
||||
@@ -70,6 +70,7 @@ async def migration_conversation_table(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
@@ -207,6 +208,7 @@ async def migration_webchat_data(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Migration script to add token_usage column to conversations table.
|
||||
|
||||
This migration adds the token_usage field to track token consumption for each conversation.
|
||||
|
||||
Changes:
|
||||
- Adds token_usage column to conversations table (default: 0)
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
|
||||
async def migrate_token_usage(db_helper: BaseDatabase):
|
||||
"""Add token_usage column to conversations table.
|
||||
|
||||
This migration adds a new column to track token consumption in conversations.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_token_usage_1"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
|
||||
|
||||
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 检查列是否已存在
|
||||
result = await session.execute(text("PRAGMA table_info(conversations)"))
|
||||
columns = result.fetchall()
|
||||
column_names = [col[1] for col in columns]
|
||||
|
||||
if "token_usage" in column_names:
|
||||
logger.info("token_usage 列已存在,跳过迁移")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_token_usage_1", True
|
||||
)
|
||||
return
|
||||
|
||||
# 添加 token_usage 列
|
||||
await session.execute(
|
||||
text(
|
||||
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("token_usage 列添加成功")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
|
||||
logger.info("token_usage 迁移完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -127,7 +127,7 @@ class SQLiteDatabase:
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: tuple = None):
|
||||
def _exec_sql(self, sql: str, params: tuple | None = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -224,9 +224,11 @@ class SQLiteDatabase:
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
return Stats(platform)
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
def get_conversation_by_user_id(
|
||||
self, user_id: str, cid: str
|
||||
) -> Conversation | None:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -258,7 +260,7 @@ class SQLiteDatabase:
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> tuple:
|
||||
def get_conversations(self, user_id: str) -> list[Conversation]:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
|
||||
+82
-15
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats" # type: ignore
|
||||
__tablename__: str = "platform_stats"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations" # type: ignore
|
||||
__tablename__: str = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
inner_conversation_id: int | None = Field(
|
||||
default=None,
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
@@ -53,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
|
||||
)
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
persona_id: str | None = Field(default=None)
|
||||
token_usage: int = Field(default=0, nullable=False)
|
||||
"""content is a list of OpenAI-formated messages in list[dict] format.
|
||||
token_usage is the total token value of the messages.
|
||||
when 0, will use estimated token counter.
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -68,7 +74,7 @@ class Persona(SQLModel, table=True):
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas" # type: ignore
|
||||
__tablename__: str = "personas"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -98,7 +104,7 @@ class Persona(SQLModel, table=True):
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences" # type: ignore
|
||||
__tablename__: str = "preferences"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
@@ -134,7 +140,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history" # type: ignore
|
||||
__tablename__: str = "platform_message_history"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -162,7 +168,7 @@ class PlatformSession(SQLModel, table=True):
|
||||
Each session can have multiple conversations (对话) associated with it.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_sessions" # type: ignore
|
||||
__tablename__: str = "platform_sessions"
|
||||
|
||||
inner_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -203,7 +209,7 @@ class Attachment(SQLModel, table=True):
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments" # type: ignore
|
||||
__tablename__: str = "attachments"
|
||||
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -233,6 +239,65 @@ class Attachment(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class CommandConfig(SQLModel, table=True):
|
||||
"""Per-command configuration overrides for dashboard management."""
|
||||
|
||||
__tablename__ = "command_configs" # type: ignore
|
||||
|
||||
handler_full_name: str = Field(
|
||||
primary_key=True,
|
||||
max_length=512,
|
||||
)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
module_path: str = Field(nullable=False, max_length=255)
|
||||
original_command: str = Field(nullable=False, max_length=255)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
enabled: bool = Field(default=True, nullable=False)
|
||||
keep_original_alias: bool = Field(default=False, nullable=False)
|
||||
conflict_key: str | None = Field(default=None, max_length=255)
|
||||
resolution_strategy: str | None = Field(default=None, max_length=64)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_managed: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class CommandConflict(SQLModel, table=True):
|
||||
"""Conflict tracking for duplicated command names."""
|
||||
|
||||
__tablename__ = "command_conflicts" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conflict_key: str = Field(nullable=False, max_length=255)
|
||||
handler_full_name: str = Field(nullable=False, max_length=512)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
status: str = Field(default="pending", max_length=32)
|
||||
resolution: str | None = Field(default=None, max_length=64)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_generated: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conflict_key",
|
||||
"handler_full_name",
|
||||
name="uix_conflict_handler",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
@@ -253,6 +318,8 @@ class Conversation:
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
token_usage: int = 0
|
||||
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
@@ -261,17 +328,17 @@ class Personality(TypedDict):
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
prompt: str
|
||||
name: str
|
||||
begin_dialogs: list[str]
|
||||
mood_imitation_dialogs: list[str]
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
tools: list[str] | None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
_begin_dialogs_processed: list[dict]
|
||||
_mood_imitation_dialogs_processed: str
|
||||
|
||||
|
||||
# ====
|
||||
|
||||
+249
-4
@@ -1,14 +1,18 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import typing as T
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -25,6 +29,7 @@ from astrbot.core.db.po import (
|
||||
)
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
TxResult = T.TypeVar("TxResult")
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
@@ -236,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session.add(new_conversation)
|
||||
return new_conversation
|
||||
|
||||
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||
async def update_conversation(
|
||||
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
||||
):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -250,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
values["persona_id"] = persona_id
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if token_usage is not None:
|
||||
values["token_usage"] = token_usage
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
@@ -489,7 +498,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(
|
||||
Attachment.attachment_id.in_(attachment_ids)
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
@@ -505,7 +514,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id) == attachment_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
@@ -521,7 +530,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount
|
||||
|
||||
async def insert_persona(
|
||||
@@ -669,6 +678,242 @@ class SQLiteDatabase(BaseDatabase):
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Command Configuration & Conflict Tracking
|
||||
# ====
|
||||
|
||||
async def _run_in_tx(
|
||||
self,
|
||||
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
||||
) -> TxResult:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
return await fn(session)
|
||||
|
||||
@staticmethod
|
||||
def _apply_updates(model, **updates) -> None:
|
||||
for field, value in updates.items():
|
||||
if value is not None:
|
||||
setattr(model, field, value)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_config(
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
return CommandConfig(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=True if enabled is None else enabled,
|
||||
keep_original_alias=False
|
||||
if keep_original_alias is None
|
||||
else keep_original_alias,
|
||||
conflict_key=conflict_key or original_command,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=bool(auto_managed),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_conflict(
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
return CommandConflict(
|
||||
conflict_key=conflict_key,
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
status=status or "pending",
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=bool(auto_generated),
|
||||
)
|
||||
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(select(CommandConfig))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
) -> CommandConfig | None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
return await session.get(CommandConfig, handler_full_name)
|
||||
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
async def _op(session: AsyncSession) -> CommandConfig:
|
||||
config = await session.get(CommandConfig, handler_full_name)
|
||||
if not config:
|
||||
config = self._new_command_config(
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
module_path,
|
||||
original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
session.add(config)
|
||||
else:
|
||||
self._apply_updates(
|
||||
config,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
await self.delete_command_configs([handler_full_name])
|
||||
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
if not handler_full_names:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConfig).where(
|
||||
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
||||
),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(CommandConflict)
|
||||
if status:
|
||||
query = query.where(CommandConflict.status == status)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
async def _op(session: AsyncSession) -> CommandConflict:
|
||||
result = await session.execute(
|
||||
select(CommandConflict).where(
|
||||
CommandConflict.conflict_key == conflict_key,
|
||||
CommandConflict.handler_full_name == handler_full_name,
|
||||
),
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
record = self._new_command_conflict(
|
||||
conflict_key,
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
self._apply_updates(
|
||||
record,
|
||||
plugin_name=plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
@@ -90,4 +90,6 @@ class EmbeddingStorage:
|
||||
path (str): 保存索引的路径
|
||||
|
||||
"""
|
||||
if self.index is None:
|
||||
return
|
||||
faiss.write_index(self.index, self.path)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EventBus:
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
@@ -40,6 +40,11 @@ class EventBus:
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
if not scheduler:
|
||||
logger.error(
|
||||
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
|
||||
)
|
||||
continue
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
|
||||
@@ -149,8 +149,16 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
分割后的文本块列表
|
||||
|
||||
"""
|
||||
chunk_size = chunk_size or self.chunk_size
|
||||
overlap = overlap or self.chunk_overlap
|
||||
if chunk_size is None:
|
||||
chunk_size = self.chunk_size
|
||||
if overlap is None:
|
||||
overlap = self.chunk_overlap
|
||||
if chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be greater than 0")
|
||||
if overlap < 0:
|
||||
raise ValueError("chunk_overlap must be non-negative")
|
||||
if overlap >= chunk_size:
|
||||
raise ValueError("chunk_overlap must be less than chunk_size")
|
||||
result = []
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
end = min(i + chunk_size, len(text))
|
||||
|
||||
@@ -166,7 +166,11 @@ class RetrievalManager:
|
||||
# 5. Rerank
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
vec_db = kb_options[kb_id]["vec_db"]
|
||||
if not isinstance(vec_db, FaissVecDB):
|
||||
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
|
||||
continue
|
||||
|
||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||
if (
|
||||
vec_db
|
||||
|
||||
+3
-2
@@ -24,6 +24,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from collections import deque
|
||||
|
||||
@@ -57,7 +58,7 @@ def is_plugin_path(pathname):
|
||||
return False
|
||||
|
||||
norm_path = os.path.normpath(pathname)
|
||||
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
||||
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
|
||||
|
||||
|
||||
def get_short_level_name(level_name):
|
||||
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
|
||||
self.log_broker.publish(
|
||||
{
|
||||
"level": record.levelname,
|
||||
"time": record.asctime,
|
||||
"time": time.time(),
|
||||
"data": log_entry,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def toDict(self):
|
||||
data = {}
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
|
||||
id: int | None = 0 # 忽略
|
||||
name: str | None = "" # qq昵称
|
||||
uin: str | None = "0" # qq号
|
||||
content: list[BaseMessageComponent] | None = []
|
||||
content: list[BaseMessageComponent] = []
|
||||
seq: str | list | None = "" # 忽略
|
||||
time: int | None = 0 # 忽略
|
||||
|
||||
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
async def to_dict(self) -> dict:
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
@@ -626,12 +629,11 @@ class Nodes(BaseMessageComponent):
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type = ComponentType.Json
|
||||
data: str | dict
|
||||
resid: int | None = 0
|
||||
data: dict
|
||||
|
||||
def __init__(self, data, **_):
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
def __init__(self, data: str | dict, **_):
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
super().__init__(data=data, **_)
|
||||
|
||||
|
||||
@@ -714,12 +716,15 @@ class File(BaseMessageComponent):
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return os.path.abspath(self.file_)
|
||||
if self.file_:
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if not self.url:
|
||||
raise ValueError("Download failed: No URL provided in File component.")
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
if self.name:
|
||||
|
||||
@@ -98,8 +98,8 @@ class PersonaManager:
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
|
||||
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
check_text: str | None = None,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""检查内容安全"""
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
ok, info = self.strategy_selector.check(text)
|
||||
|
||||
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
@@ -91,6 +91,7 @@ async def call_event_hook(
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
assert inspect.iscoroutinefunction(handler.handler)
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
|
||||
)
|
||||
return
|
||||
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
if not await SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(
|
||||
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
@@ -23,6 +24,7 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -40,11 +42,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
@@ -64,6 +61,25 @@ class InternalAgentSubStage(Stage):
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
# 上下文管理相关
|
||||
self.context_limit_reached_strategy: str = settings.get(
|
||||
"context_limit_reached_strategy", "truncate_by_turns"
|
||||
)
|
||||
self.llm_compress_instruction: str = settings.get(
|
||||
"llm_compress_instruction", ""
|
||||
)
|
||||
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
||||
self.llm_compress_provider_id: str = settings.get(
|
||||
"llm_compress_provider_id", ""
|
||||
)
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
if self.dequeue_context_length <= 0:
|
||||
self.dequeue_context_length = 1
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -166,34 +182,6 @@ class InternalAgentSubStage(Stage):
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -294,6 +282,8 @@ class InternalAgentSubStage(Stage):
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
@@ -307,217 +297,255 @@ class InternalAgentSubStage(Stage):
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
if req.contexts is None:
|
||||
req.contexts = []
|
||||
# using agent context messages to save to history
|
||||
message_to_save = []
|
||||
for message in all_messages:
|
||||
if message.role == "system":
|
||||
# we do not save system messages to history
|
||||
continue
|
||||
if message.role in ["assistant", "user"] and getattr(
|
||||
message, "_no_save", None
|
||||
):
|
||||
# we do not save user and assistant messages that are marked as _no_save
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# get token usage from agent runner stats
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
messages.append(await req.assemble_context())
|
||||
# 这一轮对话的 LLM 响应
|
||||
if req.tool_calls_result:
|
||||
if not isinstance(req.tool_calls_result, list):
|
||||
messages.extend(req.tool_calls_result.to_openai_messages())
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=messages,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
try:
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(
|
||||
f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。"
|
||||
)
|
||||
return
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
logger.debug("ready to request llm provider")
|
||||
|
||||
# 通知等待调用 LLM(在获取锁之前)
|
||||
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
||||
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
# truncate contexts to fit max length
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
)
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing agent: {e}")
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"Error occurred while processing agent request: {e}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -16,7 +16,6 @@ from ..stage import Stage
|
||||
|
||||
class StarRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
||||
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
|
||||
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
|
||||
self.ctx = ctx
|
||||
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||
"activated_handlers",
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
|
||||
):
|
||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||
if (
|
||||
event.get_result() and not event.get_result().is_stopped()
|
||||
event.get_result() and not event.is_stopped()
|
||||
) or not event.get_result():
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -117,7 +117,9 @@ class RespondStage(Stage):
|
||||
if not self.enable_seg:
|
||||
return False
|
||||
|
||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||
if (result := event.get_result()) is None:
|
||||
return False
|
||||
if self.only_llm_result and not result.is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
@@ -156,7 +158,11 @@ class RespondStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
if event.get_extra("_streaming_finished", False):
|
||||
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
event.set_extra("_streaming_finished", True)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
@@ -185,7 +191,7 @@ class RespondStage(Stage):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
result.chain[idx] = component
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -6,6 +7,7 @@ from collections.abc import AsyncGenerator
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
@@ -41,6 +43,18 @@ class ResultDecorateStage(Stage):
|
||||
"forward_threshold"
|
||||
]
|
||||
|
||||
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
|
||||
"trigger_probability",
|
||||
1,
|
||||
)
|
||||
try:
|
||||
self.tts_trigger_probability = max(
|
||||
0.0,
|
||||
min(float(trigger_probability), 1.0),
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
self.tts_trigger_probability = 1.0
|
||||
|
||||
# 分段回复
|
||||
self.words_count_threshold = int(
|
||||
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||
@@ -53,7 +67,22 @@ class ResultDecorateStage(Stage):
|
||||
self.only_llm_result = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["only_llm_result"]
|
||||
self.split_mode = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_mode", "regex")
|
||||
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
|
||||
self.split_words = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_words", ["。", "?", "!", "~", "…"])
|
||||
if self.split_words:
|
||||
escaped_words = sorted(
|
||||
[re.escape(word) for word in self.split_words], key=len, reverse=True
|
||||
)
|
||||
self.split_words_pattern = re.compile(
|
||||
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
|
||||
)
|
||||
else:
|
||||
self.split_words_pattern = None
|
||||
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["content_cleanup_rule"]
|
||||
@@ -69,6 +98,31 @@ class ResultDecorateStage(Stage):
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
provider_cfg = ctx.astrbot_config.get("provider_settings", {})
|
||||
self.show_reasoning = provider_cfg.get("display_reasoning_text", False)
|
||||
|
||||
def _split_text_by_words(self, text: str) -> list[str]:
|
||||
"""使用分段词列表分段文本"""
|
||||
if not self.split_words_pattern:
|
||||
return [text]
|
||||
|
||||
segments = self.split_words_pattern.findall(text)
|
||||
result = []
|
||||
for seg in segments:
|
||||
if isinstance(seg, tuple):
|
||||
content = seg[0]
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
for word in self.split_words:
|
||||
if content.endswith(word):
|
||||
content = content[: -len(word)]
|
||||
break
|
||||
if content.strip():
|
||||
result.append(content)
|
||||
elif seg and seg.strip():
|
||||
result.append(seg)
|
||||
return result if result else [text]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -93,11 +147,13 @@ class ResultDecorateStage(Stage):
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text += comp.text
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
|
||||
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
@@ -114,7 +170,8 @@ class ResultDecorateStage(Stage):
|
||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
|
||||
)
|
||||
await handler.handler(event)
|
||||
if event.get_result() is None or not event.get_result().chain:
|
||||
|
||||
if (result := event.get_result()) is None or not result.chain:
|
||||
logger.debug(
|
||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
|
||||
)
|
||||
@@ -161,21 +218,27 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
# 根据 split_mode 选择分段方式
|
||||
if self.split_mode == "words":
|
||||
split_response = self._split_text_by_words(comp.text)
|
||||
else: # regex 模式
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
@@ -194,63 +257,75 @@ class ResultDecorateStage(Stage):
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
should_tts = (
|
||||
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
and await SessionServiceManager.should_process_tts_request(event)
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
and tts_provider
|
||||
)
|
||||
if should_tts and not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
|
||||
if (
|
||||
not should_tts
|
||||
and self.show_reasoning
|
||||
and event.get_extra("_llm_reasoning_content")
|
||||
):
|
||||
if not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
else:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
)
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
# inject reasoning content to chain
|
||||
reasoning_content = event.get_extra("_llm_reasoning_content")
|
||||
result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n"))
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
if should_tts and tts_provider:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (
|
||||
|
||||
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
||||
WecomAIBotMessageEvent,
|
||||
)
|
||||
|
||||
from . import STAGES_ORDER
|
||||
from .context import PipelineContext
|
||||
@@ -78,7 +82,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
|
||||
# workaround for #2309
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
@@ -13,6 +14,23 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = {
|
||||
"aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"dingtalk": lambda e: e.get_sender_id(),
|
||||
"qq_official": lambda e: e.get_sender_id(),
|
||||
"qq_official_webhook": lambda e: e.get_sender_id(),
|
||||
"lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}",
|
||||
"misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}",
|
||||
"wechatpadpro": lambda e: f"{e.get_group_id()}#{e.get_sender_id()}",
|
||||
}
|
||||
|
||||
|
||||
def build_unique_session_id(event: AstrMessageEvent) -> str | None:
|
||||
platform = event.get_platform_name()
|
||||
builder = UNIQUE_SESSION_ID_BUILDERS.get(platform)
|
||||
return builder(event) if builder else None
|
||||
|
||||
|
||||
@register_stage
|
||||
class WakingCheckStage(Stage):
|
||||
@@ -50,18 +68,30 @@ class WakingCheckStage(Stage):
|
||||
"ignore_at_all",
|
||||
False,
|
||||
)
|
||||
self.disable_builtin_commands = self.ctx.astrbot_config.get(
|
||||
"disable_builtin_commands", False
|
||||
)
|
||||
platform_settings = self.ctx.astrbot_config.get("platform_settings", {})
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# apply unique session
|
||||
if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE:
|
||||
sid = build_unique_session_id(event)
|
||||
if sid:
|
||||
event.session_id = sid
|
||||
|
||||
# ignore bot self message
|
||||
if (
|
||||
self.ignore_bot_self_message
|
||||
and event.get_self_id() == event.get_sender_id()
|
||||
):
|
||||
# 忽略机器人自己发送的消息
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||
@@ -131,6 +161,14 @@ class WakingCheckStage(Stage):
|
||||
EventType.AdapterMessageEvent,
|
||||
plugins_name=event.plugins_name,
|
||||
):
|
||||
if (
|
||||
self.disable_builtin_commands
|
||||
and handler.handler_module_path
|
||||
== "astrbot.builtin_stars.builtin_commands.main"
|
||||
):
|
||||
logger.debug("skipping builtin command")
|
||||
continue
|
||||
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
@@ -189,7 +227,7 @@ class WakingCheckStage(Stage):
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
|
||||
event,
|
||||
activated_handlers,
|
||||
)
|
||||
|
||||
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
def get_sender_name(self) -> str:
|
||||
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
||||
return self.message_obj.sender.nickname
|
||||
if isinstance(self.message_obj.sender.nickname, str):
|
||||
return self.message_obj.sender.nickname
|
||||
return ""
|
||||
|
||||
def set_extra(self, key, value):
|
||||
"""设置额外的信息。"""
|
||||
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self.call_llm = call_llm
|
||||
|
||||
def get_result(self) -> MessageEventResult:
|
||||
def get_result(self) -> MessageEventResult | None:
|
||||
"""获取消息事件的结果。"""
|
||||
return self._result
|
||||
|
||||
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
session_id: str = None,
|
||||
session_id: str = "",
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
system_prompt: str = "",
|
||||
|
||||
@@ -54,7 +54,7 @@ class AstrBotMessage:
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group: Group # 群组
|
||||
group: Group | None # 群组
|
||||
sender: MessageMember # 发送者
|
||||
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
@@ -78,7 +78,7 @@ class AstrBotMessage:
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str):
|
||||
def group_id(self, value: str | None):
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
|
||||
@@ -5,6 +5,7 @@ from asyncio import Queue
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
|
||||
|
||||
from .platform import Platform, PlatformStatus
|
||||
from .register import platform_cls_map
|
||||
@@ -18,6 +19,7 @@ class PlatformManager:
|
||||
|
||||
self._inst_map: dict[str, dict] = {}
|
||||
|
||||
self.astrbot_config = config
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
|
||||
@@ -29,6 +31,8 @@ class PlatformManager:
|
||||
"""初始化所有平台适配器"""
|
||||
for platform in self.platforms_config:
|
||||
try:
|
||||
if ensure_platform_webhook_config(platform):
|
||||
self.astrbot_config.save_config()
|
||||
await self.load_platform(platform)
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -80,6 +80,13 @@ class Platform(abc.ABC):
|
||||
if self._status == PlatformStatus.ERROR:
|
||||
self._status = PlatformStatus.RUNNING
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
"""是否正在使用统一 Webhook 模式"""
|
||||
return bool(
|
||||
self.config.get("unified_webhook_mode", False)
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取平台统计信息"""
|
||||
meta = self.meta()
|
||||
@@ -97,10 +104,11 @@ class Platform(abc.ABC):
|
||||
}
|
||||
if self.last_error
|
||||
else None,
|
||||
"unified_webhook": self.unified_webhook(),
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self) -> Awaitable[Any]:
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -116,7 +124,7 @@ class Platform(abc.ABC):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||
|
||||
异步方法。
|
||||
|
||||
@@ -7,7 +7,7 @@ class PlatformMetadata:
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str | None = None
|
||||
id: str
|
||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||
|
||||
default_config_tmpl: dict | None = None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user