Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c76b7ec387 | |||
| b7f3010d72 | |||
| fbbaf1cd08 | |||
| 9c8025acce | |||
| 98c5466b5d | |||
| 6345ac6ff8 | |||
| 5bcd683012 | |||
| eaa193c6c5 | |||
| 1bdcaa1318 | |||
| 6b6c48354d | |||
| 774efb2fe0 | |||
| 3ec76636f9 | |||
| 283810d103 | |||
| 81a76bc8e5 | |||
| 788764be02 | |||
| 802ab26934 | |||
| 6857c81a14 | |||
| a6ed511a30 | |||
| 44c2b58206 | |||
| 0e2adab3fd | |||
| 0fe87d6b98 | |||
| 31ef3d1084 | |||
| b984bb2513 |
@@ -15,6 +15,7 @@ Always reference these instructions first and fallback to search or bash command
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
@@ -34,7 +35,7 @@ Always reference these instructions first and fallback to search or bash command
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Dashboard Build
|
||||
run: |
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: dist-without-markdown
|
||||
path: |
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
@@ -118,7 +118,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
name: Smoke Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README*.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
smoke-test:
|
||||
name: Run smoke tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV package manager
|
||||
run: |
|
||||
pip install uv
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync
|
||||
timeout-minutes: 15
|
||||
|
||||
- name: Run smoke tests
|
||||
run: |
|
||||
uv run main.py &
|
||||
APP_PID=$!
|
||||
|
||||
echo "Waiting for application to start..."
|
||||
for i in {1..60}; do
|
||||
if curl -f http://localhost:6185 > /dev/null 2>&1; then
|
||||
echo "Application started successfully!"
|
||||
kill $APP_PID
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Application failed to start within 30 seconds"
|
||||
kill $APP_PID 2>/dev/null || true
|
||||
exit 1
|
||||
timeout-minutes: 2
|
||||
+15
-52
@@ -1,64 +1,27 @@
|
||||
# 本工作流用于标记并关闭长期不活跃的 Issue。
|
||||
# 目前仅针对带 `bug` 标签的 Issue 生效,不会处理 PR。
|
||||
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
||||
#
|
||||
# 文档: https://github.com/actions/stale
|
||||
name: Mark stale bug issues
|
||||
# You can adjust the behavior by modifying this file.
|
||||
# For more information, see:
|
||||
# https://github.com/actions/stale
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# 每天 UTC 08:30 执行 (北京时间 16:30)
|
||||
- cron: '30 8 * * *'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry-run:
|
||||
description: '仅预览, 不实际执行 (Dry run mode)'
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
- cron: '21 23 * * *'
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 200
|
||||
|
||||
# 只处理带 bug 标签的 Issue
|
||||
any-of-labels: 'bug'
|
||||
|
||||
# 不处理 PR
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
|
||||
# 不活跃判定与关闭策略: 先标记 stale, 再延迟关闭
|
||||
days-before-issue-stale: 60
|
||||
days-before-issue-close: 30
|
||||
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: |
|
||||
This issue has been automatically marked as **stale** because it has not had any activity.
|
||||
It will be closed in a certain period of time if no further activity occurs.
|
||||
If this issue is still relevant, please leave a comment.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 已较长时间无活动, 已被标记为 `stale`。
|
||||
如无后续活动, 将在一段时间后自动关闭。
|
||||
如仍需跟进, 请回复评论。
|
||||
close-issue-message: |
|
||||
This issue has been automatically closed due to inactivity.
|
||||
If the problem still exists, feel free to reopen or create a new issue with updated information.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 因长期无活动已自动关闭。
|
||||
如问题仍存在, 欢迎补充复现信息并重新打开或新建 Issue。
|
||||
|
||||
remove-stale-when-updated: true
|
||||
|
||||
debug-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run }}
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'Stale issue message'
|
||||
stale-pr-message: 'Stale pull request message'
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
|
||||
+2
-5
@@ -24,9 +24,9 @@ configs/session
|
||||
configs/config.yaml
|
||||
cmd_config.json
|
||||
|
||||
# Plugins
|
||||
# Plugins and packages
|
||||
addons/plugins
|
||||
astrbot/builtin_stars/python_interpreter/workplace
|
||||
packages/python_interpreter/workplace
|
||||
tests/astrbot_plugin_openai
|
||||
|
||||
# Dashboard
|
||||
@@ -34,7 +34,6 @@ dashboard/node_modules/
|
||||
dashboard/dist/
|
||||
package-lock.json
|
||||
package.json
|
||||
yarn.lock
|
||||
|
||||
# Operating System
|
||||
**/.DS_Store
|
||||
@@ -48,5 +47,3 @@ astrbot.lock
|
||||
chroma
|
||||
venv/*
|
||||
pytest.ini
|
||||
AGENTS.md
|
||||
IFLOW.md
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
# CONTRIBUTING
|
||||
|
||||
## 贡献指南
|
||||
|
||||
首先,感谢您花时间做出贡献!❤️
|
||||
|
||||
所有类型的贡献都受到鼓励和重视。有关不同的帮助方式和处理方式的详细信息,请参阅[目录](#目录)。在做出贡献之前,请确保阅读相关部分。这将使我们维护人员的工作变得更加容易,并为所有参与者带来顺畅的体验。社区期待您的贡献。🎉
|
||||
|
||||
### 目录
|
||||
|
||||
- [报告问题](#报告问题)
|
||||
- [提交代码更改](#提交代码更改)
|
||||
|
||||
### 报告问题
|
||||
|
||||
如果您在使用 AstrBot 时遇到任何问题,请按照以下步骤报告:
|
||||
|
||||
1. **检查现有问题**:在提交新问题之前,请先检查 [Issues](https://github.com/AstrBotDevs/AstrBot/issues) 中是否已经存在类似的问题。
|
||||
2. **创建新问题**:如果没有类似的问题,请创建一个新问题。请确保提供以下信息:
|
||||
- 问题的简要描述
|
||||
- 重现问题的步骤
|
||||
- 预期结果和实际结果
|
||||
- 相关日志或错误消息
|
||||
|
||||
### 提交代码更改
|
||||
|
||||
#### 分支命名
|
||||
|
||||
我们使用 `fix/` 前缀来修复错误,使用 `feat/` 前缀来添加新功能。对于 `fix/` 分支,请使用简短的描述,或者直接使用 Issue 编号。例如:`fix/1234` 或者 `fix/1234-login-typo`。对于 `feat/` 分支,请使用简短的描述,例如:`feat/add-user-profile`。
|
||||
|
||||
#### PR 描述
|
||||
|
||||
- 请使用英文描述您的 PR。
|
||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||
|
||||
#### 代码规范
|
||||
|
||||
##### Core
|
||||
|
||||
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
||||
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
First off, thanks for taking the time to contribute! ❤️
|
||||
|
||||
All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- [Reporting Issues](#reporting-issues)
|
||||
- [Pull Requests](#pull-requests)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
If you encounter any issues while using AstrBot, please follow these steps to report them:
|
||||
1. **Check Existing Issues**: Before submitting a new issue, please check if a similar issue already exists in the [Issues](https://github.com/AstrBotDevs/AstrBot/issues) section of the repository.
|
||||
2. **Create a New Issue**: If no similar issue exists, please create a new issue. Make sure to provide the following information:
|
||||
- A brief description of the issue
|
||||
- Steps to reproduce the issue
|
||||
- Expected and actual results
|
||||
- Relevant logs or error messages
|
||||
|
||||
### Pull Requests
|
||||
|
||||
#### Branch Naming
|
||||
|
||||
We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features. For `fix/` branches, please use a short description or directly use the Issue number, e.g., `fix/1234` or `fix/1234-login-typo`. For `feat/` branches, please use a short description, e.g., `feat/add-user-profile`.
|
||||
|
||||
#### PR Description
|
||||
- Please use English to describe your PR.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
|
||||
#### Code Style
|
||||
|
||||
##### Core
|
||||
|
||||
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
@@ -1,13 +1,10 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
@@ -17,38 +14,35 @@
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">文档</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
|
||||
|
||||
## 主要功能
|
||||
|
||||
1. 💯 免费 & 开源。
|
||||
1. ✨ AI 大模型对话,多模态,Agent,MCP,知识库,人格设定。
|
||||
2. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。
|
||||
2. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。
|
||||
3. 📦 插件扩展,已有近 800 个插件可一键安装。
|
||||
5. 💻 WebUI 支持。
|
||||
6. 🌐 国际化(i18n)支持。
|
||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
|
||||
## 快速开始
|
||||
## 部署方式
|
||||
|
||||
#### Docker 部署(推荐 🥳)
|
||||
|
||||
@@ -56,12 +50,6 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主
|
||||
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
#### uv 部署
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
AstrBot 与宝塔面板合作,已上架至宝塔面板。
|
||||
@@ -113,6 +101,24 @@ uv run main.py
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
## 🌍 社区
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群组
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## 支持的消息平台
|
||||
|
||||
**官方维护**
|
||||
@@ -132,7 +138,6 @@ uv run main.py
|
||||
|
||||
**社区维护**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
@@ -200,26 +205,6 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 社区
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 7 群:743746109
|
||||
- 8 群:1030353265
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群组
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
@@ -245,10 +230,4 @@ pre-commit install
|
||||
|
||||
</details>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
</div
|
||||
|
||||
|
||||
+26
-41
@@ -19,38 +19,30 @@
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
|
||||
</div>
|
||||
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform that integrates with mainstream instant messaging apps. It provides reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether you're building a personal AI companion, intelligent customer service, automation assistant, or enterprise knowledge base, AstrBot enables you to quickly build production-ready AI applications within your IM platform workflows.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform and development framework.
|
||||
|
||||
## Key Features
|
||||
|
||||
1. 💯 Free & Open Source.
|
||||
2. ✨ AI LLM Conversations, Multimodal, Agent, MCP, Knowledge Base, Persona Settings.
|
||||
3. 🤖 Supports integration with Dify, Alibaba Cloud Bailian, Coze and other agent platforms.
|
||||
4. 🌐 Multi-Platform: QQ, WeChat Work, Feishu, DingTalk, WeChat Official Accounts, Telegram, Slack, and [more](#supported-messaging-platforms).
|
||||
5. 📦 Plugin Extensions with nearly 800 plugins available for one-click installation.
|
||||
6. 💻 WebUI Support.
|
||||
7. 🌐 Internationalization (i18n) Support.
|
||||
1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more.
|
||||
2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation.
|
||||
3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more.
|
||||
4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem.
|
||||
5. **Web UI**. Visual configuration and management of your bot with comprehensive features.
|
||||
|
||||
## Quick Start
|
||||
## Deployment Methods
|
||||
|
||||
#### Docker Deployment (Recommended 🥳)
|
||||
|
||||
@@ -58,12 +50,6 @@ We recommend deploying AstrBot using Docker or Docker Compose.
|
||||
|
||||
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### uv Deployment
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### BT-Panel Deployment
|
||||
|
||||
AstrBot has partnered with BT-Panel and is now available in their marketplace.
|
||||
@@ -115,6 +101,24 @@ uv run main.py
|
||||
|
||||
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## 🌍 Community
|
||||
|
||||
### QQ Groups
|
||||
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
### Telegram Group
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord Server
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## Supported Messaging Platforms
|
||||
|
||||
**Officially Maintained**
|
||||
@@ -134,7 +138,6 @@ Or refer to the official documentation: [Deploy AstrBot from Source](https://ast
|
||||
|
||||
**Community Maintained**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
@@ -202,24 +205,6 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Community
|
||||
|
||||
### QQ Groups
|
||||
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
### Telegram Group
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord Server
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
|
||||
|
||||
-249
@@ -1,249 +0,0 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Feuille de route</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Signaler un problème</a>
|
||||
</div>
|
||||
|
||||
AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègre aux principales applications de messagerie instantanée. Elle fournit une infrastructure d'IA conversationnelle fiable et évolutive pour les particuliers, les développeurs et les équipes. Que vous construisiez un compagnon IA personnel, un service client intelligent, un assistant d'automatisation ou une base de connaissances d'entreprise, AstrBot vous permet de créer rapidement des applications d'IA prêtes pour la production dans les flux de travail de votre plateforme de messagerie.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Fonctionnalités principales
|
||||
|
||||
1. 💯 Gratuit & Open Source.
|
||||
2. ✨ Conversations avec LLM IA, Multimodal, Agent, MCP, Base de connaissances, Paramètres de personnalité.
|
||||
3. 🤖 Prise en charge de l'intégration avec Dify, Alibaba Cloud Bailian, Coze et autres plateformes d'agents.
|
||||
4. 🌐 Multi-plateforme : QQ, WeChat Work, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack, et [plus encore](#plateformes-de-messagerie-prises-en-charge).
|
||||
5. 📦 Extensions de plugins avec près de 800 plugins disponibles pour une installation en un clic.
|
||||
6. 💻 Support WebUI.
|
||||
7. 🌐 Support de l'internationalisation (i18n).
|
||||
|
||||
## Démarrage rapide
|
||||
|
||||
#### Déploiement Docker (Recommandé 🥳)
|
||||
|
||||
Nous recommandons de déployer AstrBot en utilisant Docker ou Docker Compose.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### Déploiement uv
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### Déploiement BT-Panel
|
||||
|
||||
AstrBot s'est associé à BT-Panel et est maintenant disponible sur leur marketplace.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### Déploiement 1Panel
|
||||
|
||||
AstrBot a été officiellement listé sur le marketplace 1Panel.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Déployer sur RainYun
|
||||
|
||||
AstrBot a été officiellement listé sur la plateforme d'applications cloud de RainYun avec un déploiement en un clic.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Déployer sur Replit
|
||||
|
||||
Méthode de déploiement contribuée par la communauté.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Installateur Windows en un clic
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déployer AstrBot avec l'installateur Windows en un clic](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### Déploiement CasaOS
|
||||
|
||||
Méthode de déploiement contribuée par la communauté.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Déploiement manuel
|
||||
|
||||
Tout d'abord, installez uv :
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Installez AstrBot via Git Clone :
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## Plateformes de messagerie prises en charge
|
||||
|
||||
**Maintenues officiellement**
|
||||
|
||||
- QQ (Plateforme officielle & OneBot)
|
||||
- Telegram
|
||||
- Application WeChat Work & Bot intelligent WeChat Work
|
||||
- Service client WeChat & Comptes officiels WeChat
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Bientôt disponible)
|
||||
- LINE (Bientôt disponible)
|
||||
|
||||
**Maintenues par la communauté**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## Services de modèles pris en charge
|
||||
|
||||
**Services LLM**
|
||||
|
||||
- OpenAI et services compatibles
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Auto-hébergé)
|
||||
- LM Studio (Auto-hébergé)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**Plateformes LLMOps**
|
||||
|
||||
- Dify
|
||||
- Applications Alibaba Cloud Bailian
|
||||
- Coze
|
||||
|
||||
**Services de reconnaissance vocale**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**Services de synthèse vocale**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
## ❤️ Contribuer
|
||||
|
||||
Les Issues et Pull Requests sont toujours les bienvenues ! N'hésitez pas à soumettre vos modifications à ce projet :)
|
||||
|
||||
### Comment contribuer
|
||||
|
||||
Vous pouvez contribuer en examinant les issues ou en aidant à la revue des pull requests. Toutes les issues ou PRs sont les bienvenues pour encourager la participation de la communauté. Bien sûr, ce ne sont que des suggestions - vous pouvez contribuer de la manière que vous souhaitez. Pour l'ajout de nouvelles fonctionnalités, veuillez d'abord en discuter via une Issue.
|
||||
|
||||
### Environnement de développement
|
||||
|
||||
AstrBot utilise `ruff` pour le formatage et le linting du code.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Communauté
|
||||
|
||||
### Groupes QQ
|
||||
|
||||
- Groupe 1 : 322154837
|
||||
- Groupe 3 : 630166526
|
||||
- Groupe 5 : 822130018
|
||||
- Groupe 6 : 753075035
|
||||
- Groupe développeurs : 975206796
|
||||
|
||||
### Groupe Telegram
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Serveur Discord
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Remerciements spéciaux
|
||||
|
||||
Un grand merci à tous les contributeurs et développeurs de plugins pour leurs contributions à AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
De plus, la naissance de ce projet n'aurait pas été possible sans l'aide des projets open source suivants :
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - L'incroyable framework chat
|
||||
|
||||
## ⭐ Historique des étoiles
|
||||
|
||||
> [!TIP]
|
||||
> Si ce projet vous a aidé dans votre vie ou votre travail, ou si vous êtes intéressé par son développement futur, veuillez donner une étoile au projet. C'est la force motrice derrière la maintenance de ce projet open source <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
+26
-41
@@ -19,38 +19,30 @@
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E3%83%97%E3%83%A9%E3%82%B0%E3%82%A4%E3%83%B3&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">ドキュメント</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue</a>
|
||||
</div>
|
||||
|
||||
AstrBot は、主要なインスタントメッセージングアプリと統合できるオープンソースのオールインワン Agent チャットボットプラットフォームです。個人、開発者、チームに信頼性が高くスケーラブルな会話型 AI インフラストラクチャを提供します。パーソナル AI コンパニオン、インテリジェントカスタマーサービス、オートメーションアシスタント、エンタープライズナレッジベースなど、AstrBot を使用すると、IM プラットフォームのワークフロー内で本番環境対応の AI アプリケーションを迅速に構築できます。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークです。
|
||||
|
||||
## 主な機能
|
||||
|
||||
1. 💯 無料 & オープンソース。
|
||||
2. ✨ AI 大規模言語モデル対話、マルチモーダル、Agent、MCP、ナレッジベース、ペルソナ設定。
|
||||
3. 🤖 Dify、Alibaba Cloud 百炼、Coze などの Agent プラットフォームとの統合をサポート。
|
||||
4. 🌐 マルチプラットフォーム:QQ、WeChat Work、Feishu、DingTalk、WeChat 公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)。
|
||||
5. 📦 約800個のプラグインをワンクリックでインストール可能なプラグイン拡張機能。
|
||||
6. 💻 WebUI サポート。
|
||||
7. 🌐 国際化(i18n)サポート。
|
||||
1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載。
|
||||
2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート。
|
||||
3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。
|
||||
4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム。
|
||||
5. **WebUI**。ビジュアル設定とボット管理、充実した機能。
|
||||
|
||||
## クイックスタート
|
||||
## デプロイ方法
|
||||
|
||||
#### Docker デプロイ(推奨 🥳)
|
||||
|
||||
@@ -58,12 +50,6 @@ Docker / Docker Compose を使用した AstrBot のデプロイを推奨しま
|
||||
|
||||
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
|
||||
|
||||
#### uv デプロイ
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 宝塔パネルデプロイ
|
||||
|
||||
AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
|
||||
@@ -115,6 +101,24 @@ uv run main.py
|
||||
|
||||
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
|
||||
|
||||
## 🌍 コミュニティ
|
||||
|
||||
### QQ グループ
|
||||
|
||||
- 1群:322154837
|
||||
- 3群:630166526
|
||||
- 5群:822130018
|
||||
- 6群:753075035
|
||||
- 開発者群:975206796
|
||||
|
||||
### Telegram グループ
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord サーバー
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## サポートされているメッセージプラットフォーム
|
||||
|
||||
**公式メンテナンス**
|
||||
@@ -134,7 +138,6 @@ uv run main.py
|
||||
|
||||
**コミュニティメンテナンス**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
@@ -202,24 +205,6 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 コミュニティ
|
||||
|
||||
### QQ グループ
|
||||
|
||||
- 1群: 322154837
|
||||
- 3群: 630166526
|
||||
- 5群: 822130018
|
||||
- 6群: 753075035
|
||||
- 開発者群: 975206796
|
||||
|
||||
### Telegram グループ
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord サーバー
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
|
||||
|
||||
-249
@@ -1,249 +0,0 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20%D0%BF%D0%BB%D0%B0%D0%B3%D0%B8%D0%BD%D0%BE%D0%B2&style=for-the-badge&label=%D0%9C%D0%B0%D0%B3%D0%B0%D0%B7%D0%B8%D0%BD&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
|
||||
|
||||
<a href="https://astrbot.app/">Документация</a> |
|
||||
<a href="https://blog.astrbot.app/">Блог</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Дорожная карта</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Сообщить о проблеме</a>
|
||||
</div>
|
||||
|
||||
AstrBot — это универсальная платформа Agent-чатботов с открытым исходным кодом, которая интегрируется с основными приложениями для обмена мгновенными сообщениями. Она предоставляет надёжную и масштабируемую инфраструктуру разговорного ИИ для частных лиц, разработчиков и команд. Будь то персональный ИИ-компаньон, интеллектуальная служба поддержки, автоматизированный помощник или корпоративная база знаний — AstrBot позволяет быстро создавать готовые к использованию ИИ-приложения в рабочих процессах вашей платформы обмена сообщениями.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Основные возможности
|
||||
|
||||
1. 💯 Бесплатно и с открытым исходным кодом.
|
||||
2. ✨ ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности.
|
||||
3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов.
|
||||
4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями).
|
||||
5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик.
|
||||
6. 💻 Поддержка WebUI.
|
||||
7. 🌐 Поддержка интернационализации (i18n).
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
#### Развёртывание Docker (Рекомендуется 🥳)
|
||||
|
||||
Мы рекомендуем развёртывать AstrBot с помощью Docker или Docker Compose.
|
||||
|
||||
См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### Развёртывание uv
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### Развёртывание BT-Panel
|
||||
|
||||
AstrBot в партнёрстве с BT-Panel теперь доступен на их маркетплейсе.
|
||||
|
||||
См. официальную документацию: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### Развёртывание 1Panel
|
||||
|
||||
AstrBot официально размещён на маркетплейсе 1Panel.
|
||||
|
||||
См. официальную документацию: [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Развёртывание на RainYun
|
||||
|
||||
AstrBot официально размещён на облачной платформе приложений RainYun с развёртыванием в один клик.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Развёртывание на Replit
|
||||
|
||||
Метод развёртывания от сообщества.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Установщик Windows в один клик
|
||||
|
||||
См. официальную документацию: [Развёртывание AstrBot с установщиком Windows в один клик](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### Развёртывание CasaOS
|
||||
|
||||
Метод развёртывания от сообщества.
|
||||
|
||||
См. официальную документацию: [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Ручное развёртывание
|
||||
|
||||
Сначала установите uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Установите AstrBot через Git Clone:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## Поддерживаемые платформы обмена сообщениями
|
||||
|
||||
**Официально поддерживаемые**
|
||||
|
||||
- QQ (Официальная платформа и OneBot)
|
||||
- Telegram
|
||||
- Приложение WeChat Work и интеллектуальный бот WeChat Work
|
||||
- Служба поддержки WeChat и официальные аккаунты WeChat
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Скоро)
|
||||
- LINE (Скоро)
|
||||
|
||||
**Поддерживаемые сообществом**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## Поддерживаемые сервисы моделей
|
||||
|
||||
**Сервисы LLM**
|
||||
|
||||
- OpenAI и совместимые сервисы
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Самостоятельное размещение)
|
||||
- LM Studio (Самостоятельное размещение)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**Платформы LLMOps**
|
||||
|
||||
- Dify
|
||||
- Приложения Alibaba Cloud Bailian
|
||||
- Coze
|
||||
|
||||
**Сервисы распознавания речи**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**Сервисы синтеза речи**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
## ❤️ Вклад в проект
|
||||
|
||||
Issues и Pull Request всегда приветствуются! Не стесняйтесь отправлять свои изменения в этот проект :)
|
||||
|
||||
### Как внести вклад
|
||||
|
||||
Вы можете внести вклад, просматривая issues или помогая с ревью pull request. Любые issues или PR приветствуются для поощрения участия сообщества. Конечно, это лишь предложения — вы можете вносить вклад любым удобным для вас способом. Для добавления новых функций сначала обсудите это через Issue.
|
||||
|
||||
### Среда разработки
|
||||
|
||||
AstrBot использует `ruff` для форматирования и линтинга кода.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Сообщество
|
||||
|
||||
### Группы QQ
|
||||
|
||||
- Группа 1: 322154837
|
||||
- Группа 3: 630166526
|
||||
- Группа 5: 822130018
|
||||
- Группа 6: 753075035
|
||||
- Группа разработчиков: 975206796
|
||||
|
||||
### Группа Telegram
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Сервер Discord
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Особая благодарность
|
||||
|
||||
Особая благодарность всем контрибьюторам и разработчикам плагинов за их вклад в AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
Кроме того, рождение этого проекта было бы невозможно без помощи следующих проектов с открытым исходным кодом:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - Замечательный кошачий фреймворк
|
||||
|
||||
## ⭐ История звёзд
|
||||
|
||||
> [!TIP]
|
||||
> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
-249
@@ -1,249 +0,0 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%A0%B4&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">简体中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">文件</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路線圖</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題回報</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主流即時通訊軟體,為個人、開發者和團隊打造可靠、可擴展的對話式智慧基礎設施。無論是個人 AI 夥伴、智慧客服、自動化助手,還是企業知識庫,AstrBot 都能在您的即時通訊軟體平台的工作流程中快速構建生產可用的 AI 應用程式。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## 主要功能
|
||||
|
||||
1. 💯 免費 & 開源。
|
||||
2. ✨ AI 大型模型對話,多模態,Agent,MCP,知識庫,人格設定。
|
||||
3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。
|
||||
4. 🌐 多平台:QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
|
||||
5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。
|
||||
6. 💻 WebUI 支援。
|
||||
7. 🌐 國際化(i18n)支援。
|
||||
|
||||
## 快速開始
|
||||
|
||||
#### Docker 部署(推薦 🥳)
|
||||
|
||||
推薦使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
請參閱官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
|
||||
|
||||
#### uv 部署
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 寶塔面板部署
|
||||
|
||||
AstrBot 與寶塔面板合作,已上架至寶塔面板。
|
||||
|
||||
請參閱官方文件 [寶塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html)。
|
||||
|
||||
#### 1Panel 部署
|
||||
|
||||
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
|
||||
|
||||
請參閱官方文件 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html)。
|
||||
|
||||
#### 在雨雲上部署
|
||||
|
||||
AstrBot 已由雨雲官方上架至雲端應用程式平台,可一鍵部署。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
社群貢獻的部署方式。
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows 一鍵安裝器部署
|
||||
|
||||
請參閱官方文件 [使用 Windows 一鍵安裝器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html)。
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社群貢獻的部署方式。
|
||||
|
||||
請參閱官方文件 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html)。
|
||||
|
||||
#### 手動部署
|
||||
|
||||
首先安裝 uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
透過 Git Clone 安裝 AstrBot:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。
|
||||
|
||||
## 支援的訊息平台
|
||||
|
||||
**官方維護**
|
||||
|
||||
- QQ(官方平台 & OneBot)
|
||||
- Telegram
|
||||
- 企微應用 & 企微智慧機器人
|
||||
- 微信客服 & 微信公眾號
|
||||
- 飛書
|
||||
- 釘釘
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- Whatsapp(即將支援)
|
||||
- LINE(即將支援)
|
||||
|
||||
**社群維護**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## 支援的模型服務
|
||||
|
||||
**大型模型服務**
|
||||
|
||||
- OpenAI 及相容服務
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智譜 AI
|
||||
- DeepSeek
|
||||
- Ollama(本機部署)
|
||||
- LM Studio(本機部署)
|
||||
- [優雲智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小馬算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [矽基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派歐雲](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**LLMOps 平台**
|
||||
|
||||
- Dify
|
||||
- 阿里雲百煉應用
|
||||
- Coze
|
||||
|
||||
**語音轉文字服務**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**文字轉語音服務**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- 阿里雲百煉 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- 火山引擎 TTS
|
||||
|
||||
## ❤️ 貢獻
|
||||
|
||||
歡迎任何 Issues/Pull Requests!只需要將您的變更提交到此專案 :)
|
||||
|
||||
### 如何貢獻
|
||||
|
||||
您可以透過檢視問題或協助審核 PR(拉取請求)來貢獻。任何問題或 PR 都歡迎參與,以促進社群貢獻。當然,這些只是建議,您可以以任何方式進行貢獻。對於新功能的新增,請先透過 Issue 討論。
|
||||
|
||||
### 開發環境
|
||||
|
||||
AstrBot 使用 `ruff` 進行程式碼格式化和檢查。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 社群
|
||||
|
||||
### QQ 群組
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 開發者群:975206796
|
||||
|
||||
### Telegram 群組
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群組
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特別感謝所有 Contributors 和外掛開發者對 AstrBot 的貢獻 ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
此外,本專案的誕生離不開以下開源專案的幫助:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 偉大的貓貓框架
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本專案對您的生活 / 工作產生了幫助,或者您關注本專案的未來發展,請給專案 Star,這是我們維護這個開源專案的動力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
@@ -21,9 +21,6 @@ from astrbot.core.star.register import (
|
||||
from astrbot.core.star.register import register_on_llm_request as on_llm_request
|
||||
from astrbot.core.star.register import register_on_llm_response as on_llm_response
|
||||
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
|
||||
from astrbot.core.star.register import (
|
||||
register_on_waiting_llm_request as on_waiting_llm_request,
|
||||
)
|
||||
from astrbot.core.star.register import register_permission_type as permission_type
|
||||
from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
@@ -49,7 +46,6 @@ __all__ = [
|
||||
"on_llm_request",
|
||||
"on_llm_response",
|
||||
"on_platform_loaded",
|
||||
"on_waiting_llm_request",
|
||||
"permission_type",
|
||||
"platform_adapter_type",
|
||||
"regex",
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
import traceback
|
||||
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.provider import LLMResponse, ProviderRequest
|
||||
from astrbot.core import logger
|
||||
|
||||
from .long_term_memory import LongTermMemory
|
||||
from .process_llm_request import ProcessLLMRequest
|
||||
|
||||
|
||||
class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.ltm = None
|
||||
try:
|
||||
self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"聊天增强 err: {e}")
|
||||
|
||||
self.proc_llm_req = ProcessLLMRequest(self.context)
|
||||
|
||||
def ltm_enabled(self, event: AstrMessageEvent):
|
||||
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_ltm_settings"
|
||||
]
|
||||
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
"""群聊记忆增强"""
|
||||
has_image_or_plain = False
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Plain) or isinstance(comp, Image):
|
||||
has_image_or_plain = True
|
||||
break
|
||||
|
||||
if self.ltm_enabled(event) and self.ltm and has_image_or_plain:
|
||||
need_active = await self.ltm.need_active_reply(event)
|
||||
|
||||
group_icl_enable = self.context.get_config()["provider_ltm_settings"][
|
||||
"group_icl_enable"
|
||||
]
|
||||
if group_icl_enable:
|
||||
"""记录对话"""
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
|
||||
if need_active:
|
||||
"""主动回复"""
|
||||
provider = self.context.get_using_provider(event.unified_msg_origin)
|
||||
if not provider:
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
conv = None
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
if not session_curr_cid:
|
||||
logger.error(
|
||||
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
)
|
||||
return
|
||||
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
session_curr_cid,
|
||||
)
|
||||
|
||||
prompt = event.message_str
|
||||
|
||||
if not conv:
|
||||
logger.error("未找到对话,无法主动回复")
|
||||
return
|
||||
|
||||
yield event.request_llm(
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
conversation=conv,
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
await self.proc_llm_req.process_llm_request(event, req)
|
||||
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.on_req_llm(event, req)
|
||||
except BaseException as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_message_sent(self, event: AstrMessageEvent):
|
||||
"""消息发送后处理"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
clean_session = event.get_extra("_clean_ltm_session", False)
|
||||
if clean_session:
|
||||
await self.ltm.remove_session(event)
|
||||
except Exception as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
@@ -1,4 +0,0 @@
|
||||
name: astrbot
|
||||
desc: AstrBot 自带插件,包含人格注入、思考内容注入、群聊上下文感知等功能的实现,禁用后将无法使用这些功能。
|
||||
author: Soulter
|
||||
version: 4.1.0
|
||||
@@ -1,88 +0,0 @@
|
||||
import aiohttp
|
||||
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.star import command_management
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
|
||||
|
||||
class HelpCommand:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://astrbot.app/notice.json",
|
||||
timeout=2,
|
||||
) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
return ""
|
||||
|
||||
async def _build_reserved_command_lines(self) -> list[str]:
|
||||
"""
|
||||
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
|
||||
"""
|
||||
try:
|
||||
commands = await command_management.list_commands()
|
||||
except BaseException:
|
||||
return []
|
||||
|
||||
lines: list[str] = []
|
||||
hidden_commands = {"set", "unset", "websearch"}
|
||||
|
||||
def walk(items: list[dict], indent: int = 0):
|
||||
for item in items:
|
||||
if not item.get("reserved") or not item.get("enabled"):
|
||||
continue
|
||||
# 仅展示顶级指令或指令组
|
||||
if item.get("type") == "sub_command":
|
||||
continue
|
||||
if item.get("parent_signature"):
|
||||
continue
|
||||
|
||||
effective = (
|
||||
item.get("effective_command")
|
||||
or item.get("original_command")
|
||||
or item.get("handler_name")
|
||||
)
|
||||
if not effective:
|
||||
continue
|
||||
if effective in hidden_commands:
|
||||
continue
|
||||
|
||||
description = item.get("description") or ""
|
||||
desc_text = f" - {description}" if description else ""
|
||||
indent_prefix = " " * indent
|
||||
lines.append(f"{indent_prefix}/{effective}{desc_text}")
|
||||
|
||||
walk(commands)
|
||||
return lines
|
||||
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
"""查看帮助"""
|
||||
notice = ""
|
||||
try:
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
command_lines = await self._build_reserved_command_lines()
|
||||
commands_section = (
|
||||
"\n".join(command_lines) if command_lines else "暂无启用的内置指令"
|
||||
)
|
||||
|
||||
msg_parts = [
|
||||
f"AstrBot v{VERSION}(WebUI: {dashboard_version})",
|
||||
"内置指令:",
|
||||
commands_section,
|
||||
]
|
||||
if notice:
|
||||
msg_parts.append(notice)
|
||||
msg = "\n".join(msg_parts)
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
@@ -1,4 +0,0 @@
|
||||
name: builtin_commands
|
||||
desc: AstrBot 自带指令,提供常用的对话管理、工具使用、插件管理等功能。
|
||||
author: Soulter
|
||||
version: 0.0.1
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.11.2"
|
||||
__version__ = "3.5.23"
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot import logger
|
||||
else:
|
||||
try:
|
||||
from astrbot import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContextCompressor(Protocol):
|
||||
"""
|
||||
Protocol for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens for the model.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor:
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
drop_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
def split_history(
|
||||
messages: list[Message], keep_recent: int
|
||||
) -> tuple[list[Message], list[Message], list[Message]]:
|
||||
"""Split the message list into system messages, messages to summarize, and recent messages.
|
||||
|
||||
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
keep_recent: The number of latest messages to keep.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, messages_to_summarize, recent_messages)
|
||||
"""
|
||||
# keep the system messages
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) <= keep_recent:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
# Find the split point, ensuring recent_messages starts with a user message
|
||||
# This maintains complete conversation turns
|
||||
split_index = len(non_system_messages) - keep_recent
|
||||
|
||||
# Search backward from split_index to find the first user message
|
||||
# This ensures recent_messages starts with a user message (complete turn)
|
||||
while split_index > 0 and non_system_messages[split_index].role != "user":
|
||||
# TODO: +=1 or -=1 ? calculate by tokens
|
||||
split_index -= 1
|
||||
|
||||
# If we couldn't find a user message, keep all messages as recent
|
||||
if split_index == 0:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
messages_to_summarize = non_system_messages[:split_index]
|
||||
recent_messages = non_system_messages[split_index:]
|
||||
|
||||
return system_messages, messages_to_summarize, recent_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor:
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
1. Divide messages: keep the system message and the latest N messages.
|
||||
2. Send the old messages + the instruction message to the LLM.
|
||||
3. Reconstruct the message list: [system message, summary message, latest messages].
|
||||
"""
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
system_messages, messages_to_summarize, recent_messages = split_history(
|
||||
messages, self.keep_recent
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
response = await self.provider.text_chat(contexts=llm_payload)
|
||||
summary_content = response.completion_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
return messages
|
||||
|
||||
# build result
|
||||
result = []
|
||||
result.extend(system_messages)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=f"Our previous history conversation summary: {summary_content}",
|
||||
)
|
||||
)
|
||||
result.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Acknowledged the summary of our previous conversation history.",
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
return result
|
||||
@@ -1,35 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compressor import ContextCompressor
|
||||
from .token_counter import TokenCounter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Context configuration class."""
|
||||
|
||||
max_context_tokens: int = 0
|
||||
"""Maximum number of context tokens. <= 0 means no limit."""
|
||||
enforce_max_turns: int = -1 # -1 means no limit
|
||||
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
||||
truncate_turns: int = 1
|
||||
"""Number of conversation turns to discard at once when truncation is triggered.
|
||||
Two processes will use this value:
|
||||
|
||||
1. Enforce max turns truncation.
|
||||
2. Truncation by turns compression strategy.
|
||||
"""
|
||||
llm_compress_instruction: str | None = None
|
||||
"""Instruction prompt for LLM-based compression."""
|
||||
llm_compress_keep_recent: int = 0
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
custom_token_counter: TokenCounter | None = None
|
||||
"""Custom token counting method. If None, the default method is used."""
|
||||
custom_compressor: ContextCompressor | None = None
|
||||
"""Custom context compression method. If None, the default method is used."""
|
||||
@@ -1,120 +0,0 @@
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
1. Truncate by turns: remove older messages by turns.
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
config: The context configuration.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if config.custom_compressor:
|
||||
self.compressor = config.custom_compressor
|
||||
elif config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
truncate_turns=config.truncate_turns
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> list[Message]:
|
||||
"""Process the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
total_tokens = self.token_counter.count_tokens(
|
||||
result, trusted_token_usage
|
||||
)
|
||||
|
||||
if self.compressor.should_compress(
|
||||
result, total_tokens, self.config.max_context_tokens
|
||||
):
|
||||
result = await self._run_compression(result, total_tokens)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], prev_tokens: int
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
prev_tokens: The token count before compression.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
|
||||
# calculate compress rate
|
||||
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
|
||||
logger.info(
|
||||
f"Compress completed."
|
||||
f" {prev_tokens} -> {tokens_after_summary} tokens,"
|
||||
f" compression rate: {compress_rate:.2f}%.",
|
||||
)
|
||||
|
||||
# last check
|
||||
if self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
|
||||
return messages
|
||||
@@ -1,64 +0,0 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TokenCounter(Protocol):
|
||||
"""
|
||||
Protocol for token counters.
|
||||
Provides an interface for counting tokens in message lists.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
"""Count the total tokens in the message list.
|
||||
|
||||
Args:
|
||||
messages: The message list.
|
||||
trusted_token_usage: The total token usage that LLM API returned.
|
||||
For some cases, this value is more accurate.
|
||||
But some API does not return it, so the value defaults to 0.
|
||||
|
||||
Returns:
|
||||
The total token count.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
if trusted_token_usage > 0:
|
||||
return trusted_token_usage
|
||||
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
total += self._estimate_tokens(tc_str)
|
||||
|
||||
return total
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
||||
other_count = len(text) - chinese_count
|
||||
return int(chinese_count * 0.6 + other_count * 0.3)
|
||||
@@ -1,141 +0,0 @@
|
||||
from ..message import Message
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
"""Context truncator."""
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.role == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
def truncate_by_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
return messages
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
@@ -345,6 +345,9 @@ class MCPClient:
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
# Close current exit stack
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
@@ -356,9 +359,6 @@ class MCPClient:
|
||||
# Just clear the list to release references
|
||||
self._old_exit_stacks.clear()
|
||||
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
|
||||
class MCPTool(FunctionTool, Generic[TContext]):
|
||||
"""A function tool that calls an MCP service."""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class ContentPart(BaseModel):
|
||||
|
||||
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
||||
|
||||
type: Literal["text", "think", "image_url", "audio_url"]
|
||||
type: str
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -63,28 +63,6 @@ class TextPart(ContentPart):
|
||||
text: str
|
||||
|
||||
|
||||
class ThinkPart(ContentPart):
|
||||
"""
|
||||
>>> ThinkPart(think="I think I need to think about this.").model_dump()
|
||||
{'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None}
|
||||
"""
|
||||
|
||||
type: str = "think"
|
||||
think: str
|
||||
encrypted: str | None = None
|
||||
"""Encrypted thinking content, or signature."""
|
||||
|
||||
def merge_in_place(self, other: Any) -> bool:
|
||||
if not isinstance(other, ThinkPart):
|
||||
return False
|
||||
if self.encrypted:
|
||||
return False
|
||||
self.think += other.think
|
||||
if other.encrypted:
|
||||
self.encrypted = other.encrypted
|
||||
return True
|
||||
|
||||
|
||||
class ImageURLPart(ContentPart):
|
||||
"""
|
||||
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
||||
@@ -144,12 +122,10 @@ class ToolCall(BaseModel):
|
||||
extra_content: dict[str, Any] | None = None
|
||||
"""Extra metadata for the tool call."""
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
if self.extra_content is None:
|
||||
data.pop("extra_content", None)
|
||||
return data
|
||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||
return super().model_dump(**kwargs)
|
||||
|
||||
|
||||
class ToolCallPart(BaseModel):
|
||||
@@ -169,48 +145,22 @@ class Message(BaseModel):
|
||||
"tool",
|
||||
]
|
||||
|
||||
content: str | list[ContentPart] | None = None
|
||||
content: str | list[ContentPart]
|
||||
"""The content of the message."""
|
||||
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
"""The tool calls of the message."""
|
||||
|
||||
tool_call_id: str | None = None
|
||||
"""The ID of the tool call."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content_required(self):
|
||||
# assistant + tool_calls is not None: allow content to be None
|
||||
if self.role == "assistant" and self.tool_calls is not None:
|
||||
return self
|
||||
|
||||
# other all cases: content is required
|
||||
if self.content is None:
|
||||
raise ValueError(
|
||||
"content is required unless role='assistant' and tool_calls is not None"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.tool_calls is None:
|
||||
data.pop("tool_calls", None)
|
||||
if self.tool_call_id is None:
|
||||
data.pop("tool_call_id", None)
|
||||
return data
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
|
||||
|
||||
class ToolCallMessageSegment(Message):
|
||||
"""A message segment representing a tool call."""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class UserMessageSegment(Message):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import typing as T
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import TokenUsage
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
@@ -13,23 +12,3 @@ class AgentResponseData(T.TypedDict):
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStats:
|
||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||
start_time: float = 0.0
|
||||
end_time: float = 0.0
|
||||
time_to_first_token: float = 0.0
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"token_usage": self.token_usage.__dict__,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"time_to_first_token": self.time_to_first_token,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from .message import Message
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
|
||||
@@ -2,12 +2,13 @@ import abc
|
||||
import typing as T
|
||||
from enum import Enum, auto
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..response import AgentResponse
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
@@ -23,7 +24,9 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
@@ -57,9 +60,3 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
This method should be called after the agent is done.
|
||||
"""
|
||||
...
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""Transition the agent state."""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
@@ -1,367 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Coze Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://"),
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
# 会话相关缓存
|
||||
self.file_id_cache: dict[str, dict[str, str]] = {}
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Coze Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Coze 请求并处理结果
|
||||
async for response in self._execute_coze_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"Coze 请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await self.api_client.close()
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _execute_coze_request(self):
|
||||
"""执行 Coze 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
contexts = self.req.contexts or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
# 用户ID参数
|
||||
user_id = session_id
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=user_id,
|
||||
key="coze_conversation_id",
|
||||
default="",
|
||||
)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
# 处理历史上下文
|
||||
if not self.auto_save_history and contexts:
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
# 处理上下文中的图片
|
||||
content = ctx["content"]
|
||||
if isinstance(content, list):
|
||||
# 多模态内容,需要处理图片
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片上传
|
||||
try:
|
||||
image_data = item.get("image_url", {})
|
||||
url = image_data.get("url", "")
|
||||
if url:
|
||||
file_id = (
|
||||
await self._download_and_upload_image(
|
||||
url, session_id
|
||||
)
|
||||
)
|
||||
processed_content.append(
|
||||
{
|
||||
"type": "file",
|
||||
"file_id": file_id,
|
||||
"file_url": url,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理上下文图片失败: {e}")
|
||||
continue
|
||||
|
||||
if processed_content:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本内容
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": content,
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
|
||||
# 构建当前消息
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
# the url is a base64 string
|
||||
try:
|
||||
image_data = base64.b64decode(url)
|
||||
file_id = await self.api_client.upload_file(image_data)
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理图片失败 {url}: {e}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
elif prompt:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
# 执行 Coze API 请求
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=user_id,
|
||||
key="coze_conversation_id",
|
||||
value=data["conversation_id"],
|
||||
)
|
||||
|
||||
if event_type == "conversation.message.delta":
|
||||
# 增量消息
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
accumulated_content += content
|
||||
message_started = True
|
||||
|
||||
# 如果是流式响应,发送增量数据
|
||||
if self.streaming:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(content)
|
||||
),
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
# 消息完成
|
||||
logger.debug("Coze message completed")
|
||||
message_started = True
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
# 对话完成
|
||||
logger.debug("Coze chat completed")
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
# 错误处理
|
||||
error_msg = data.get("msg", "未知错误")
|
||||
error_code = data.get("code", "UNKNOWN")
|
||||
logger.error(f"Coze 出现错误: {error_code} - {error_msg}")
|
||||
raise Exception(f"Coze 出现错误: {error_code} - {error_msg}")
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
logger.warning("Coze 未返回任何内容")
|
||||
accumulated_content = ""
|
||||
|
||||
# 创建最终响应
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self,
|
||||
image_url: str,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
import hashlib
|
||||
|
||||
# 计算哈希实现缓存
|
||||
cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest()
|
||||
|
||||
if session_id:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
file_id = await self.api_client.upload_file(image_data)
|
||||
|
||||
if session_id:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {e!s}")
|
||||
raise Exception(f"处理图片失败: {e!s}")
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -1,403 +0,0 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import queue
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import typing as T
|
||||
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dashscope Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("阿里云百炼 API Key 不能为空。")
|
||||
self.app_id = provider_config.get("dashscope_app_id", "")
|
||||
if not self.app_id:
|
||||
raise Exception("阿里云百炼 APP ID 不能为空。")
|
||||
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
||||
if not self.dashscope_app_type:
|
||||
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||
|
||||
self.variables: dict = provider_config.get("variables", {}) or {}
|
||||
self.rag_options: dict = provider_config.get("rag_options", {})
|
||||
self.output_reference = self.rag_options.get("output_reference", False)
|
||||
self.rag_options = self.rag_options.copy()
|
||||
self.rag_options.pop("output_reference", None)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
def has_rag_options(self):
|
||||
"""判断是否有 RAG 选项
|
||||
|
||||
Returns:
|
||||
bool: 是否有 RAG 选项
|
||||
|
||||
"""
|
||||
if self.rag_options and (
|
||||
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||
or len(self.rag_options.get("file_ids", [])) > 0
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Dashscope Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Dashscope 请求并处理结果
|
||||
async for response in self._execute_dashscope_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"阿里云百炼请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
def _consume_sync_generator(
|
||||
self, response: T.Any, response_queue: queue.Queue
|
||||
) -> None:
|
||||
"""在线程中消费同步generator,将结果放入队列
|
||||
|
||||
Args:
|
||||
response: 同步generator对象
|
||||
response_queue: 用于传递数据的队列
|
||||
|
||||
"""
|
||||
try:
|
||||
if self.streaming:
|
||||
for chunk in response:
|
||||
response_queue.put(("data", chunk))
|
||||
else:
|
||||
response_queue.put(("data", response))
|
||||
except Exception as e:
|
||||
response_queue.put(("error", e))
|
||||
finally:
|
||||
response_queue.put(("done", None))
|
||||
|
||||
async def _process_stream_chunk(
|
||||
self, chunk: ApplicationResponse, output_text: str
|
||||
) -> tuple[str, list | None, AgentResponse | None]:
|
||||
"""处理流式响应的单个chunk
|
||||
|
||||
Args:
|
||||
chunk: Dashscope响应chunk
|
||||
output_text: 当前累积的输出文本
|
||||
|
||||
Returns:
|
||||
(更新后的output_text, doc_references, AgentResponse或None)
|
||||
|
||||
"""
|
||||
logger.debug(f"dashscope stream chunk: {chunk}")
|
||||
|
||||
if chunk.status_code != 200:
|
||||
logger.error(
|
||||
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
|
||||
)
|
||||
self._transition_state(AgentState.ERROR)
|
||||
error_msg = (
|
||||
f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
|
||||
)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err",
|
||||
result_chain=MessageChain().message(error_msg),
|
||||
)
|
||||
return (
|
||||
output_text,
|
||||
None,
|
||||
AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(chain=MessageChain().message(error_msg)),
|
||||
),
|
||||
)
|
||||
|
||||
chunk_text = chunk.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
|
||||
|
||||
response = None
|
||||
if chunk_text:
|
||||
output_text += chunk_text
|
||||
response = AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=MessageChain().message(chunk_text)),
|
||||
)
|
||||
|
||||
# 获取文档引用
|
||||
doc_references = chunk.output.get("doc_references", None)
|
||||
|
||||
return output_text, doc_references, response
|
||||
|
||||
def _format_doc_references(self, doc_references: list) -> str:
|
||||
"""格式化文档引用为文本
|
||||
|
||||
Args:
|
||||
doc_references: 文档引用列表
|
||||
|
||||
Returns:
|
||||
格式化后的引用文本
|
||||
|
||||
"""
|
||||
ref_parts = []
|
||||
for ref in doc_references:
|
||||
ref_title = (
|
||||
ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
|
||||
)
|
||||
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
|
||||
ref_str = "".join(ref_parts)
|
||||
return f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
async def _build_request_payload(
|
||||
self, prompt: str, session_id: str, contexts: list, system_prompt: str
|
||||
) -> dict:
|
||||
"""构建请求payload
|
||||
|
||||
Args:
|
||||
prompt: 用户输入
|
||||
session_id: 会话ID
|
||||
contexts: 上下文列表
|
||||
system_prompt: 系统提示词
|
||||
|
||||
Returns:
|
||||
请求payload字典
|
||||
|
||||
"""
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dashscope_conversation_id",
|
||||
default="",
|
||||
)
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
session_var = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_variables",
|
||||
default={},
|
||||
)
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||
and not self.has_rag_options()
|
||||
):
|
||||
# 支持多轮对话的
|
||||
p = {
|
||||
"app_id": self.app_id,
|
||||
"api_key": self.api_key,
|
||||
"prompt": prompt,
|
||||
"biz_params": payload_vars or None,
|
||||
"stream": self.streaming,
|
||||
"incremental_output": True,
|
||||
}
|
||||
if conversation_id:
|
||||
p["session_id"] = conversation_id
|
||||
return p
|
||||
else:
|
||||
# 不支持多轮对话的
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"prompt": prompt,
|
||||
"api_key": self.api_key,
|
||||
"biz_params": payload_vars or None,
|
||||
"stream": self.streaming,
|
||||
"incremental_output": True,
|
||||
}
|
||||
if self.rag_options:
|
||||
payload["rag_options"] = self.rag_options
|
||||
return payload
|
||||
|
||||
async def _handle_streaming_response(
|
||||
self, response: T.Any, session_id: str
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""处理流式响应
|
||||
|
||||
Args:
|
||||
response: Dashscope 流式响应 generator
|
||||
|
||||
Yields:
|
||||
AgentResponse 对象
|
||||
|
||||
"""
|
||||
response_queue = queue.Queue()
|
||||
consumer_thread = threading.Thread(
|
||||
target=self._consume_sync_generator,
|
||||
args=(response, response_queue),
|
||||
daemon=True,
|
||||
)
|
||||
consumer_thread.start()
|
||||
|
||||
output_text = ""
|
||||
doc_references = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, response_queue.get, True, 1
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if item_type == "done":
|
||||
break
|
||||
elif item_type == "error":
|
||||
raise item_data
|
||||
elif item_type == "data":
|
||||
chunk = item_data
|
||||
assert isinstance(chunk, ApplicationResponse)
|
||||
|
||||
(
|
||||
output_text,
|
||||
chunk_doc_refs,
|
||||
response,
|
||||
) = await self._process_stream_chunk(chunk, output_text)
|
||||
|
||||
if response:
|
||||
if response.type == "err":
|
||||
yield response
|
||||
return
|
||||
yield response
|
||||
|
||||
if chunk_doc_refs:
|
||||
doc_references = chunk_doc_refs
|
||||
|
||||
if chunk.output.session_id:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dashscope_conversation_id",
|
||||
value=chunk.output.session_id,
|
||||
)
|
||||
|
||||
# 添加 RAG 引用
|
||||
if self.output_reference and doc_references:
|
||||
ref_text = self._format_doc_references(doc_references)
|
||||
output_text += ref_text
|
||||
|
||||
if self.streaming:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=MessageChain().message(ref_text)),
|
||||
)
|
||||
|
||||
# 创建最终响应
|
||||
chain = MessageChain(chain=[Comp.Plain(output_text)])
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def _execute_dashscope_request(self):
|
||||
"""执行 Dashscope 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
contexts = self.req.contexts or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
# 检查图片输入
|
||||
if image_urls:
|
||||
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
||||
|
||||
# 构建请求payload
|
||||
payload = await self._build_request_payload(
|
||||
prompt, session_id, contexts, system_prompt
|
||||
)
|
||||
|
||||
if not self.streaming:
|
||||
payload["incremental_output"] = False
|
||||
|
||||
# 发起请求
|
||||
partial = functools.partial(Application.call, **payload)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
async for resp in self._handle_streaming_response(response, session_id):
|
||||
yield resp
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -1,336 +0,0 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .dify_api_client import DifyAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DifyAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dify Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_type = provider_config.get("dify_api_type", "chat")
|
||||
self.workflow_output_key = provider_config.get(
|
||||
"dify_workflow_output_key",
|
||||
"astrbot_wf_output",
|
||||
)
|
||||
self.dify_query_input_key = provider_config.get(
|
||||
"dify_query_input_key",
|
||||
"astrbot_text_query",
|
||||
)
|
||||
self.variables: dict = provider_config.get("variables", {}) or {}
|
||||
self.timeout = provider_config.get("timeout", 60)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
self.api_client = DifyAPIClient(self.api_key, self.api_base)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Dify Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Dify 请求并处理结果
|
||||
async for response in self._execute_dify_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Dify 请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"Dify 请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"Dify 请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await self.api_client.close()
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _execute_dify_request(self):
|
||||
"""执行 Dify 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dify_conversation_id",
|
||||
default="",
|
||||
)
|
||||
result = ""
|
||||
|
||||
# 处理图片上传
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
# image_url is a base64 string
|
||||
try:
|
||||
image_data = base64.b64decode(image_url)
|
||||
file_response = await self.api_client.file_upload(
|
||||
file_data=image_data,
|
||||
user=session_id,
|
||||
mime_type="image/png",
|
||||
file_name="image.png",
|
||||
)
|
||||
logger.debug(f"Dify 上传图片响应:{file_response}")
|
||||
if "id" not in file_response:
|
||||
logger.warning(
|
||||
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
|
||||
)
|
||||
continue
|
||||
files_payload.append(
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response["id"],
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"上传图片失败:{e}")
|
||||
continue
|
||||
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_var = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_variables",
|
||||
default={},
|
||||
)
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
# 处理不同的 API 类型
|
||||
match self.api_type:
|
||||
case "chat" | "agent" | "chatflow":
|
||||
if not prompt:
|
||||
prompt = "请描述这张图片。"
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**payload_vars,
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk["event"] == "message" or chunk["event"] == "agent_message":
|
||||
result += chunk["answer"]
|
||||
if not conversation_id:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dify_conversation_id",
|
||||
value=chunk["conversation_id"],
|
||||
)
|
||||
conversation_id = chunk["conversation_id"]
|
||||
|
||||
# 如果是流式响应,发送增量数据
|
||||
if self.streaming and chunk["answer"]:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(chunk["answer"])
|
||||
),
|
||||
)
|
||||
elif chunk["event"] == "message_end":
|
||||
logger.debug("Dify message end")
|
||||
break
|
||||
elif chunk["event"] == "error":
|
||||
logger.error(f"Dify 出现错误:{chunk}")
|
||||
raise Exception(
|
||||
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
|
||||
)
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**payload_vars,
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify workflow resp chunk: {chunk}")
|
||||
match chunk["event"]:
|
||||
case "workflow_started":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
|
||||
)
|
||||
case "node_finished":
|
||||
logger.debug(
|
||||
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
|
||||
)
|
||||
case "text_chunk":
|
||||
if self.streaming and chunk["data"]["text"]:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
chunk["data"]["text"]
|
||||
)
|
||||
),
|
||||
)
|
||||
case "workflow_finished":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
|
||||
)
|
||||
logger.debug(f"Dify 工作流结果:{chunk}")
|
||||
if chunk["data"]["error"]:
|
||||
logger.error(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
||||
)
|
||||
if self.workflow_output_key not in chunk["data"]["outputs"]:
|
||||
raise Exception(
|
||||
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
|
||||
)
|
||||
result = chunk
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
|
||||
if not result:
|
||||
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
||||
|
||||
# 解析结果
|
||||
chain = await self.parse_dify_result(result)
|
||||
|
||||
# 创建最终响应
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
"""解析 Dify 的响应结果"""
|
||||
if isinstance(chunk, str):
|
||||
# Chat
|
||||
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||
|
||||
async def parse_file(item: dict):
|
||||
match item["type"]:
|
||||
case "image":
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
return Comp.Video(file=item["url"])
|
||||
case _:
|
||||
return Comp.File(name=item["filename"], file=item["url"])
|
||||
|
||||
output = chunk["data"]["outputs"][self.workflow_output_key]
|
||||
chains = []
|
||||
if isinstance(output, str):
|
||||
# 纯文本输出
|
||||
chains.append(Comp.Plain(output))
|
||||
elif isinstance(output, list):
|
||||
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
|
||||
for item in output:
|
||||
# handle Array[File]
|
||||
if (
|
||||
not isinstance(item, dict)
|
||||
or item.get("dify_model_identity", "") != "__dify__file__"
|
||||
):
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
break
|
||||
else:
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
|
||||
# scan file
|
||||
files = chunk["data"].get("files", [])
|
||||
for item in files:
|
||||
comp = await parse_file(item)
|
||||
chains.append(comp)
|
||||
|
||||
return MessageChain(chain=chains)
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -1,5 +1,4 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
@@ -13,8 +12,6 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import TextPart, ThinkPart
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
@@ -25,13 +22,9 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
from ..response import AgentResponseData
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||
@@ -51,47 +44,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
enforce_max_turns: int = -1,
|
||||
# llm compressor
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
# truncate by turns compressor
|
||||
truncate_turns: int = 1,
|
||||
# customize
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = streaming
|
||||
self.enforce_max_turns = enforce_max_turns
|
||||
self.llm_compress_instruction = llm_compress_instruction
|
||||
self.llm_compress_keep_recent = llm_compress_keep_recent
|
||||
self.llm_compress_provider = llm_compress_provider
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_config = ContextConfig(
|
||||
# <=0 will never do compress
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
# enforce max turns before compression
|
||||
enforce_max_turns=self.enforce_max_turns,
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
custom_token_counter=self.custom_token_counter,
|
||||
custom_compressor=self.custom_compressor,
|
||||
)
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
@@ -113,25 +69,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
self.stats = AgentStats()
|
||||
self.stats.start_time = time.time()
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**payload)
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
@@ -151,18 +102,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate and compress
|
||||
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
|
||||
self.run_context.messages = await self.context_manager.process(
|
||||
self.run_context.messages, trusted_token_usage=token_usage
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
if self.stats.time_to_first_token == 0:
|
||||
self.stats.time_to_first_token = time.time() - self.stats.start_time
|
||||
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
@@ -186,10 +128,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
|
||||
if not llm_response.is_chunk and llm_response.usage:
|
||||
# only count the token usage of the final response for computation purpose
|
||||
self.stats.token_usage += llm_response.usage
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
@@ -201,7 +139,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self.stats.end_time = time.time()
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
@@ -216,21 +153,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
|
||||
# record the final assistant message
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
self.run_context.messages.append(Message(role="assistant", content=parts))
|
||||
|
||||
# call the on_agent_done hook
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
),
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -253,35 +182,29 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
if result.type is None:
|
||||
# should not happen
|
||||
continue
|
||||
if result.type == "tool_direct_result":
|
||||
ar_type = "tool_call_result"
|
||||
else:
|
||||
ar_type = result.type
|
||||
result.type = "tool_call_result"
|
||||
yield AgentResponse(
|
||||
type=ar_type,
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||
content=parts,
|
||||
content=llm_resp.completion_text,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
@@ -302,25 +225,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
|
||||
if not self.done():
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
# 拔掉所有工具
|
||||
if self.req:
|
||||
self.req.func_tool = None
|
||||
# 注入提示词
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
# 再执行最后一步
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -336,19 +240,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
yield MessageChain(
|
||||
type="tool_call",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"name": func_tool_name,
|
||||
"args": func_tool_args,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
@@ -422,6 +313,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=res.content[0].text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -443,6 +335,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=resource.text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
@@ -466,34 +359,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content="返回的数据类型不受支持",
|
||||
),
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||
logger.warning(
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具没有返回值或者将结果直接发送给了用户*",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)}。",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
|
||||
),
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -515,22 +394,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Generic
|
||||
|
||||
import jsonschema
|
||||
@@ -7,8 +7,6 @@ from deprecated import deprecated
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
@@ -40,10 +38,7 @@ class ToolSchema:
|
||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: (
|
||||
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
|
||||
| None
|
||||
) = None
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
|
||||
@@ -6,10 +6,8 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class AstrAgentContext:
|
||||
__pydantic_config__ = {"arbitrary_types_allowed": True}
|
||||
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
|
||||
@@ -13,12 +13,6 @@ from astrbot.core.star.star_handler import EventType
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
if llm_response and llm_response.reasoning_content:
|
||||
# we will use this in result_decorate stage to inject reasoning content to chain
|
||||
run_context.context.event.set_extra(
|
||||
"_llm_reasoning_content", llm_response.reasoning_content
|
||||
)
|
||||
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
|
||||
@@ -2,16 +2,13 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
@@ -25,25 +22,8 @@ async def run_agent(
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step + 1:
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
|
||||
if step_idx == max_step + 1:
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
if not agent_runner.done():
|
||||
# 拔掉所有工具
|
||||
if agent_runner.req:
|
||||
agent_runner.req.func_tool = None
|
||||
# 注入提示词
|
||||
agent_runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
@@ -52,27 +32,16 @@ async def run_agent(
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(msg_chain)
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
if astr_event.get_platform_id() == "webchat":
|
||||
await astr_event.send(msg_chain)
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
if show_tool_use:
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
await astr_event.send(chain)
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
@@ -99,33 +68,11 @@ async def run_agent(
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
# send agent stats to webchat
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(
|
||||
MessageChain(
|
||||
type="agent_stats",
|
||||
chain=[Json(data=agent_runner.stats.to_dict())],
|
||||
)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
|
||||
error_llm_response = LLMResponse(
|
||||
role="err",
|
||||
completion_text=err_msg,
|
||||
)
|
||||
try:
|
||||
await agent_runner.agent_hooks.on_agent_done(
|
||||
agent_runner.run_context, error_llm_response
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in on_agent_done hook")
|
||||
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
|
||||
@@ -185,11 +185,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[
|
||||
...,
|
||||
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
|
||||
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
|
||||
],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
@@ -209,42 +205,12 @@ async def call_local_llm_tool(
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
raise Exception(f"Tool execution ValueError: {e}") from e
|
||||
except TypeError as e:
|
||||
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
|
||||
try:
|
||||
sig = inspect.signature(handler)
|
||||
params = list(sig.parameters.values())
|
||||
# 跳过第一个参数(event 或 context)
|
||||
if params:
|
||||
params = params[1:]
|
||||
|
||||
param_strs = []
|
||||
for param in params:
|
||||
param_str = param.name
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
# 获取类型注解的字符串表示
|
||||
if isinstance(param.annotation, type):
|
||||
type_str = param.annotation.__name__
|
||||
else:
|
||||
type_str = str(param.annotation)
|
||||
param_str += f": {type_str}"
|
||||
if param.default != inspect.Parameter.empty:
|
||||
param_str += f" = {param.default!r}"
|
||||
param_strs.append(param_str)
|
||||
|
||||
handler_param_str = (
|
||||
", ".join(param_strs) if param_strs else "(no additional parameters)"
|
||||
)
|
||||
except Exception:
|
||||
handler_param_str = "(unable to inspect signature)"
|
||||
|
||||
raise Exception(
|
||||
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
|
||||
) from e
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""AstrBot 备份与恢复模块
|
||||
|
||||
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
|
||||
"""
|
||||
|
||||
# 从 constants 模块导入共享常量
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
# 导入导出器和导入器
|
||||
from .exporter import AstrBotExporter
|
||||
from .importer import AstrBotImporter, ImportPreCheckResult
|
||||
|
||||
__all__ = [
|
||||
"AstrBotExporter",
|
||||
"AstrBotImporter",
|
||||
"ImportPreCheckResult",
|
||||
"MAIN_DB_MODELS",
|
||||
"KB_METADATA_MODELS",
|
||||
"get_backup_directories",
|
||||
"BACKUP_MANIFEST_VERSION",
|
||||
]
|
||||
@@ -1,77 +0,0 @@
|
||||
"""AstrBot 备份模块共享常量
|
||||
|
||||
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
)
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_data_path,
|
||||
get_astrbot_plugin_path,
|
||||
get_astrbot_t2i_templates_path,
|
||||
get_astrbot_temp_path,
|
||||
get_astrbot_webchat_path,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 共享常量 - 确保导出和导入端配置一致
|
||||
# ============================================================
|
||||
|
||||
# 主数据库模型类映射
|
||||
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
|
||||
"platform_stats": PlatformStat,
|
||||
"conversations": ConversationV2,
|
||||
"personas": Persona,
|
||||
"preferences": Preference,
|
||||
"platform_message_history": PlatformMessageHistory,
|
||||
"platform_sessions": PlatformSession,
|
||||
"attachments": Attachment,
|
||||
"command_configs": CommandConfig,
|
||||
"command_conflicts": CommandConflict,
|
||||
}
|
||||
|
||||
# 知识库元数据模型类映射
|
||||
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
|
||||
"knowledge_bases": KnowledgeBase,
|
||||
"kb_documents": KBDocument,
|
||||
"kb_media": KBMedia,
|
||||
}
|
||||
|
||||
|
||||
def get_backup_directories() -> dict[str, str]:
|
||||
"""获取需要备份的目录列表
|
||||
|
||||
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
|
||||
|
||||
Returns:
|
||||
dict: 键为备份文件中的目录名称,值为目录的绝对路径
|
||||
"""
|
||||
return {
|
||||
"plugins": get_astrbot_plugin_path(), # 插件本体
|
||||
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
|
||||
"config": get_astrbot_config_path(), # 配置目录
|
||||
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
|
||||
"webchat": get_astrbot_webchat_path(), # WebChat 数据
|
||||
"temp": get_astrbot_temp_path(), # 临时文件
|
||||
}
|
||||
|
||||
|
||||
# 备份清单版本号
|
||||
BACKUP_MANIFEST_VERSION = "1.1"
|
||||
@@ -1,477 +0,0 @@
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
负责将所有数据导出为 ZIP 备份文件。
|
||||
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
|
||||
|
||||
class AstrBotExporter:
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
导出内容:
|
||||
- 主数据库所有表(data/data_v4.db)
|
||||
- 知识库元数据(data/knowledge_base/kb.db)
|
||||
- 每个知识库的向量文档数据
|
||||
- 配置文件(data/cmd_config.json)
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self._checksums: dict[str, str] = {}
|
||||
|
||||
async def export_all(
|
||||
self,
|
||||
output_dir: str | None = None,
|
||||
progress_callback: Any | None = None,
|
||||
) -> str:
|
||||
"""导出所有数据到 ZIP 文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
str: 生成的 ZIP 文件路径
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = get_astrbot_backups_path()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||
zip_path = os.path.join(output_dir, zip_filename)
|
||||
|
||||
logger.info(f"开始导出备份到 {zip_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# 1. 导出主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
|
||||
main_data = await self._export_main_database()
|
||||
main_db_json = json.dumps(
|
||||
main_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/main_db.json", main_db_json)
|
||||
self._add_checksum("databases/main_db.json", main_db_json)
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导出完成")
|
||||
|
||||
# 2. 导出知识库数据
|
||||
kb_meta_data: dict[str, Any] = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
"kb_media": [],
|
||||
}
|
||||
if self.kb_manager:
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 0, 100, "正在导出知识库元数据..."
|
||||
)
|
||||
kb_meta_data = await self._export_kb_metadata()
|
||||
kb_meta_json = json.dumps(
|
||||
kb_meta_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/kb_metadata.json", kb_meta_json)
|
||||
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 100, 100, "知识库元数据导出完成"
|
||||
)
|
||||
|
||||
# 导出每个知识库的文档数据
|
||||
kb_insts = self.kb_manager.kb_insts
|
||||
total_kbs = len(kb_insts)
|
||||
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents",
|
||||
idx,
|
||||
total_kbs,
|
||||
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
|
||||
)
|
||||
doc_data = await self._export_kb_documents(kb_helper)
|
||||
doc_json = json.dumps(
|
||||
doc_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
zf.writestr(doc_path, doc_json)
|
||||
self._add_checksum(doc_path, doc_json)
|
||||
|
||||
# 导出 FAISS 索引文件
|
||||
await self._export_faiss_index(zf, kb_helper, kb_id)
|
||||
|
||||
# 导出知识库多媒体文件
|
||||
await self._export_kb_media_files(zf, kb_helper, kb_id)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
|
||||
)
|
||||
|
||||
# 3. 导出配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_content = f.read()
|
||||
zf.writestr("config/cmd_config.json", config_content)
|
||||
self._add_checksum("config/cmd_config.json", config_content)
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导出完成")
|
||||
|
||||
# 4. 导出附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导出附件...")
|
||||
await self._export_attachments(zf, main_data.get("attachments", []))
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导出完成")
|
||||
|
||||
# 5. 导出插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导出插件和数据目录..."
|
||||
)
|
||||
dir_stats = await self._export_directories(zf)
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导出完成")
|
||||
|
||||
# 6. 生成 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 0, 100, "正在生成清单...")
|
||||
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", manifest_json)
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 100, 100, "清单生成完成")
|
||||
|
||||
logger.info(f"备份导出完成: {zip_path}")
|
||||
return zip_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份导出失败: {e}")
|
||||
# 清理失败的文件
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
raise
|
||||
|
||||
async def _export_main_database(self) -> dict[str, list[dict]]:
|
||||
"""导出主数据库所有表"""
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
|
||||
"""导出知识库元数据库"""
|
||||
if not self.kb_manager:
|
||||
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
|
||||
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
|
||||
"""导出知识库的文档块数据"""
|
||||
try:
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
|
||||
vec_db: FaissVecDB = kb_helper.vec_db
|
||||
if not vec_db or not vec_db.document_storage:
|
||||
return {"documents": []}
|
||||
|
||||
# 获取所有文档
|
||||
docs = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={},
|
||||
offset=0,
|
||||
limit=None, # 获取全部
|
||||
)
|
||||
|
||||
return {"documents": docs}
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库文档失败: {e}")
|
||||
return {"documents": []}
|
||||
|
||||
async def _export_faiss_index(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_helper: Any,
|
||||
kb_id: str,
|
||||
) -> None:
|
||||
"""导出 FAISS 索引文件"""
|
||||
try:
|
||||
index_path = kb_helper.kb_dir / "index.faiss"
|
||||
if index_path.exists():
|
||||
archive_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
zf.write(str(index_path), archive_path)
|
||||
logger.debug(f"导出 FAISS 索引: {archive_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"导出 FAISS 索引失败: {e}")
|
||||
|
||||
async def _export_kb_media_files(
|
||||
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
|
||||
) -> None:
|
||||
"""导出知识库的多媒体文件"""
|
||||
try:
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if not media_dir.exists():
|
||||
return
|
||||
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(kb_helper.kb_dir)
|
||||
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库媒体文件失败: {e}")
|
||||
|
||||
async def _export_directories(
|
||||
self, zf: zipfile.ZipFile
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""导出插件和其他数据目录
|
||||
|
||||
Returns:
|
||||
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
|
||||
"""
|
||||
stats: dict[str, dict[str, int]] = {}
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name, dir_path in backup_directories.items():
|
||||
full_path = Path(dir_path)
|
||||
if not full_path.exists():
|
||||
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||
continue
|
||||
|
||||
file_count = 0
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
for root, dirs, files in os.walk(full_path):
|
||||
# 跳过 __pycache__ 目录
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
|
||||
for file in files:
|
||||
# 跳过 .pyc 文件
|
||||
if file.endswith(".pyc"):
|
||||
continue
|
||||
|
||||
file_path = Path(root) / file
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(full_path)
|
||||
archive_path = f"directories/{dir_name}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
file_count += 1
|
||||
total_size += file_path.stat().st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"导出文件 {file_path} 失败: {e}")
|
||||
|
||||
stats[dir_name] = {"files": file_count, "size": total_size}
|
||||
logger.debug(
|
||||
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出目录 {dir_path} 失败: {e}")
|
||||
stats[dir_name] = {"files": 0, "size": 0}
|
||||
|
||||
return stats
|
||||
|
||||
async def _export_attachments(
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""导出附件文件"""
|
||||
for attachment in attachments:
|
||||
try:
|
||||
file_path = attachment.get("path", "")
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 使用 attachment_id 作为文件名
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出附件失败: {e}")
|
||||
|
||||
def _model_to_dict(self, record: Any) -> dict:
|
||||
"""将 SQLModel 实例转换为字典
|
||||
|
||||
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
|
||||
"""
|
||||
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
|
||||
if hasattr(record, "model_dump"):
|
||||
data = record.model_dump(mode="python")
|
||||
# 处理 datetime 类型
|
||||
for key, value in data.items():
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
return data
|
||||
|
||||
# 回退到手动提取
|
||||
data = {}
|
||||
# 使用 inspect 获取表信息
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
mapper = sa_inspect(record.__class__)
|
||||
for column in mapper.columns:
|
||||
value = getattr(record, column.name)
|
||||
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
data[column.name] = value
|
||||
return data
|
||||
|
||||
def _add_checksum(self, path: str, content: str | bytes) -> None:
|
||||
"""计算并添加文件校验和"""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
self._checksums[path] = f"sha256:{checksum}"
|
||||
|
||||
def _generate_manifest(
|
||||
self,
|
||||
main_data: dict[str, list[dict]],
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
dir_stats: dict[str, dict[str, int]] | None = None,
|
||||
) -> dict:
|
||||
"""生成备份清单"""
|
||||
if dir_stats is None:
|
||||
dir_stats = {}
|
||||
# 收集知识库 ID
|
||||
kb_document_tables = {}
|
||||
if self.kb_manager:
|
||||
for kb_id in self.kb_manager.kb_insts.keys():
|
||||
kb_document_tables[kb_id] = "documents"
|
||||
|
||||
# 收集附件文件列表
|
||||
attachment_files = []
|
||||
for attachment in main_data.get("attachments", []):
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
path = attachment.get("path", "")
|
||||
if attachment_id and path:
|
||||
ext = os.path.splitext(path)[1]
|
||||
attachment_files.append(f"{attachment_id}{ext}")
|
||||
|
||||
# 收集知识库媒体文件
|
||||
kb_media_files: dict[str, list[str]] = {}
|
||||
if self.kb_manager:
|
||||
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
|
||||
media_files: list[str] = []
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if media_dir.exists():
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
media_files.append(file)
|
||||
if media_files:
|
||||
kb_media_files[kb_id] = media_files
|
||||
|
||||
manifest = {
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
"kb_db": "v1",
|
||||
},
|
||||
"tables": {
|
||||
"main_db": list(main_data.keys()),
|
||||
"kb_metadata": list(kb_meta_data.keys()),
|
||||
"kb_documents": kb_document_tables,
|
||||
},
|
||||
"files": {
|
||||
"attachments": attachment_files,
|
||||
"kb_media": kb_media_files,
|
||||
},
|
||||
"directories": list(dir_stats.keys()),
|
||||
"checksums": self._checksums,
|
||||
"statistics": {
|
||||
"main_db": {
|
||||
table: len(records) for table, records in main_data.items()
|
||||
},
|
||||
"kb_metadata": {
|
||||
table: len(records) for table, records in kb_meta_data.items()
|
||||
},
|
||||
"directories": dir_stats,
|
||||
},
|
||||
}
|
||||
|
||||
return manifest
|
||||
@@ -1,761 +0,0 @@
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
负责从 ZIP 备份文件恢复所有数据。
|
||||
导入时进行版本校验:
|
||||
- 主版本(前两位)不同时直接拒绝导入
|
||||
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||||
- 版本匹配时也需要用户确认
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_knowledge_base_path,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
def _get_major_version(version_str: str) -> str:
|
||||
"""提取版本的主版本部分(前两位)
|
||||
|
||||
Args:
|
||||
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||||
|
||||
Returns:
|
||||
主版本字符串,如 "4.9", "4.10"
|
||||
"""
|
||||
if not version_str:
|
||||
return "0.0"
|
||||
# 移除 v 前缀和预发布标签
|
||||
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||||
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[0]}.{parts[1]}"
|
||||
elif len(parts) == 1 and parts[0]:
|
||||
return f"{parts[0]}.0"
|
||||
return "0.0"
|
||||
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
KB_PATH = get_astrbot_knowledge_base_path()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportPreCheckResult:
|
||||
"""导入预检查结果
|
||||
|
||||
用于在实际导入前检查备份文件的版本兼容性,
|
||||
并返回确认信息让用户决定是否继续导入。
|
||||
"""
|
||||
|
||||
# 检查是否通过(文件有效且版本可导入)
|
||||
valid: bool = False
|
||||
# 是否可以导入(版本兼容)
|
||||
can_import: bool = False
|
||||
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||||
version_status: str = ""
|
||||
# 备份文件中的 AstrBot 版本
|
||||
backup_version: str = ""
|
||||
# 当前运行的 AstrBot 版本
|
||||
current_version: str = VERSION
|
||||
# 备份创建时间
|
||||
backup_time: str = ""
|
||||
# 确认消息(显示给用户)
|
||||
confirm_message: str = ""
|
||||
# 警告消息列表
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# 错误消息(如果检查失败)
|
||||
error: str = ""
|
||||
# 备份包含的内容摘要
|
||||
backup_summary: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"valid": self.valid,
|
||||
"can_import": self.can_import,
|
||||
"version_status": self.version_status,
|
||||
"backup_version": self.backup_version,
|
||||
"current_version": self.current_version,
|
||||
"backup_time": self.backup_time,
|
||||
"confirm_message": self.confirm_message,
|
||||
"warnings": self.warnings,
|
||||
"error": self.error,
|
||||
"backup_summary": self.backup_summary,
|
||||
}
|
||||
|
||||
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.success = True
|
||||
self.imported_tables: dict[str, int] = {}
|
||||
self.imported_files: dict[str, int] = {}
|
||||
self.imported_directories: dict[str, int] = {}
|
||||
self.warnings: list[str] = []
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_warning(self, msg: str) -> None:
|
||||
self.warnings.append(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
def add_error(self, msg: str) -> None:
|
||||
self.errors.append(msg)
|
||||
self.success = False
|
||||
logger.error(msg)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"imported_tables": self.imported_tables,
|
||||
"imported_files": self.imported_files,
|
||||
"imported_directories": self.imported_directories,
|
||||
"warnings": self.warnings,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class AstrBotImporter:
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
导入备份文件中的所有数据,包括:
|
||||
- 主数据库所有表
|
||||
- 知识库元数据和文档
|
||||
- 配置文件
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
kb_root_dir: str = KB_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self.kb_root_dir = kb_root_dir
|
||||
|
||||
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||||
"""预检查备份文件
|
||||
|
||||
在实际导入前检查备份文件的有效性和版本兼容性。
|
||||
返回检查结果供前端显示确认对话框。
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
|
||||
Returns:
|
||||
ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
result = ImportPreCheckResult()
|
||||
result.current_version = VERSION
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.error = f"备份文件不存在: {zip_path}"
|
||||
return result
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.error = f"manifest.json 格式错误: {e}"
|
||||
return result
|
||||
|
||||
# 提取基本信息
|
||||
result.backup_version = manifest.get("astrbot_version", "未知")
|
||||
result.backup_time = manifest.get("exported_at", "未知")
|
||||
result.valid = True
|
||||
|
||||
# 构建备份摘要
|
||||
result.backup_summary = {
|
||||
"tables": list(manifest.get("tables", {}).keys()),
|
||||
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||||
"has_config": manifest.get("has_config", False),
|
||||
"directories": manifest.get("directories", []),
|
||||
}
|
||||
|
||||
# 检查版本兼容性
|
||||
version_check = self._check_version_compatibility(result.backup_version)
|
||||
result.version_status = version_check["status"]
|
||||
result.can_import = version_check["can_import"]
|
||||
|
||||
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||||
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.error = "无效的 ZIP 文件"
|
||||
return result
|
||||
except Exception as e:
|
||||
result.error = f"检查备份文件失败: {e}"
|
||||
return result
|
||||
|
||||
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||||
"""检查版本兼容性
|
||||
|
||||
规则:
|
||||
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||||
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||||
|
||||
Returns:
|
||||
dict: {status, can_import, message}
|
||||
"""
|
||||
if not backup_version:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": "备份文件缺少版本信息",
|
||||
}
|
||||
|
||||
# 提取主版本(前两位)进行比较
|
||||
backup_major = _get_major_version(backup_version)
|
||||
current_major = _get_major_version(VERSION)
|
||||
|
||||
# 比较主版本
|
||||
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": (
|
||||
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||||
),
|
||||
}
|
||||
|
||||
# 比较完整版本
|
||||
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||||
if version_cmp != 0:
|
||||
return {
|
||||
"status": "minor_diff",
|
||||
"can_import": True,
|
||||
"message": (
|
||||
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "match",
|
||||
"can_import": True,
|
||||
"message": "版本匹配",
|
||||
}
|
||||
|
||||
async def import_all(
|
||||
self,
|
||||
zip_path: str,
|
||||
mode: str = "replace", # "replace" 清空后导入
|
||||
progress_callback: Any | None = None,
|
||||
) -> ImportResult:
|
||||
"""从 ZIP 文件导入所有数据
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.add_error(f"备份文件不存在: {zip_path}")
|
||||
return result
|
||||
|
||||
logger.info(f"开始从 {zip_path} 导入备份")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 1. 读取并验证 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||||
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.add_error("备份文件缺少 manifest.json")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.add_error(f"manifest.json 格式错误: {e}")
|
||||
return result
|
||||
|
||||
# 版本校验
|
||||
try:
|
||||
self._validate_version(manifest)
|
||||
except ValueError as e:
|
||||
result.add_error(str(e))
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 100, 100, "验证完成")
|
||||
|
||||
# 2. 导入主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||||
|
||||
try:
|
||||
main_data_content = zf.read("databases/main_db.json")
|
||||
main_data = json.loads(main_data_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_main_db()
|
||||
|
||||
imported = await self._import_main_database(main_data)
|
||||
result.imported_tables.update(imported)
|
||||
except Exception as e:
|
||||
result.add_error(f"导入主数据库失败: {e}")
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||||
|
||||
# 3. 导入知识库
|
||||
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||||
|
||||
try:
|
||||
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||||
kb_meta_data = json.loads(kb_meta_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_kb_data()
|
||||
|
||||
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||||
|
||||
# 4. 导入配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||||
|
||||
if "config/cmd_config.json" in zf.namelist():
|
||||
try:
|
||||
config_content = zf.read("config/cmd_config.json")
|
||||
# 备份现有配置
|
||||
if os.path.exists(self.config_path):
|
||||
backup_path = f"{self.config_path}.bak"
|
||||
shutil.copy2(self.config_path, backup_path)
|
||||
|
||||
with open(self.config_path, "wb") as f:
|
||||
f.write(config_content)
|
||||
result.imported_files["config"] = 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入配置文件失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||||
|
||||
# 5. 导入附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||||
|
||||
attachment_count = await self._import_attachments(
|
||||
zf, main_data.get("attachments", [])
|
||||
)
|
||||
result.imported_files["attachments"] = attachment_count
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||||
|
||||
# 6. 导入插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导入插件和数据目录..."
|
||||
)
|
||||
|
||||
dir_stats = await self._import_directories(zf, manifest, result)
|
||||
result.imported_directories = dir_stats
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导入完成")
|
||||
|
||||
logger.info(f"备份导入完成: {result.to_dict()}")
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.add_error("无效的 ZIP 文件")
|
||||
return result
|
||||
except Exception as e:
|
||||
result.add_error(f"导入失败: {e}")
|
||||
return result
|
||||
|
||||
def _validate_version(self, manifest: dict) -> None:
|
||||
"""验证版本兼容性 - 仅允许相同主版本导入
|
||||
|
||||
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||||
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||||
"""
|
||||
backup_version = manifest.get("astrbot_version")
|
||||
if not backup_version:
|
||||
raise ValueError("备份文件缺少版本信息")
|
||||
|
||||
# 使用新的版本兼容性检查
|
||||
version_check = self._check_version_compatibility(backup_version)
|
||||
|
||||
if version_check["status"] == "major_diff":
|
||||
raise ValueError(version_check["message"])
|
||||
|
||||
# minor_diff 和 match 都允许导入
|
||||
if version_check["status"] == "minor_diff":
|
||||
logger.warning(f"版本差异警告: {version_check['message']}")
|
||||
|
||||
async def _clear_main_db(self) -> None:
|
||||
"""清空主数据库所有表"""
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||
|
||||
async def _clear_kb_data(self) -> None:
|
||||
"""清空知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 清空知识库元数据表
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空知识库表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||||
|
||||
# 删除知识库文件目录
|
||||
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||||
try:
|
||||
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||||
await kb_helper.terminate()
|
||||
if kb_helper.kb_dir.exists():
|
||||
shutil.rmtree(kb_helper.kb_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||||
|
||||
self.kb_manager.kb_insts.clear()
|
||||
|
||||
async def _import_main_database(
|
||||
self, data: dict[str, list[dict]]
|
||||
) -> dict[str, int]:
|
||||
"""导入主数据库数据"""
|
||||
imported: dict[str, int] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in data.items():
|
||||
model_class = MAIN_DB_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
logger.warning(f"未知的表: {table_name}")
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
# 转换 datetime 字符串为 datetime 对象
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||||
|
||||
imported[table_name] = count
|
||||
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||||
|
||||
return imported
|
||||
|
||||
async def _import_knowledge_bases(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
result: ImportResult,
|
||||
) -> None:
|
||||
"""导入知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 1. 导入知识库元数据
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in kb_meta_data.items():
|
||||
model_class = KB_METADATA_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||||
|
||||
result.imported_tables[f"kb_{table_name}"] = count
|
||||
|
||||
# 2. 导入每个知识库的文档和文件
|
||||
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||||
kb_id = kb_data.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
|
||||
# 创建知识库目录
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 导入文档数据
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
if doc_path in zf.namelist():
|
||||
try:
|
||||
doc_content = zf.read(doc_path)
|
||||
doc_data = json.loads(doc_content)
|
||||
|
||||
# 导入到文档存储数据库
|
||||
await self._import_kb_documents(kb_id, doc_data)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||||
|
||||
# 导入 FAISS 索引
|
||||
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
if faiss_path in zf.namelist():
|
||||
try:
|
||||
target_path = kb_dir / "index.faiss"
|
||||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||
|
||||
# 导入媒体文件
|
||||
media_prefix = f"files/kb_media/{kb_id}/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(media_prefix):
|
||||
try:
|
||||
rel_path = name[len(media_prefix) :]
|
||||
target_path = kb_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||
|
||||
# 3. 重新加载知识库实例
|
||||
await self.kb_manager.load_kbs()
|
||||
|
||||
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||||
"""导入知识库文档到向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
doc_db_path = kb_dir / "doc.db"
|
||||
|
||||
# 初始化文档存储
|
||||
doc_storage = DocumentStorage(str(doc_db_path))
|
||||
await doc_storage.initialize()
|
||||
|
||||
try:
|
||||
documents = doc_data.get("documents", [])
|
||||
for doc in documents:
|
||||
try:
|
||||
await doc_storage.insert_document(
|
||||
doc_id=doc.get("doc_id", ""),
|
||||
text=doc.get("text", ""),
|
||||
metadata=json.loads(doc.get("metadata", "{}")),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导入文档块失败: {e}")
|
||||
finally:
|
||||
await doc_storage.close()
|
||||
|
||||
async def _import_attachments(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
attachments: list[dict],
|
||||
) -> int:
|
||||
"""导入附件文件"""
|
||||
count = 0
|
||||
|
||||
attachments_dir = Path(self.config_path).parent / "attachments"
|
||||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
attachment_prefix = "files/attachments/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||||
try:
|
||||
# 从附件记录中找到原始路径
|
||||
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||||
original_path = None
|
||||
for att in attachments:
|
||||
if att.get("attachment_id") == attachment_id:
|
||||
original_path = att.get("path")
|
||||
break
|
||||
|
||||
if original_path:
|
||||
target_path = Path(original_path)
|
||||
else:
|
||||
target_path = attachments_dir / os.path.basename(name)
|
||||
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||
|
||||
return count
|
||||
|
||||
async def _import_directories(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
manifest: dict,
|
||||
result: ImportResult,
|
||||
) -> dict[str, int]:
|
||||
"""导入插件和其他数据目录
|
||||
|
||||
Args:
|
||||
zf: ZIP 文件对象
|
||||
manifest: 备份清单
|
||||
result: 导入结果对象
|
||||
|
||||
Returns:
|
||||
dict: 每个目录导入的文件数量
|
||||
"""
|
||||
dir_stats: dict[str, int] = {}
|
||||
|
||||
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||||
backup_version = manifest.get("version", "1.0")
|
||||
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||||
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||||
return dir_stats
|
||||
|
||||
backed_up_dirs = manifest.get("directories", [])
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name in backed_up_dirs:
|
||||
if dir_name not in backup_directories:
|
||||
result.add_warning(f"未知的目录类型: {dir_name}")
|
||||
continue
|
||||
|
||||
target_dir = Path(backup_directories[dir_name])
|
||||
archive_prefix = f"directories/{dir_name}/"
|
||||
|
||||
file_count = 0
|
||||
|
||||
try:
|
||||
# 获取该目录下的所有文件
|
||||
dir_files = [
|
||||
name
|
||||
for name in zf.namelist()
|
||||
if name.startswith(archive_prefix) and name != archive_prefix
|
||||
]
|
||||
|
||||
if not dir_files:
|
||||
continue
|
||||
|
||||
# 备份现有目录(如果存在)
|
||||
if target_dir.exists():
|
||||
backup_path = Path(f"{target_dir}.bak")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(str(target_dir), str(backup_path))
|
||||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||
|
||||
# 创建目标目录
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解压文件
|
||||
for name in dir_files:
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = name[len(archive_prefix) :]
|
||||
if not rel_path: # 跳过目录条目
|
||||
continue
|
||||
|
||||
target_path = target_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
file_count += 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||
|
||||
dir_stats[dir_name] = file_count
|
||||
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||||
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||||
dir_stats[dir_name] = 0
|
||||
|
||||
return dir_stats
|
||||
|
||||
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||||
"""转换 datetime 字符串字段为 datetime 对象"""
|
||||
result = row.copy()
|
||||
|
||||
# 获取模型的 datetime 字段
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
try:
|
||||
mapper = sa_inspect(model_class)
|
||||
for column in mapper.columns:
|
||||
if column.name in result and result[column.name] is not None:
|
||||
# 检查是否是 datetime 类型的列
|
||||
from sqlalchemy import DateTime
|
||||
|
||||
if isinstance(column.type, DateTime):
|
||||
value = result[column.name]
|
||||
if isinstance(value, str):
|
||||
# 解析 ISO 格式的日期时间字符串
|
||||
result[column.name] = datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -24,10 +24,6 @@ class AstrBotConfig(dict):
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
"""
|
||||
|
||||
config_path: str
|
||||
default_config: dict
|
||||
schema: dict | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
@@ -80,8 +76,6 @@ class AstrBotConfig(dict):
|
||||
if v["type"] == "object":
|
||||
conf[k] = {}
|
||||
_parse_schema(v["items"], conf[k])
|
||||
elif v["type"] == "template_list":
|
||||
conf[k] = default
|
||||
else:
|
||||
conf[k] = default
|
||||
|
||||
|
||||
+260
-629
File diff suppressed because it is too large
Load Diff
@@ -1,111 +0,0 @@
|
||||
"""
|
||||
配置元数据国际化工具
|
||||
|
||||
提供配置元数据的国际化键转换功能
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ConfigMetadataI18n:
|
||||
"""配置元数据国际化转换器"""
|
||||
|
||||
@staticmethod
|
||||
def _get_i18n_key(group: str, section: str, field: str, attr: str) -> str:
|
||||
"""
|
||||
生成国际化键
|
||||
|
||||
Args:
|
||||
group: 配置组,如 'ai_group', 'platform_group'
|
||||
section: 配置节,如 'agent_runner', 'general'
|
||||
field: 字段名,如 'enable', 'default_provider'
|
||||
attr: 属性类型,如 'description', 'hint', 'labels'
|
||||
|
||||
Returns:
|
||||
国际化键,格式如: 'ai_group.agent_runner.enable.description'
|
||||
"""
|
||||
if field:
|
||||
return f"{group}.{section}.{field}.{attr}"
|
||||
else:
|
||||
return f"{group}.{section}.{attr}"
|
||||
|
||||
@staticmethod
|
||||
def convert_to_i18n_keys(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
将配置元数据转换为使用国际化键
|
||||
|
||||
Args:
|
||||
metadata: 原始配置元数据字典
|
||||
|
||||
Returns:
|
||||
使用国际化键的配置元数据字典
|
||||
"""
|
||||
result = {}
|
||||
|
||||
for group_key, group_data in metadata.items():
|
||||
group_result = {
|
||||
"name": f"{group_key}.name",
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
for section_key, section_data in group_data.get("metadata", {}).items():
|
||||
section_result = {
|
||||
"description": f"{group_key}.{section_key}.description",
|
||||
"type": section_data.get("type"),
|
||||
}
|
||||
|
||||
# 复制其他属性
|
||||
for key in ["items", "condition", "_special", "invisible"]:
|
||||
if key in section_data:
|
||||
section_result[key] = section_data[key]
|
||||
|
||||
# 处理 hint
|
||||
if "hint" in section_data:
|
||||
section_result["hint"] = f"{group_key}.{section_key}.hint"
|
||||
|
||||
# 处理 items 中的字段
|
||||
if "items" in section_data and isinstance(section_data["items"], dict):
|
||||
items_result = {}
|
||||
for field_key, field_data in section_data["items"].items():
|
||||
# 处理嵌套的点号字段名(如 provider_settings.enable)
|
||||
field_name = field_key
|
||||
|
||||
field_result = {}
|
||||
|
||||
# 复制基本属性
|
||||
for attr in [
|
||||
"type",
|
||||
"condition",
|
||||
"_special",
|
||||
"invisible",
|
||||
"options",
|
||||
"slider",
|
||||
]:
|
||||
if attr in field_data:
|
||||
field_result[attr] = field_data[attr]
|
||||
|
||||
# 转换文本属性为国际化键
|
||||
if "description" in field_data:
|
||||
field_result["description"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.description"
|
||||
)
|
||||
|
||||
if "hint" in field_data:
|
||||
field_result["hint"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.hint"
|
||||
)
|
||||
|
||||
if "labels" in field_data:
|
||||
field_result["labels"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.labels"
|
||||
)
|
||||
|
||||
items_result[field_key] = field_result
|
||||
|
||||
section_result["items"] = items_result
|
||||
|
||||
group_result["metadata"][section_key] = section_result
|
||||
|
||||
result[group_key] = group_result
|
||||
|
||||
return result
|
||||
@@ -69,7 +69,6 @@ class ConversationManager:
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
token_usage=conv_v2.token_usage,
|
||||
)
|
||||
|
||||
async def new_conversation(
|
||||
@@ -257,7 +256,6 @@ class ConversationManager:
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""更新会话的对话.
|
||||
|
||||
@@ -265,7 +263,6 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
token_usage (int | None): token 使用量。None 表示不更新
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
@@ -277,7 +274,6 @@ class ConversationManager:
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
|
||||
@@ -16,13 +16,15 @@ import time
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core import LogBroker, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.memory.memory_manager import MemoryManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
@@ -33,8 +35,6 @@ from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.llm_metadata import update_llm_metadata
|
||||
from astrbot.core.utils.migra_helper import migra
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
from .event_bus import EventBus
|
||||
@@ -90,7 +90,6 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化 UMOP 配置路由器
|
||||
self.umop_config_router = UmopConfigRouter(sp=sp)
|
||||
await self.umop_config_router.initialize()
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
@@ -99,16 +98,18 @@ class AstrBotCoreLifecycle:
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
# apply migration
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
try:
|
||||
await migra(
|
||||
self.db,
|
||||
self.astrbot_config_mgr,
|
||||
self.umop_config_router,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
|
||||
except Exception as e:
|
||||
logger.error(f"AstrBot migration failed: {e!s}")
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for webchat session
|
||||
try:
|
||||
await migrate_webchat_session(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
@@ -136,6 +137,8 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
# 初始化记忆管理器
|
||||
self.memory_manager = MemoryManager()
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
@@ -149,6 +152,7 @@ class AstrBotCoreLifecycle:
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
self.kb_manager,
|
||||
self.memory_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -187,8 +191,6 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
asyncio.create_task(update_llm_metadata())
|
||||
|
||||
def _load(self) -> None:
|
||||
"""加载事件总线和任务并初始化."""
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
@@ -201,7 +203,7 @@ class AstrBotCoreLifecycle:
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
extra_tasks = []
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
|
||||
tasks_ = [event_bus_task, *extra_tasks]
|
||||
for task in tasks_:
|
||||
|
||||
+4
-105
@@ -5,12 +5,11 @@ from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -33,7 +32,7 @@ class BaseDatabase(abc.ABC):
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
@@ -152,7 +151,6 @@ class BaseDatabase(abc.ABC):
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
@@ -175,7 +173,7 @@ class BaseDatabase(abc.ABC):
|
||||
content: dict,
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> PlatformMessageHistory:
|
||||
) -> None:
|
||||
"""Insert a new platform message history record."""
|
||||
...
|
||||
|
||||
@@ -200,14 +198,6 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get platform message history for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_message_history_by_id(
|
||||
self,
|
||||
message_id: int,
|
||||
) -> PlatformMessageHistory | None:
|
||||
"""Get a platform message history record by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_attachment(
|
||||
self,
|
||||
@@ -223,27 +213,6 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get an attachment by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]:
|
||||
"""Get multiple attachments by their IDs."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||
"""Delete an attachment by its ID.
|
||||
|
||||
Returns True if the attachment was deleted, False if it was not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
"""Delete multiple attachments by their IDs.
|
||||
|
||||
Returns the number of attachments deleted.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_persona(
|
||||
self,
|
||||
@@ -317,76 +286,6 @@ class BaseDatabase(abc.ABC):
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
"""Get all stored command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
|
||||
"""Fetch a single command configuration by handler."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
"""Create or update a command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
"""Delete a single command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
"""Bulk delete command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
"""List recorded command conflict entries."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
"""Create or update a conflict record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
"""Delete conflict records."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
|
||||
@@ -70,7 +70,6 @@ async def migration_conversation_table(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
@@ -208,7 +207,6 @@ async def migration_webchat_data(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Migration script to add token_usage column to conversations table.
|
||||
|
||||
This migration adds the token_usage field to track token consumption for each conversation.
|
||||
|
||||
Changes:
|
||||
- Adds token_usage column to conversations table (default: 0)
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
|
||||
async def migrate_token_usage(db_helper: BaseDatabase):
|
||||
"""Add token_usage column to conversations table.
|
||||
|
||||
This migration adds a new column to track token consumption in conversations.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_token_usage_1"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
|
||||
|
||||
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 检查列是否已存在
|
||||
result = await session.execute(text("PRAGMA table_info(conversations)"))
|
||||
columns = result.fetchall()
|
||||
column_names = [col[1] for col in columns]
|
||||
|
||||
if "token_usage" in column_names:
|
||||
logger.info("token_usage 列已存在,跳过迁移")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_token_usage_1", True
|
||||
)
|
||||
return
|
||||
|
||||
# 添加 token_usage 列
|
||||
await session.execute(
|
||||
text(
|
||||
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("token_usage 列添加成功")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
|
||||
logger.info("token_usage 迁移完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -25,7 +25,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_webchat_session_1"
|
||||
"global", "global", "migration_done_webchat_session"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
@@ -43,7 +43,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
||||
)
|
||||
.where(col(PlatformMessageHistory.platform_id) == "webchat")
|
||||
.where(col(PlatformMessageHistory.sender_id) != "bot")
|
||||
.where(col(PlatformMessageHistory.sender_id) == "astrbot")
|
||||
.group_by(col(PlatformMessageHistory.user_id))
|
||||
)
|
||||
|
||||
@@ -53,7 +53,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
if not webchat_users:
|
||||
logger.info("没有找到需要迁移的 WebChat 数据")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_webchat_session_1", True
|
||||
"global", "global", "migration_done_webchat_session", True
|
||||
)
|
||||
return
|
||||
|
||||
@@ -124,7 +124,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
logger.info("没有新会话需要迁移")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_webchat_session_1", True)
|
||||
await sp.put_async("global", "global", "migration_done_webchat_session", True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
|
||||
@@ -127,7 +127,7 @@ class SQLiteDatabase:
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: tuple | None = None):
|
||||
def _exec_sql(self, sql: str, params: tuple = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -224,11 +224,9 @@ class SQLiteDatabase:
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform)
|
||||
return Stats(platform, [], [])
|
||||
|
||||
def get_conversation_by_user_id(
|
||||
self, user_id: str, cid: str
|
||||
) -> Conversation | None:
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -260,7 +258,7 @@ class SQLiteDatabase:
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> list[Conversation]:
|
||||
def get_conversations(self, user_id: str) -> tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
|
||||
+16
-83
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__: str = "platform_stats"
|
||||
__tablename__ = "platform_stats" # type: ignore
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
@@ -31,10 +31,9 @@ class PlatformStat(SQLModel, table=True):
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__: str = "conversations"
|
||||
__tablename__ = "conversations" # type: ignore
|
||||
|
||||
inner_conversation_id: int | None = Field(
|
||||
default=None,
|
||||
inner_conversation_id: int = Field(
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
@@ -54,11 +53,6 @@ class ConversationV2(SQLModel, table=True):
|
||||
)
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
persona_id: str | None = Field(default=None)
|
||||
token_usage: int = Field(default=0, nullable=False)
|
||||
"""content is a list of OpenAI-formated messages in list[dict] format.
|
||||
token_usage is the total token value of the messages.
|
||||
when 0, will use estimated token counter.
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -74,7 +68,7 @@ class Persona(SQLModel, table=True):
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__: str = "personas"
|
||||
__tablename__ = "personas" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -104,7 +98,7 @@ class Persona(SQLModel, table=True):
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__: str = "preferences"
|
||||
__tablename__ = "preferences" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
@@ -140,7 +134,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__: str = "platform_message_history"
|
||||
__tablename__ = "platform_message_history" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -168,7 +162,7 @@ class PlatformSession(SQLModel, table=True):
|
||||
Each session can have multiple conversations (对话) associated with it.
|
||||
"""
|
||||
|
||||
__tablename__: str = "platform_sessions"
|
||||
__tablename__ = "platform_sessions" # type: ignore
|
||||
|
||||
inner_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -179,7 +173,7 @@ class PlatformSession(SQLModel, table=True):
|
||||
max_length=100,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
default_factory=lambda: f"webchat_{uuid.uuid4()}",
|
||||
)
|
||||
platform_id: str = Field(default="webchat", nullable=False)
|
||||
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
|
||||
@@ -209,7 +203,7 @@ class Attachment(SQLModel, table=True):
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__: str = "attachments"
|
||||
__tablename__ = "attachments" # type: ignore
|
||||
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -239,65 +233,6 @@ class Attachment(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class CommandConfig(SQLModel, table=True):
|
||||
"""Per-command configuration overrides for dashboard management."""
|
||||
|
||||
__tablename__ = "command_configs" # type: ignore
|
||||
|
||||
handler_full_name: str = Field(
|
||||
primary_key=True,
|
||||
max_length=512,
|
||||
)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
module_path: str = Field(nullable=False, max_length=255)
|
||||
original_command: str = Field(nullable=False, max_length=255)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
enabled: bool = Field(default=True, nullable=False)
|
||||
keep_original_alias: bool = Field(default=False, nullable=False)
|
||||
conflict_key: str | None = Field(default=None, max_length=255)
|
||||
resolution_strategy: str | None = Field(default=None, max_length=64)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_managed: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class CommandConflict(SQLModel, table=True):
|
||||
"""Conflict tracking for duplicated command names."""
|
||||
|
||||
__tablename__ = "command_conflicts" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conflict_key: str = Field(nullable=False, max_length=255)
|
||||
handler_full_name: str = Field(nullable=False, max_length=512)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
status: str = Field(default="pending", max_length=32)
|
||||
resolution: str | None = Field(default=None, max_length=64)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_generated: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conflict_key",
|
||||
"handler_full_name",
|
||||
name="uix_conflict_handler",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
@@ -318,8 +253,6 @@ class Conversation:
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
token_usage: int = 0
|
||||
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
@@ -328,17 +261,17 @@ class Personality(TypedDict):
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
name: str
|
||||
begin_dialogs: list[str]
|
||||
mood_imitation_dialogs: list[str]
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None
|
||||
tools: list[str] | None = None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict]
|
||||
_mood_imitation_dialogs_processed: str
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
# ====
|
||||
|
||||
+4
-303
@@ -1,18 +1,14 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import typing as T
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -29,7 +25,6 @@ from astrbot.core.db.po import (
|
||||
)
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
TxResult = T.TypeVar("TxResult")
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
@@ -110,8 +105,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
text("""
|
||||
SELECT * FROM platform_stats
|
||||
WHERE timestamp >= :start_time
|
||||
GROUP BY platform_id
|
||||
ORDER BY timestamp DESC
|
||||
GROUP BY platform_id
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
@@ -241,9 +236,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session.add(new_conversation)
|
||||
return new_conversation
|
||||
|
||||
async def update_conversation(
|
||||
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
||||
):
|
||||
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -257,8 +250,6 @@ class SQLiteDatabase(BaseDatabase):
|
||||
values["persona_id"] = persona_id
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if token_usage is not None:
|
||||
values["token_usage"] = token_usage
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
@@ -458,18 +449,6 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query.offset(offset).limit(page_size))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_platform_message_history_by_id(
|
||||
self, message_id: int
|
||||
) -> PlatformMessageHistory | None:
|
||||
"""Get a platform message history record by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(PlatformMessageHistory).where(
|
||||
PlatformMessageHistory.id == message_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def insert_attachment(self, path, type, mime_type):
|
||||
"""Insert a new attachment record."""
|
||||
async with self.get_db() as session:
|
||||
@@ -491,48 +470,6 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_attachments(self, attachment_ids: list[str]) -> list:
|
||||
"""Get multiple attachments by their IDs."""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||
"""Delete an attachment by its ID.
|
||||
|
||||
Returns True if the attachment was deleted, False if it was not found.
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id) == attachment_id
|
||||
)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
"""Delete multiple attachments by their IDs.
|
||||
|
||||
Returns the number of attachments deleted.
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return 0
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount
|
||||
|
||||
async def insert_persona(
|
||||
self,
|
||||
persona_id,
|
||||
@@ -678,242 +615,6 @@ class SQLiteDatabase(BaseDatabase):
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Command Configuration & Conflict Tracking
|
||||
# ====
|
||||
|
||||
async def _run_in_tx(
|
||||
self,
|
||||
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
||||
) -> TxResult:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
return await fn(session)
|
||||
|
||||
@staticmethod
|
||||
def _apply_updates(model, **updates) -> None:
|
||||
for field, value in updates.items():
|
||||
if value is not None:
|
||||
setattr(model, field, value)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_config(
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
return CommandConfig(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=True if enabled is None else enabled,
|
||||
keep_original_alias=False
|
||||
if keep_original_alias is None
|
||||
else keep_original_alias,
|
||||
conflict_key=conflict_key or original_command,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=bool(auto_managed),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_conflict(
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
return CommandConflict(
|
||||
conflict_key=conflict_key,
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
status=status or "pending",
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=bool(auto_generated),
|
||||
)
|
||||
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(select(CommandConfig))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
) -> CommandConfig | None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
return await session.get(CommandConfig, handler_full_name)
|
||||
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
async def _op(session: AsyncSession) -> CommandConfig:
|
||||
config = await session.get(CommandConfig, handler_full_name)
|
||||
if not config:
|
||||
config = self._new_command_config(
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
module_path,
|
||||
original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
session.add(config)
|
||||
else:
|
||||
self._apply_updates(
|
||||
config,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
await self.delete_command_configs([handler_full_name])
|
||||
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
if not handler_full_names:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConfig).where(
|
||||
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
||||
),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(CommandConflict)
|
||||
if status:
|
||||
query = query.where(CommandConflict.status == status)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
async def _op(session: AsyncSession) -> CommandConflict:
|
||||
result = await session.execute(
|
||||
select(CommandConflict).where(
|
||||
CommandConflict.conflict_key == conflict_key,
|
||||
CommandConflict.handler_full_name == handler_full_name,
|
||||
),
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
record = self._new_command_conflict(
|
||||
conflict_key,
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
self._apply_updates(
|
||||
record,
|
||||
plugin_name=plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
@@ -1093,7 +794,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
await session.execute(
|
||||
update(PlatformSession)
|
||||
.where(col(PlatformSession.session_id) == session_id)
|
||||
.where(col(PlatformSession.session_id == session_id))
|
||||
.values(**values),
|
||||
)
|
||||
|
||||
@@ -1104,6 +805,6 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(PlatformSession).where(
|
||||
col(PlatformSession.session_id) == session_id,
|
||||
col(PlatformSession.session_id == session_id),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
class ResultData(TypedDict):
|
||||
id: str
|
||||
doc_id: str
|
||||
text: str
|
||||
metadata: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
similarity: float
|
||||
data: dict
|
||||
data: ResultData | dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
|
||||
@@ -90,6 +90,4 @@ class EmbeddingStorage:
|
||||
path (str): 保存索引的路径
|
||||
|
||||
"""
|
||||
if self.index is None:
|
||||
return
|
||||
faiss.write_index(self.index, self.path)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EventBus:
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
@@ -40,11 +40,6 @@ class EventBus:
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
if not scheduler:
|
||||
logger.error(
|
||||
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
|
||||
)
|
||||
continue
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
|
||||
@@ -149,16 +149,8 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
分割后的文本块列表
|
||||
|
||||
"""
|
||||
if chunk_size is None:
|
||||
chunk_size = self.chunk_size
|
||||
if overlap is None:
|
||||
overlap = self.chunk_overlap
|
||||
if chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be greater than 0")
|
||||
if overlap < 0:
|
||||
raise ValueError("chunk_overlap must be non-negative")
|
||||
if overlap >= chunk_size:
|
||||
raise ValueError("chunk_overlap must be less than chunk_size")
|
||||
chunk_size = chunk_size or self.chunk_size
|
||||
overlap = overlap or self.chunk_overlap
|
||||
result = []
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
end = min(i + chunk_size, len(text))
|
||||
|
||||
@@ -166,11 +166,7 @@ class RetrievalManager:
|
||||
# 5. Rerank
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db = kb_options[kb_id]["vec_db"]
|
||||
if not isinstance(vec_db, FaissVecDB):
|
||||
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
|
||||
continue
|
||||
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||
if (
|
||||
vec_db
|
||||
|
||||
+2
-3
@@ -24,7 +24,6 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from collections import deque
|
||||
|
||||
@@ -58,7 +57,7 @@ def is_plugin_path(pathname):
|
||||
return False
|
||||
|
||||
norm_path = os.path.normpath(pathname)
|
||||
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
|
||||
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
||||
|
||||
|
||||
def get_short_level_name(level_name):
|
||||
@@ -149,7 +148,7 @@ class LogQueueHandler(logging.Handler):
|
||||
self.log_broker.publish(
|
||||
{
|
||||
"level": record.levelname,
|
||||
"time": time.time(),
|
||||
"time": record.asctime,
|
||||
"data": log_entry,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,822 @@
|
||||
{
|
||||
"type": "excalidraw",
|
||||
"version": 2,
|
||||
"source": "https://marketplace.visualstudio.com/items?itemName=pomdtr.excalidraw-editor",
|
||||
"elements": [
|
||||
{
|
||||
"id": "l6cYurMvF69IM4Kc33Qou",
|
||||
"type": "rectangle",
|
||||
"x": 173.140625,
|
||||
"y": -29.0234375,
|
||||
"width": 92.95703125,
|
||||
"height": 77.109375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a0",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1409469537,
|
||||
"version": 91,
|
||||
"versionNonce": 307958671,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "1ZvS6t8U6ihUjNU0dakgl",
|
||||
"type": "arrow",
|
||||
"x": 409.30859375,
|
||||
"y": 9.6875,
|
||||
"width": 118.2734375,
|
||||
"height": 1.9609375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a1",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 326508865,
|
||||
"version": 120,
|
||||
"versionNonce": 199367023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
-118.2734375,
|
||||
-1.9609375
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "tfdUGiJdcMoOHGfqFHXK6",
|
||||
"type": "text",
|
||||
"x": 153.46875,
|
||||
"y": -70.9765625,
|
||||
"width": 136.4598846435547,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a2",
|
||||
"roundness": null,
|
||||
"seed": 688712865,
|
||||
"version": 67,
|
||||
"versionNonce": 300660705,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703743816,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FAISS+SQLite",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FAISS+SQLite",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "AeL3kEB9a8_TAvAXpAbpl",
|
||||
"type": "text",
|
||||
"x": 438.36328125,
|
||||
"y": -3.78125,
|
||||
"width": 116.109375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a3",
|
||||
"roundness": null,
|
||||
"seed": 788579535,
|
||||
"version": 33,
|
||||
"versionNonce": 946602095,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703932431,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FACT",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FACT",
|
||||
"autoResize": false,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Pe3TeMZvxQ8tRTcbD5v6P",
|
||||
"type": "arrow",
|
||||
"x": 297.125,
|
||||
"y": 40.2578125,
|
||||
"width": 120.2421875,
|
||||
"height": 1.421875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a4",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1146229999,
|
||||
"version": 44,
|
||||
"versionNonce": 636917679,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703759050,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
120.2421875,
|
||||
1.421875
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "GhmQoadtQRK8c8aEEbYKQ",
|
||||
"type": "text",
|
||||
"x": 283.53515625,
|
||||
"y": 64.76171875,
|
||||
"width": 130.85989379882812,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a5",
|
||||
"roundness": null,
|
||||
"seed": 1445650959,
|
||||
"version": 79,
|
||||
"versionNonce": 566193167,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703768982,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "top-n Similary\n",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "top-n Similary\n",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"type": "rectangle",
|
||||
"x": 528.1586158430439,
|
||||
"y": -173.43472375183552,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a6",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 223409231,
|
||||
"version": 44,
|
||||
"versionNonce": 1066827105,
|
||||
"isDeleted": false,
|
||||
"boundElements": [
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow"
|
||||
}
|
||||
],
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "2SzqzpJ4C2ymVj8-8vN7H",
|
||||
"type": "text",
|
||||
"x": 548.1480270948795,
|
||||
"y": -211,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a7",
|
||||
"roundness": null,
|
||||
"seed": 1015608623,
|
||||
"version": 23,
|
||||
"versionNonce": 950374849,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704047884,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "CgW6Yf9v0a9q1tsjhDl7b",
|
||||
"type": "text",
|
||||
"x": 568.3099317299038,
|
||||
"y": -154.69469411681115,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aA",
|
||||
"roundness": null,
|
||||
"seed": 452254927,
|
||||
"version": 10,
|
||||
"versionNonce": 972895023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "knvlKpaFZ8lY-73Y-e9W6",
|
||||
"type": "text",
|
||||
"x": 569.11328125,
|
||||
"y": -116.91056665512056,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aB",
|
||||
"roundness": null,
|
||||
"seed": 914644015,
|
||||
"version": 90,
|
||||
"versionNonce": 158135631,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Q7URqvTSMpvj08ye-afTT",
|
||||
"type": "rectangle",
|
||||
"x": 444.515625,
|
||||
"y": 36.7890625,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aC",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1642537601,
|
||||
"version": 19,
|
||||
"versionNonce": 948406575,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703870173,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "JjxBt9cZIZXNTd6CmwyKL",
|
||||
"type": "rectangle",
|
||||
"x": 452.203125,
|
||||
"y": 46.064453125,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aD",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1746916641,
|
||||
"version": 40,
|
||||
"versionNonce": 1650978255,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703871882,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "XGBCPPFnjriqsL8LvLwyQ",
|
||||
"type": "rectangle",
|
||||
"x": 461.56640625,
|
||||
"y": 56.162109375,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aE",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 529794575,
|
||||
"version": 85,
|
||||
"versionNonce": 2131900641,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703874182,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow",
|
||||
"x": 537.6875,
|
||||
"y": 48.203125,
|
||||
"width": 6.615850226297994,
|
||||
"height": 75.81335873223107,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aF",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1982870689,
|
||||
"version": 90,
|
||||
"versionNonce": 25307457,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
6.615850226297994,
|
||||
-75.81335873223107
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": {
|
||||
"elementId": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"focus": 0.6071885090336794,
|
||||
"gap": 24.64453125
|
||||
},
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "jgJgqGMRWcaNX_28wY4CU",
|
||||
"type": "text",
|
||||
"x": 570,
|
||||
"y": 10,
|
||||
"width": 67.11994934082031,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aG",
|
||||
"roundness": null,
|
||||
"seed": 1065220559,
|
||||
"version": 26,
|
||||
"versionNonce": 2115991521,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703959397,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "update",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "update",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "_5pSPPOpp9h1TpFCIc055",
|
||||
"type": "text",
|
||||
"x": 292.36328125,
|
||||
"y": -138.5703125,
|
||||
"width": 122.87992858886719,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aH",
|
||||
"roundness": null,
|
||||
"seed": 51461025,
|
||||
"version": 26,
|
||||
"versionNonce": 1647492655,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703925147,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "ADD Memory",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "ADD Memory",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "YG6MdL14l7lk4ypQNMZ_k",
|
||||
"type": "text",
|
||||
"x": 296.71885397566257,
|
||||
"y": 161.399157096715,
|
||||
"width": 295.27984619140625,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aJ",
|
||||
"roundness": null,
|
||||
"seed": 1183210273,
|
||||
"version": 122,
|
||||
"versionNonce": 1702733281,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704085083,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RETRIEVE Memory (STATIC)",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RETRIEVE Memory (STATIC)",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Foa3VPJYqhj1uAX5mn3n0",
|
||||
"type": "rectangle",
|
||||
"x": 324.7616636099071,
|
||||
"y": 248.63213980937013,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aL",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 995116257,
|
||||
"version": 225,
|
||||
"versionNonce": 1886900225,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "pe3veI_yBFKYtbaJwDKQT",
|
||||
"type": "text",
|
||||
"x": 344.7510748617428,
|
||||
"y": 211.06686356120565,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aM",
|
||||
"roundness": null,
|
||||
"seed": 26673345,
|
||||
"version": 204,
|
||||
"versionNonce": 1004546017,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "bOlhO8AaKE86_43viu5UG",
|
||||
"type": "text",
|
||||
"x": 365.50408375566445,
|
||||
"y": 269.24725381983865,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aN",
|
||||
"roundness": null,
|
||||
"seed": 1849784033,
|
||||
"version": 106,
|
||||
"versionNonce": 762320737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "V_iDW10PKwMe7vWb5S5HF",
|
||||
"type": "text",
|
||||
"x": 366.3074332757606,
|
||||
"y": 307.03138128152926,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aO",
|
||||
"roundness": null,
|
||||
"seed": 1670509249,
|
||||
"version": 186,
|
||||
"versionNonce": 1964540737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "LHKMRdSowgcl2LsKacxTz",
|
||||
"type": "text",
|
||||
"x": 484.9493410573871,
|
||||
"y": 292.45619471187945,
|
||||
"width": 273.579833984375,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aP",
|
||||
"roundness": null,
|
||||
"seed": 945666991,
|
||||
"version": 104,
|
||||
"versionNonce": 1512137505,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704096016,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
}
|
||||
],
|
||||
"appState": {
|
||||
"gridSize": 20,
|
||||
"gridStep": 5,
|
||||
"gridModeEnabled": false,
|
||||
"viewBackgroundColor": "#ffffff"
|
||||
},
|
||||
"files": {}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
## Decay Score
|
||||
|
||||
记忆衰减分数定义为:
|
||||
|
||||
\[
|
||||
\text{decay\_score}
|
||||
= \alpha \cdot e^{-\lambda \cdot \Delta t \cdot \beta}
|
||||
|
||||
+ (1-\alpha)\cdot (1 - e^{-\gamma \cdot c})
|
||||
\]
|
||||
|
||||
其中:
|
||||
|
||||
+ \(\Delta t\):自上次检索以来经过的时间(天),由 `last_retrieval_at` 计算;
|
||||
+ \(c\):检索次数,对应字段 `retrieval_count`;
|
||||
+ \(\alpha\):控制时间衰减和检索次数影响的权重;
|
||||
+ \(\gamma\):控制检索次数影响的速率;
|
||||
+ \(\lambda\):控制时间衰减的速率;
|
||||
+ \(\beta\):时间衰减调节因子;
|
||||
|
||||
\[
|
||||
\beta = \frac{1}{1 + a \cdot c}
|
||||
\]
|
||||
|
||||
+ \(a\):控制检索次数对时间衰减影响的权重。
|
||||
|
||||
## ADD MEMORY
|
||||
|
||||
+ LLM 通过 `astr_add_memory` 工具调用,传入记忆内容和记忆类型。
|
||||
+ 生成 `mem_id = uuid4()`。
|
||||
+ 从上下文中获取 `owner_id = unified_message_origin`。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 使用 VecDB 以新记忆内容为 query,检索前 20 条相似记忆。
|
||||
2. 从中取相似度最高的前 5 条:
|
||||
+ 若相似度超过“合并阈值”(如 `sim >= merge_threshold`):
|
||||
+ 将该条记忆视为同一记忆,使用 LLM 将旧内容与新内容合并;
|
||||
+ 在同一个 `mem_id` 上更新 MemoryDB 和 VecDB(UPDATE,而非新建)。
|
||||
+ 否则:
|
||||
+ 作为全新的记忆插入:
|
||||
+ 写入 VecDB(metadata 中包含 `mem_id`, `owner_id`);
|
||||
+ 写入 MemoryDB 的 `memory_chunks` 表,初始化:
|
||||
+ `created_at = now`
|
||||
+ `last_retrieval_at = now`
|
||||
+ `retrieval_count = 1` 等。
|
||||
3. 对 VecDB 返回的前 20 条记忆,如果相似度高于某个“赫布阈值”(`hebb_threshold`),则:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
这一步体现了赫布学习:与新记忆共同被激活的旧记忆会获得一次强化。
|
||||
|
||||
## QUERY MEMORY (STATIC)
|
||||
|
||||
+ LLM 通过 `astr_query_memory` 工具调用,无参数。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 从 MemoryDB 的 `memory_chunks` 表中查询当前用户所有活跃记忆:
|
||||
+ `SELECT * FROM memory_chunks WHERE owner_id = ? AND is_active = 1`
|
||||
2. 对每条记忆,根据 `last_retrieval_at` 和 `retrieval_count` 计算对应的 `decay_score`。
|
||||
3. 按 `decay_score` 从高到低排序,返回前 `top_k` 条记忆内容给 LLM。
|
||||
4. 对返回的这 `top_k` 条记忆:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
## QUERY MEMORY (DYNAMIC)(暂不实现)
|
||||
|
||||
+ LLM 提供查询内容作为语义 query。
|
||||
+ 使用 VecDB 检索与该 query 最相似的前 `N` 条记忆(`N > top_k`)。
|
||||
+ 根据 `mem_id` 从 `memory_chunks` 中加载对应记录。
|
||||
+ 对这批候选记忆计算:
|
||||
+ 语义相似度(来自 VecDB)
|
||||
+ `decay_score`
|
||||
+ 最终排序分数(例如 `w1 * sim + w2 * decay_score`)
|
||||
+ 按最终排序分数从高到低返回前 `top_k` 条记忆内容,并更新它们的 `retrieval_count` 和 `last_retrieval_at`。
|
||||
@@ -0,0 +1,63 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
from sqlmodel import Field, MetaData, SQLModel
|
||||
|
||||
MEMORY_TYPE_IMPORTANCE = {"persona": 1.3, "fact": 1.0, "ephemeral": 0.8}
|
||||
|
||||
|
||||
class BaseMemoryModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class MemoryChunk(BaseMemoryModel, table=True):
|
||||
"""A chunk of memory stored in the system."""
|
||||
|
||||
__tablename__ = "memory_chunks" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
mem_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
fact: str = Field(nullable=False)
|
||||
"""The factual content of the memory chunk."""
|
||||
owner_id: str = Field(max_length=255, nullable=False, index=True)
|
||||
"""The identifier of the owner (user) of the memory chunk."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The timestamp when the memory chunk was created."""
|
||||
last_retrieval_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
"""The timestamp when the memory chunk was last retrieved."""
|
||||
retrieval_count: int = Field(default=1, nullable=False)
|
||||
"""The number of times the memory chunk has been retrieved."""
|
||||
memory_type: str = Field(max_length=20, nullable=False, default="fact")
|
||||
"""The type of memory (e.g., 'persona', 'fact', 'ephemeral')."""
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
"""Whether the memory chunk is active."""
|
||||
|
||||
def compute_decay_score(self, current_time: datetime) -> float:
|
||||
"""Compute the decay score of the memory chunk based on time and retrievals."""
|
||||
# Constants for the decay formula
|
||||
alpha = 0.5
|
||||
gamma = 0.1
|
||||
lambda_ = 0.05
|
||||
a = 0.1
|
||||
|
||||
# Calculate delta_t in days
|
||||
delta_t = (current_time - self.last_retrieval_at).total_seconds() / 86400
|
||||
c = self.retrieval_count
|
||||
beta = 1 / (1 + a * c)
|
||||
decay_score = alpha * np.exp(-lambda_ * delta_t * beta) + (1 - alpha) * (
|
||||
1 - np.exp(-gamma * c)
|
||||
)
|
||||
return decay_score * MEMORY_TYPE_IMPORTANCE.get(self.memory_type, 1.0)
|
||||
@@ -0,0 +1,174 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import col
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
from .entities import BaseMemoryModel, MemoryChunk
|
||||
|
||||
|
||||
class MemoryDatabase:
|
||||
def __init__(self, db_path: str = "data/astr_memory/memory.db") -> None:
|
||||
"""Initialize memory database
|
||||
|
||||
Args:
|
||||
db_path: Database file path, default is data/astr_memory/memory.db
|
||||
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# Ensure directory exists
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create async engine
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""Get database session
|
||||
|
||||
Usage:
|
||||
async with mem_db.get_db() as session:
|
||||
# Perform database operations
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize database, create tables and configure SQLite parameters"""
|
||||
async with self.engine.begin() as conn:
|
||||
# Create all memory related tables
|
||||
await conn.run_sync(BaseMemoryModel.metadata.create_all)
|
||||
|
||||
# Configure SQLite performance optimization parameters
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
await self._create_indexes()
|
||||
self.inited = True
|
||||
logger.info(f"Memory database initialized: {self.db_path}")
|
||||
|
||||
async def _create_indexes(self) -> None:
|
||||
"""Create indexes for memory_chunks table"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
# Create memory chunks table indexes
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_mem_id "
|
||||
"ON memory_chunks(mem_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_id "
|
||||
"ON memory_chunks(owner_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_active "
|
||||
"ON memory_chunks(owner_id, is_active)",
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"Memory database closed: {self.db_path}")
|
||||
|
||||
async def insert_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Insert a new memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_memory_by_id(self, mem_id: str) -> MemoryChunk | None:
|
||||
"""Get memory chunk by mem_id"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(col(MemoryChunk.mem_id) == mem_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Update an existing memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_active_memories(self, owner_id: str) -> list[MemoryChunk]:
|
||||
"""Get all active memories for a user"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(
|
||||
col(MemoryChunk.owner_id) == owner_id,
|
||||
col(MemoryChunk.is_active) == True, # noqa: E712
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_retrieval_stats(
|
||||
self,
|
||||
mem_ids: list[str],
|
||||
current_time: datetime | None = None,
|
||||
) -> None:
|
||||
"""Update retrieval statistics for multiple memories"""
|
||||
if not mem_ids:
|
||||
return
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id).in_(mem_ids))
|
||||
.values(
|
||||
retrieval_count=MemoryChunk.retrieval_count + 1,
|
||||
last_retrieval_at=current_time,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def deactivate_memory(self, mem_id: str) -> bool:
|
||||
"""Deactivate a memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id) == mem_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0 if result.rowcount else False # type: ignore
|
||||
@@ -0,0 +1,281 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import Provider as LLMProvider
|
||||
|
||||
from .entities import MemoryChunk
|
||||
from .mem_db_sqlite import MemoryDatabase
|
||||
|
||||
MERGE_THRESHOLD = 0.85
|
||||
"""Similarity threshold for merging memories"""
|
||||
HEBB_THRESHOLD = 0.70
|
||||
"""Similarity threshold for Hebbian learning reinforcement"""
|
||||
MERGE_SYSTEM_PROMPT = """You are a memory consolidation assistant. Your task is to merge two related memory entries into a single, comprehensive memory.
|
||||
|
||||
Input format:
|
||||
- Old memory: [existing memory content]
|
||||
- New memory: [new memory content to be integrated]
|
||||
|
||||
Your output should be a single, concise memory that combines the essential information from both entries. Preserve specific details, update outdated information, and eliminate redundancy. Output only the merged memory content without any explanations or meta-commentary."""
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""Manager for user long-term memory storage and retrieval"""
|
||||
|
||||
def __init__(self, memory_root_dir: str = "data/astr_memory"):
|
||||
self.memory_root_dir = Path(memory_root_dir)
|
||||
self.memory_root_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.mem_db: MemoryDatabase | None = None
|
||||
self.vec_db: FaissVecDB | None = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
merge_llm_provider: LLMProvider,
|
||||
):
|
||||
"""Initialize memory database and vector database"""
|
||||
# Initialize MemoryDB
|
||||
db_path = self.memory_root_dir / "memory.db"
|
||||
self.mem_db = MemoryDatabase(db_path.as_posix())
|
||||
await self.mem_db.initialize()
|
||||
|
||||
self.embedding_provider = embedding_provider
|
||||
self.merge_llm_provider = merge_llm_provider
|
||||
|
||||
# Initialize VecDB
|
||||
doc_store_path = self.memory_root_dir / "doc.db"
|
||||
index_store_path = self.memory_root_dir / "index.faiss"
|
||||
self.vec_db = FaissVecDB(
|
||||
doc_store_path=doc_store_path.as_posix(),
|
||||
index_store_path=index_store_path.as_posix(),
|
||||
embedding_provider=self.embedding_provider,
|
||||
)
|
||||
await self.vec_db.initialize()
|
||||
|
||||
logger.info("Memory manager initialized")
|
||||
self._initialized = True
|
||||
|
||||
async def terminate(self):
|
||||
"""Close all database connections"""
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
if self.mem_db:
|
||||
await self.mem_db.close()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
fact: str,
|
||||
owner_id: str,
|
||||
memory_type: str = "fact",
|
||||
) -> MemoryChunk:
|
||||
"""Add a new memory with similarity check and merge logic
|
||||
|
||||
Implements the ADD MEMORY workflow from _README.md:
|
||||
1. Search for similar memories using VecDB
|
||||
2. If similarity >= merge_threshold, merge with existing memory
|
||||
3. Otherwise, create new memory
|
||||
4. Apply Hebbian learning to similar memories (similarity >= hebb_threshold)
|
||||
|
||||
Args:
|
||||
fact: Memory content
|
||||
owner_id: User identifier
|
||||
memory_type: Memory type ('persona', 'fact', 'ephemeral')
|
||||
|
||||
Returns:
|
||||
The created or updated MemoryChunk
|
||||
|
||||
"""
|
||||
if not self.vec_db or not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Search for similar memories
|
||||
similar_results = await self.vec_db.retrieve(
|
||||
query=fact,
|
||||
k=20,
|
||||
fetch_k=50,
|
||||
metadata_filters={"owner_id": owner_id},
|
||||
)
|
||||
|
||||
# Step 2: Check if we should merge with existing memories (top 3 similar ones)
|
||||
merge_candidates = [
|
||||
r for r in similar_results[:3] if r.similarity >= MERGE_THRESHOLD
|
||||
]
|
||||
|
||||
if merge_candidates:
|
||||
# Get all candidate memories from database
|
||||
candidate_memories: list[tuple[str, MemoryChunk]] = []
|
||||
for candidate in merge_candidates:
|
||||
mem_id = json.loads(candidate.data["metadata"])["mem_id"]
|
||||
memory = await self.mem_db.get_memory_by_id(mem_id)
|
||||
if memory:
|
||||
candidate_memories.append((mem_id, memory))
|
||||
|
||||
if candidate_memories:
|
||||
# Use the most similar memory as the base
|
||||
base_mem_id, base_memory = candidate_memories[0]
|
||||
|
||||
# Collect all facts to merge (existing candidates + new fact)
|
||||
all_facts = [mem.fact for _, mem in candidate_memories] + [fact]
|
||||
merged_fact = await self._merge_multiple_memories(all_facts)
|
||||
|
||||
# Update the base memory
|
||||
base_memory.fact = merged_fact
|
||||
base_memory.last_retrieval_at = current_time
|
||||
base_memory.retrieval_count += 1
|
||||
updated_memory = await self.mem_db.update_memory(base_memory)
|
||||
|
||||
# Update VecDB for base memory
|
||||
await self.vec_db.delete(base_mem_id)
|
||||
await self.vec_db.insert(
|
||||
content=merged_fact,
|
||||
metadata={
|
||||
"mem_id": base_mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=base_mem_id,
|
||||
)
|
||||
|
||||
# Deactivate and remove other merged memories
|
||||
for mem_id, _ in candidate_memories[1:]:
|
||||
await self.mem_db.deactivate_memory(mem_id)
|
||||
await self.vec_db.delete(mem_id)
|
||||
|
||||
logger.info(
|
||||
f"Merged {len(candidate_memories)} memories into {base_mem_id} for user {owner_id}"
|
||||
)
|
||||
return updated_memory
|
||||
|
||||
# Step 3: Create new memory
|
||||
mem_id = str(uuid.uuid4())
|
||||
new_memory = MemoryChunk(
|
||||
mem_id=mem_id,
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
created_at=current_time,
|
||||
last_retrieval_at=current_time,
|
||||
retrieval_count=1,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Insert into MemoryDB
|
||||
created_memory = await self.mem_db.insert_memory(new_memory)
|
||||
|
||||
# Insert into VecDB
|
||||
await self.vec_db.insert(
|
||||
content=fact,
|
||||
metadata={
|
||||
"mem_id": mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=mem_id,
|
||||
)
|
||||
|
||||
# Step 4: Apply Hebbian learning to similar memories
|
||||
hebb_mem_ids = [
|
||||
json.loads(r.data["metadata"])["mem_id"]
|
||||
for r in similar_results
|
||||
if r.similarity >= HEBB_THRESHOLD
|
||||
]
|
||||
if hebb_mem_ids:
|
||||
await self.mem_db.update_retrieval_stats(hebb_mem_ids, current_time)
|
||||
logger.debug(
|
||||
f"Applied Hebbian learning to {len(hebb_mem_ids)} memories for user {owner_id}",
|
||||
)
|
||||
|
||||
logger.info(f"Created new memory {mem_id} for user {owner_id}")
|
||||
return created_memory
|
||||
|
||||
async def query_memory(
|
||||
self,
|
||||
owner_id: str,
|
||||
top_k: int = 5,
|
||||
) -> list[MemoryChunk]:
|
||||
"""Query user's memories using static retrieval with decay score ranking
|
||||
|
||||
Implements the QUERY MEMORY (STATIC) workflow from _README.md:
|
||||
1. Get all active memories for user from MemoryDB
|
||||
2. Compute decay_score for each memory
|
||||
3. Sort by decay_score and return top_k
|
||||
4. Update retrieval statistics for returned memories
|
||||
|
||||
Args:
|
||||
owner_id: User identifier
|
||||
top_k: Number of memories to return
|
||||
|
||||
Returns:
|
||||
List of top_k MemoryChunk sorted by decay score
|
||||
"""
|
||||
if not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Get all active memories for user
|
||||
all_memories = await self.mem_db.get_active_memories(owner_id)
|
||||
|
||||
if not all_memories:
|
||||
return []
|
||||
|
||||
# Step 2-3: Compute decay scores and sort
|
||||
memories_with_scores = [
|
||||
(mem, mem.compute_decay_score(current_time)) for mem in all_memories
|
||||
]
|
||||
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get top_k memories
|
||||
top_memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
||||
|
||||
# Step 4: Update retrieval statistics
|
||||
mem_ids = [mem.mem_id for mem in top_memories]
|
||||
await self.mem_db.update_retrieval_stats(mem_ids, current_time)
|
||||
|
||||
logger.debug(f"Retrieved {len(top_memories)} memories for user {owner_id}")
|
||||
return top_memories
|
||||
|
||||
async def _merge_multiple_memories(self, facts: list[str]) -> str:
|
||||
"""Merge multiple memory facts using LLM in one call
|
||||
|
||||
Args:
|
||||
facts: List of memory facts to merge
|
||||
|
||||
Returns:
|
||||
Merged memory content
|
||||
"""
|
||||
if not self.merge_llm_provider:
|
||||
return " ".join(facts)
|
||||
|
||||
if len(facts) == 1:
|
||||
return facts[0]
|
||||
|
||||
try:
|
||||
# Format all facts as a numbered list
|
||||
facts_list = "\n".join(f"{i + 1}. {fact}" for i, fact in enumerate(facts))
|
||||
user_prompt = (
|
||||
f"Please merge the following {len(facts)} related memory entries "
|
||||
"into a single, comprehensive memory:"
|
||||
f"\n{facts_list}\n\nOutput only the merged memory content."
|
||||
)
|
||||
response = await self.merge_llm_provider.text_chat(
|
||||
prompt=user_prompt,
|
||||
system_prompt=MERGE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
merged_content = response.completion_text.strip()
|
||||
return merged_content if merged_content else " ".join(facts)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to merge memories with LLM: {e}, using fallback")
|
||||
return " ".join(facts)
|
||||
@@ -0,0 +1,156 @@
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext, ContextWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for adding memories to user's long-term memory storage"""
|
||||
|
||||
name: str = "astr_add_memory"
|
||||
description: str = (
|
||||
"Add a new memory to the user's long-term memory storage. "
|
||||
"Use this tool only when the user explicitly asks you to remember something, "
|
||||
"or when they share stable preferences, identity, or long-term goals that will be useful in future interactions."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fact": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The concrete memory content to store, such as a user preference, "
|
||||
"identity detail, long-term goal, or stable profile fact."
|
||||
),
|
||||
},
|
||||
"memory_type": {
|
||||
"type": "string",
|
||||
"enum": ["persona", "fact", "ephemeral"],
|
||||
"description": (
|
||||
"The relative importance of this memory. "
|
||||
"Use 'persona' for core identity or highly impactful information, "
|
||||
"'fact' for normal long-term preferences, "
|
||||
"and 'ephemeral' for minor or tentative facts."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["fact", "memory_type"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Add a memory to long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Must contain 'fact' and 'memory_type'
|
||||
|
||||
Returns:
|
||||
ToolExecResult with success message
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
fact = kwargs.get("fact")
|
||||
memory_type = kwargs.get("memory_type", "fact")
|
||||
|
||||
if not fact:
|
||||
return "Missing required parameter: fact"
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Add memory using memory manager
|
||||
memory = await mm.add_memory(
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
)
|
||||
|
||||
return f"Memory added successfully (ID: {memory.mem_id})"
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to add memory: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for querying user's long-term memories"""
|
||||
|
||||
name: str = "astr_query_memory"
|
||||
description: str = (
|
||||
"Query the user's long-term memory storage and return the most relevant memories. "
|
||||
"Use this tool when you need user-specific context, preferences, or past facts "
|
||||
"that are not explicitly present in the current conversation."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of memories to retrieve after retention-based ranking. "
|
||||
"Typically between 3 and 10."
|
||||
),
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Query memories from long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Optional 'top_k' parameter
|
||||
|
||||
Returns:
|
||||
ToolExecResult with formatted memory list
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Query memories using memory manager
|
||||
memories = await mm.query_memory(
|
||||
owner_id=owner_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not memories:
|
||||
return "No memories found for this user."
|
||||
|
||||
# Format memories for output
|
||||
formatted_memories = []
|
||||
for i, mem in enumerate(memories, 1):
|
||||
formatted_memories.append(
|
||||
f"{i}. [{mem.memory_type.upper()}] {mem.fact} "
|
||||
f"(retrieved {mem.retrieval_count} times, "
|
||||
f"last: {mem.last_retrieval_at.strftime('%Y-%m-%d')})"
|
||||
)
|
||||
|
||||
result_text = "Retrieved memories:\n" + "\n".join(formatted_memories)
|
||||
return result_text
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to query memories: {str(e)}"
|
||||
|
||||
|
||||
ADD_MEMORY_TOOL = AddMemory()
|
||||
QUERY_MEMORY_TOOL = QueryMemory()
|
||||
@@ -66,9 +66,6 @@ class ComponentType(str, Enum):
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def toDict(self):
|
||||
data = {}
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -554,7 +551,7 @@ class Node(BaseMessageComponent):
|
||||
id: int | None = 0 # 忽略
|
||||
name: str | None = "" # qq昵称
|
||||
uin: str | None = "0" # qq号
|
||||
content: list[BaseMessageComponent] = []
|
||||
content: list[BaseMessageComponent] | None = []
|
||||
seq: str | list | None = "" # 忽略
|
||||
time: int | None = 0 # 忽略
|
||||
|
||||
@@ -618,7 +615,7 @@ class Nodes(BaseMessageComponent):
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self) -> dict:
|
||||
async def to_dict(self):
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
@@ -629,11 +626,12 @@ class Nodes(BaseMessageComponent):
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type = ComponentType.Json
|
||||
data: dict
|
||||
data: str | dict
|
||||
resid: int | None = 0
|
||||
|
||||
def __init__(self, data: str | dict, **_):
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
def __init__(self, data, **_):
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
super().__init__(data=data, **_)
|
||||
|
||||
|
||||
@@ -716,23 +714,15 @@ class File(BaseMessageComponent):
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
if self.file_:
|
||||
return os.path.abspath(self.file_)
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if not self.url:
|
||||
raise ValueError("Download failed: No URL provided in File component.")
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
if self.name:
|
||||
name, ext = os.path.splitext(self.name)
|
||||
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
else:
|
||||
filename = f"{uuid.uuid4().hex}"
|
||||
file_path = os.path.join(download_dir, filename)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
|
||||
@@ -98,8 +98,8 @@ class PersonaManager:
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
|
||||
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
check_text: str | None = None,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
"""检查内容安全"""
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
ok, info = self.strategy_selector.check(text)
|
||||
|
||||
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
@@ -91,7 +91,6 @@ async def call_event_hook(
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
assert inspect.iscoroutinefunction(handler.handler)
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||
)
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from .agent_sub_stages.internal import InternalAgentSubStage
|
||||
from .agent_sub_stages.third_party import ThirdPartyAgentSubStage
|
||||
|
||||
|
||||
class AgentRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
|
||||
self.bot_wake_prefixs: list[str] = self.config["wake_prefix"]
|
||||
self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"]
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.prov_wake_prefix.startswith(bwp):
|
||||
logger.info(
|
||||
f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
||||
)
|
||||
self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :]
|
||||
|
||||
agent_runner_type = self.config["provider_settings"]["agent_runner_type"]
|
||||
if agent_runner_type == "local":
|
||||
self.agent_sub_stage = InternalAgentSubStage()
|
||||
else:
|
||||
self.agent_sub_stage = ThirdPartyAgentSubStage()
|
||||
await self.agent_sub_stage.initialize(ctx)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]:
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug(
|
||||
"This pipeline does not enable AI capability, skip processing."
|
||||
)
|
||||
return
|
||||
|
||||
if not await SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(
|
||||
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
|
||||
)
|
||||
return
|
||||
|
||||
async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix):
|
||||
yield resp
|
||||
@@ -1,561 +0,0 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from .....astr_agent_context import AgentContextWrapper
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from .....astr_agent_run_util import AgentRunner, run_agent
|
||||
from .....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
||||
|
||||
|
||||
class InternalAgentSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
]
|
||||
self.max_step: int = settings.get("max_agent_step", 30)
|
||||
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
file_extract_conf: dict = settings.get("file_extract", {})
|
||||
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
|
||||
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
|
||||
self.file_extract_msh_api_key: str = file_extract_conf.get(
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
# 上下文管理相关
|
||||
self.context_limit_reached_strategy: str = settings.get(
|
||||
"context_limit_reached_strategy", "truncate_by_turns"
|
||||
)
|
||||
self.llm_compress_instruction: str = settings.get(
|
||||
"llm_compress_instruction", ""
|
||||
)
|
||||
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
||||
self.llm_compress_provider_id: str = settings.get(
|
||||
"llm_compress_provider_id", ""
|
||||
)
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
if self.dequeue_context_length <= 0:
|
||||
self.dequeue_context_length = 1
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
"""选择使用的 LLM 提供商"""
|
||||
sel_provider = event.get_extra("selected_provider")
|
||||
_ctx = self.ctx.plugin_manager.context
|
||||
if sel_provider and isinstance(sel_provider, str):
|
||||
provider = _ctx.get_provider_by_id(sel_provider)
|
||||
if not provider:
|
||||
logger.error(f"未找到指定的提供商: {sel_provider}。")
|
||||
return provider
|
||||
|
||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||
|
||||
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
||||
umo = event.unified_msg_origin
|
||||
conv_mgr = self.conv_manager
|
||||
|
||||
# 获取对话上下文
|
||||
cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
if not cid:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
async def _apply_kb(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply knowledge base context to the provider request"""
|
||||
if not self.kb_agentic_mode:
|
||||
if req.prompt is None:
|
||||
return
|
||||
try:
|
||||
kb_result = await retrieve_knowledge_base(
|
||||
query=req.prompt,
|
||||
umo=event.unified_msg_origin,
|
||||
context=self.ctx.plugin_manager.context,
|
||||
)
|
||||
if not kb_result:
|
||||
return
|
||||
if req.system_prompt is not None:
|
||||
req.system_prompt += (
|
||||
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while retrieving knowledge base: {e}")
|
||||
else:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
async def _apply_file_extract(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply file extract to the provider request"""
|
||||
file_paths = []
|
||||
file_names = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
file_paths.append(await comp.get_file())
|
||||
file_names.append(comp.name)
|
||||
elif isinstance(comp, Reply) and comp.chain:
|
||||
for reply_comp in comp.chain:
|
||||
if isinstance(reply_comp, File):
|
||||
file_paths.append(await reply_comp.get_file())
|
||||
file_names.append(reply_comp.name)
|
||||
if not file_paths:
|
||||
return
|
||||
if not req.prompt:
|
||||
req.prompt = "总结一下文件里面讲了什么?"
|
||||
if self.file_extract_prov == "moonshotai":
|
||||
if not self.file_extract_msh_api_key:
|
||||
logger.error("Moonshot AI API key for file extract is not set")
|
||||
return
|
||||
file_contents = await asyncio.gather(
|
||||
*[
|
||||
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
|
||||
for file_path in file_paths
|
||||
]
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
|
||||
return
|
||||
|
||||
# add file extract results to contexts
|
||||
for file_content, file_name in zip(file_contents, file_names):
|
||||
req.contexts.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
|
||||
},
|
||||
)
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""检查提供商的模态能力,清理请求中的不支持内容"""
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
|
||||
req.image_urls = []
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(
|
||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
||||
)
|
||||
req.func_tool = None
|
||||
|
||||
def _plugin_tool_fix(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表"""
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
mp = tool.handler_module_path
|
||||
if not mp:
|
||||
continue
|
||||
plugin = star_map.get(mp)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
async def _handle_webchat(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
prov: Provider,
|
||||
):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
if not req.conversation:
|
||||
return
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
)
|
||||
if conversation and not req.conversation.title:
|
||||
messages = json.loads(conversation.history)
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
content = latest_pair[0].get("content", "")
|
||||
if isinstance(content, list):
|
||||
# 多模态
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image":
|
||||
text_parts.append("[图片]")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
cleaned_text = "User: " + " ".join(text_parts).strip()
|
||||
elif isinstance(content, str):
|
||||
cleaned_text = "User: " + content.strip()
|
||||
else:
|
||||
return
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await prov.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
prompt=(
|
||||
f"Please summarize the following query of user:\n"
|
||||
f"{cleaned_text}\n"
|
||||
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
|
||||
"You must use the same language as the user."
|
||||
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
|
||||
),
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
unified_msg_origin=event.unified_msg_origin,
|
||||
title=title,
|
||||
conversation_id=req.conversation.cid,
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
or not llm_response
|
||||
or llm_response.role != "assistant"
|
||||
):
|
||||
return
|
||||
|
||||
if not llm_response.completion_text and not req.tool_calls_result:
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
# using agent context messages to save to history
|
||||
message_to_save = []
|
||||
for message in all_messages:
|
||||
if message.role == "system":
|
||||
# we do not save system messages to history
|
||||
continue
|
||||
if message.role in ["assistant", "user"] and getattr(
|
||||
message, "_no_save", None
|
||||
):
|
||||
# we do not save user and assistant messages that are marked as _no_save
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# get token usage from agent runner stats
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
try:
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(
|
||||
f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。"
|
||||
)
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
# 检查消息内容是否有效,避免空消息触发钩子
|
||||
has_provider_request = event.get_extra("provider_request") is not None
|
||||
has_valid_message = bool(event.message_str and event.message_str.strip())
|
||||
|
||||
if not has_provider_request and not has_valid_message:
|
||||
logger.debug("skip llm request: empty message and no provider_request")
|
||||
return
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
|
||||
# 通知等待调用 LLM(在获取锁之前)
|
||||
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
||||
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
# 检查事件是否被停止,如果被停止则不保存历史记录
|
||||
if not event.is_stopped():
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
)
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing agent: {e}")
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"Error occurred while processing agent request: {e}"
|
||||
)
|
||||
)
|
||||
@@ -1,205 +0,0 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
|
||||
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
||||
DashscopeAgentRunner,
|
||||
)
|
||||
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.agent.runners.base import BaseAgentRunner
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
|
||||
AGENT_RUNNER_TYPE_KEY = {
|
||||
"dify": "dify_agent_runner_provider_id",
|
||||
"coze": "coze_agent_runner_provider_id",
|
||||
"dashscope": "dashscope_agent_runner_provider_id",
|
||||
}
|
||||
|
||||
|
||||
async def run_third_party_agent(
|
||||
runner: "BaseAgentRunner",
|
||||
stream_to_general: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
"""
|
||||
运行第三方 agent runner 并转换响应格式
|
||||
类似于 run_agent 函数,但专门处理第三方 agent runner
|
||||
"""
|
||||
try:
|
||||
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
|
||||
if resp.type == "streaming_delta":
|
||||
if stream_to_general:
|
||||
continue
|
||||
yield resp.data["chain"]
|
||||
elif resp.type == "llm_result":
|
||||
if stream_to_general:
|
||||
yield resp.data["chain"]
|
||||
except Exception as e:
|
||||
logger.error(f"Third party agent runner error: {e}")
|
||||
err_msg = (
|
||||
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
|
||||
f"错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
)
|
||||
yield MessageChain().message(err_msg)
|
||||
|
||||
|
||||
class ThirdPartyAgentSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.conf = ctx.astrbot_config
|
||||
self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
|
||||
self.prov_id = self.conf["provider_settings"].get(
|
||||
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
|
||||
"",
|
||||
)
|
||||
settings = ctx.astrbot_config["provider_settings"]
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
]
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
self.prov_cfg: dict = next(
|
||||
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id),
|
||||
{},
|
||||
)
|
||||
if not self.prov_id:
|
||||
logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。")
|
||||
return
|
||||
if not self.prov_cfg:
|
||||
logger.error(
|
||||
f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。"
|
||||
)
|
||||
return
|
||||
|
||||
# make provider request
|
||||
req = ProviderRequest()
|
||||
req.session_id = event.unified_msg_origin
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_base64()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
if self.runner_type == "dify":
|
||||
runner = DifyAgentRunner[AstrAgentContext]()
|
||||
elif self.runner_type == "coze":
|
||||
runner = CozeAgentRunner[AstrAgentContext]()
|
||||
elif self.runner_type == "dashscope":
|
||||
runner = DashscopeAgentRunner[AstrAgentContext]()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported third party agent runner type: {self.runner_type}",
|
||||
)
|
||||
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
|
||||
await runner.reset(
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=60,
|
||||
),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
provider_config=self.prov_cfg,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_third_party_agent(
|
||||
runner,
|
||||
stream_to_general=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if runner.done():
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
if final_resp and final_resp.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=final_resp.result_chain.chain or [],
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 非流式响应或转换为普通响应
|
||||
async for _ in run_third_party_agent(
|
||||
runner,
|
||||
stream_to_general=stream_to_general,
|
||||
):
|
||||
yield
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
|
||||
if not final_resp or not final_resp.result_chain:
|
||||
logger.warning("Agent Runner 未返回最终结果。")
|
||||
return
|
||||
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=final_resp.result_chain.chain or [],
|
||||
result_content_type=ResultContentType.LLM_RESULT,
|
||||
),
|
||||
)
|
||||
yield
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=self.runner_type,
|
||||
provider_type=self.runner_type,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,498 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from ....astr_agent_context import AgentContextWrapper
|
||||
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....astr_agent_run_util import AgentRunner, run_agent
|
||||
from ....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ....memory.tools import ADD_MEMORY_TOOL, QUERY_MEMORY_TOOL
|
||||
from ...context import PipelineContext, call_event_hook
|
||||
from ..stage import Stage
|
||||
from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
||||
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
]
|
||||
self.max_step: int = settings.get("max_agent_step", 30)
|
||||
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(
|
||||
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
||||
)
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
"""选择使用的 LLM 提供商"""
|
||||
sel_provider = event.get_extra("selected_provider")
|
||||
_ctx = self.ctx.plugin_manager.context
|
||||
if sel_provider and isinstance(sel_provider, str):
|
||||
provider = _ctx.get_provider_by_id(sel_provider)
|
||||
if not provider:
|
||||
logger.error(f"未找到指定的提供商: {sel_provider}。")
|
||||
return provider
|
||||
|
||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||
|
||||
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
||||
umo = event.unified_msg_origin
|
||||
conv_mgr = self.conv_manager
|
||||
|
||||
# 获取对话上下文
|
||||
cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
if not cid:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
async def _apply_kb(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply knowledge base context to the provider request"""
|
||||
if not self.kb_agentic_mode:
|
||||
if req.prompt is None:
|
||||
return
|
||||
try:
|
||||
kb_result = await retrieve_knowledge_base(
|
||||
query=req.prompt,
|
||||
umo=event.unified_msg_origin,
|
||||
context=self.ctx.plugin_manager.context,
|
||||
)
|
||||
if not kb_result:
|
||||
return
|
||||
if req.system_prompt is not None:
|
||||
req.system_prompt += (
|
||||
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while retrieving knowledge base: {e}")
|
||||
else:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
async def _apply_memory(self, req: ProviderRequest):
|
||||
mm = self.ctx.plugin_manager.context.memory_manager
|
||||
if not mm or not mm._initialized:
|
||||
return
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(ADD_MEMORY_TOOL)
|
||||
req.func_tool.add_tool(QUERY_MEMORY_TOOL)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""检查提供商的模态能力,清理请求中的不支持内容"""
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
|
||||
req.image_urls = []
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(
|
||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
||||
)
|
||||
req.func_tool = None
|
||||
|
||||
def _plugin_tool_fix(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表"""
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
mp = tool.handler_module_path
|
||||
if not mp:
|
||||
continue
|
||||
plugin = star_map.get(mp)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
async def _handle_webchat(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
prov: Provider,
|
||||
):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
if not req.conversation:
|
||||
return
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
)
|
||||
if conversation and not req.conversation.title:
|
||||
messages = json.loads(conversation.history)
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
content = latest_pair[0].get("content", "")
|
||||
if isinstance(content, list):
|
||||
# 多模态
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image":
|
||||
text_parts.append("[图片]")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
cleaned_text = "User: " + " ".join(text_parts).strip()
|
||||
elif isinstance(content, str):
|
||||
cleaned_text = "User: " + content.strip()
|
||||
else:
|
||||
return
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await prov.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
prompt=(
|
||||
f"Please summarize the following query of user:\n"
|
||||
f"{cleaned_text}\n"
|
||||
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
|
||||
"You must use the same language as the user."
|
||||
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
|
||||
),
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
unified_msg_origin=event.unified_msg_origin,
|
||||
title=title,
|
||||
conversation_id=req.conversation.cid,
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
or not llm_response
|
||||
or llm_response.role != "assistant"
|
||||
):
|
||||
return
|
||||
|
||||
if not llm_response.completion_text and not req.tool_calls_result:
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
if req.contexts is None:
|
||||
req.contexts = []
|
||||
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
messages.append(await req.assemble_context())
|
||||
# 这一轮对话的 LLM 响应
|
||||
if req.tool_calls_result:
|
||||
if not isinstance(req.tool_calls_result, list):
|
||||
messages.extend(req.tool_calls_result.to_openai_messages())
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=messages,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
_nested: bool = False,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if self.provider_wake_prefix and not event.message_str.startswith(
|
||||
self.provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# apply memory feature
|
||||
await self._apply_memory(req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
@@ -16,6 +16,7 @@ from ..stage import Stage
|
||||
|
||||
class StarRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
||||
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
|
||||
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
|
||||
self.ctx = ctx
|
||||
@@ -23,7 +24,7 @@ class StarRequestSubStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||
"activated_handlers",
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
from .method.agent_request import AgentRequestSubStage
|
||||
from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
|
||||
|
||||
@@ -16,12 +17,9 @@ class ProcessStage(Stage):
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
self.llm_request_sub_stage = LLMRequestSubStage()
|
||||
await self.llm_request_sub_stage.initialize(ctx)
|
||||
|
||||
# initialize agent sub stage
|
||||
self.agent_sub_stage = AgentRequestSubStage()
|
||||
await self.agent_sub_stage.initialize(ctx)
|
||||
|
||||
# initialize star request sub stage
|
||||
self.star_request_sub_stage = StarRequestSubStage()
|
||||
await self.star_request_sub_stage.initialize(ctx)
|
||||
|
||||
@@ -41,7 +39,7 @@ class ProcessStage(Stage):
|
||||
# Handler 的 LLM 请求
|
||||
event.set_extra("provider_request", resp)
|
||||
_t = False
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
_t = True
|
||||
yield
|
||||
if not _t:
|
||||
@@ -60,7 +58,14 @@ class ProcessStage(Stage):
|
||||
):
|
||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||
if (
|
||||
event.get_result() and not event.is_stopped()
|
||||
event.get_result() and not event.get_result().is_stopped()
|
||||
) or not event.get_result():
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
# 事件没有终止传播
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -117,9 +117,7 @@ class RespondStage(Stage):
|
||||
if not self.enable_seg:
|
||||
return False
|
||||
|
||||
if (result := event.get_result()) is None:
|
||||
return False
|
||||
if self.only_llm_result and not result.is_llm_result():
|
||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
@@ -158,11 +156,7 @@ class RespondStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
if event.get_extra("_streaming_finished", False):
|
||||
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
event.set_extra("_streaming_finished", True)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
@@ -191,7 +185,7 @@ class RespondStage(Stage):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
result.chain[idx] = component
|
||||
event.get_result().chain[idx] = component
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -7,7 +6,6 @@ from collections.abc import AsyncGenerator
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
@@ -43,18 +41,6 @@ class ResultDecorateStage(Stage):
|
||||
"forward_threshold"
|
||||
]
|
||||
|
||||
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
|
||||
"trigger_probability",
|
||||
1,
|
||||
)
|
||||
try:
|
||||
self.tts_trigger_probability = max(
|
||||
0.0,
|
||||
min(float(trigger_probability), 1.0),
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
self.tts_trigger_probability = 1.0
|
||||
|
||||
# 分段回复
|
||||
self.words_count_threshold = int(
|
||||
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||
@@ -67,22 +53,7 @@ class ResultDecorateStage(Stage):
|
||||
self.only_llm_result = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["only_llm_result"]
|
||||
self.split_mode = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_mode", "regex")
|
||||
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
|
||||
self.split_words = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_words", ["。", "?", "!", "~", "…"])
|
||||
if self.split_words:
|
||||
escaped_words = sorted(
|
||||
[re.escape(word) for word in self.split_words], key=len, reverse=True
|
||||
)
|
||||
self.split_words_pattern = re.compile(
|
||||
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
|
||||
)
|
||||
else:
|
||||
self.split_words_pattern = None
|
||||
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["content_cleanup_rule"]
|
||||
@@ -98,31 +69,6 @@ class ResultDecorateStage(Stage):
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
provider_cfg = ctx.astrbot_config.get("provider_settings", {})
|
||||
self.show_reasoning = provider_cfg.get("display_reasoning_text", False)
|
||||
|
||||
def _split_text_by_words(self, text: str) -> list[str]:
|
||||
"""使用分段词列表分段文本"""
|
||||
if not self.split_words_pattern:
|
||||
return [text]
|
||||
|
||||
segments = self.split_words_pattern.findall(text)
|
||||
result = []
|
||||
for seg in segments:
|
||||
if isinstance(seg, tuple):
|
||||
content = seg[0]
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
for word in self.split_words:
|
||||
if content.endswith(word):
|
||||
content = content[: -len(word)]
|
||||
break
|
||||
if content.strip():
|
||||
result.append(content)
|
||||
elif seg and seg.strip():
|
||||
result.append(seg)
|
||||
return result if result else [text]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -147,13 +93,11 @@ class ResultDecorateStage(Stage):
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text += comp.text
|
||||
|
||||
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
@@ -170,8 +114,7 @@ class ResultDecorateStage(Stage):
|
||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
|
||||
)
|
||||
await handler.handler(event)
|
||||
|
||||
if (result := event.get_result()) is None or not result.chain:
|
||||
if event.get_result() is None or not event.get_result().chain:
|
||||
logger.debug(
|
||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
|
||||
)
|
||||
@@ -218,27 +161,11 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
# 根据 split_mode 选择分段方式
|
||||
if self.split_mode == "words":
|
||||
split_response = self._split_text_by_words(comp.text)
|
||||
else: # regex 模式
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
@@ -257,75 +184,63 @@ class ResultDecorateStage(Stage):
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
should_tts = (
|
||||
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
|
||||
and result.is_llm_result()
|
||||
and await SessionServiceManager.should_process_tts_request(event)
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
and tts_provider
|
||||
)
|
||||
if should_tts and not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
|
||||
if (
|
||||
not should_tts
|
||||
and self.show_reasoning
|
||||
and event.get_extra("_llm_reasoning_content")
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
# inject reasoning content to chain
|
||||
reasoning_content = event.get_extra("_llm_reasoning_content")
|
||||
result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n"))
|
||||
if not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
else:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
)
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
if should_tts and tts_provider:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (
|
||||
|
||||
@@ -2,10 +2,6 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
||||
WecomAIBotMessageEvent,
|
||||
)
|
||||
|
||||
from . import STAGES_ORDER
|
||||
from .context import PipelineContext
|
||||
@@ -82,7 +78,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
|
||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# 检查会话是否整体启用
|
||||
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
|
||||
# workaround for #2309
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
@@ -14,22 +13,6 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = {
|
||||
"aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"dingtalk": lambda e: e.get_sender_id(),
|
||||
"qq_official": lambda e: e.get_sender_id(),
|
||||
"qq_official_webhook": lambda e: e.get_sender_id(),
|
||||
"lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}",
|
||||
"misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}",
|
||||
}
|
||||
|
||||
|
||||
def build_unique_session_id(event: AstrMessageEvent) -> str | None:
|
||||
platform = event.get_platform_name()
|
||||
builder = UNIQUE_SESSION_ID_BUILDERS.get(platform)
|
||||
return builder(event) if builder else None
|
||||
|
||||
|
||||
@register_stage
|
||||
class WakingCheckStage(Stage):
|
||||
@@ -67,30 +50,18 @@ class WakingCheckStage(Stage):
|
||||
"ignore_at_all",
|
||||
False,
|
||||
)
|
||||
self.disable_builtin_commands = self.ctx.astrbot_config.get(
|
||||
"disable_builtin_commands", False
|
||||
)
|
||||
platform_settings = self.ctx.astrbot_config.get("platform_settings", {})
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# apply unique session
|
||||
if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE:
|
||||
sid = build_unique_session_id(event)
|
||||
if sid:
|
||||
event.session_id = sid
|
||||
|
||||
# ignore bot self message
|
||||
if (
|
||||
self.ignore_bot_self_message
|
||||
and event.get_self_id() == event.get_sender_id()
|
||||
):
|
||||
# 忽略机器人自己发送的消息
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||
@@ -160,14 +131,6 @@ class WakingCheckStage(Stage):
|
||||
EventType.AdapterMessageEvent,
|
||||
plugins_name=event.plugins_name,
|
||||
):
|
||||
if (
|
||||
self.disable_builtin_commands
|
||||
and handler.handler_module_path
|
||||
== "astrbot.builtin_stars.builtin_commands.main"
|
||||
):
|
||||
logger.debug("skipping builtin command")
|
||||
continue
|
||||
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
@@ -226,7 +189,7 @@ class WakingCheckStage(Stage):
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
event,
|
||||
activated_handlers,
|
||||
)
|
||||
|
||||
@@ -153,9 +153,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
def get_sender_name(self) -> str:
|
||||
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
||||
if isinstance(self.message_obj.sender.nickname, str):
|
||||
return self.message_obj.sender.nickname
|
||||
return ""
|
||||
return self.message_obj.sender.nickname
|
||||
|
||||
def set_extra(self, key, value):
|
||||
"""设置额外的信息。"""
|
||||
@@ -272,7 +270,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self.call_llm = call_llm
|
||||
|
||||
def get_result(self) -> MessageEventResult | None:
|
||||
def get_result(self) -> MessageEventResult:
|
||||
"""获取消息事件的结果。"""
|
||||
return self._result
|
||||
|
||||
@@ -322,7 +320,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
session_id: str = "",
|
||||
session_id: str = None,
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
system_prompt: str = "",
|
||||
|
||||
@@ -54,7 +54,7 @@ class AstrBotMessage:
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group: Group | None # 群组
|
||||
group: Group # 群组
|
||||
sender: MessageMember # 发送者
|
||||
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
@@ -78,7 +78,7 @@ class AstrBotMessage:
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str | None):
|
||||
def group_id(self, value: str):
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
|
||||
@@ -5,9 +5,8 @@ from asyncio import Queue
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
|
||||
|
||||
from .platform import Platform, PlatformStatus
|
||||
from .platform import Platform
|
||||
from .register import platform_cls_map
|
||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
@@ -17,9 +16,8 @@ class PlatformManager:
|
||||
self.platform_insts: list[Platform] = []
|
||||
"""加载的 Platform 的实例"""
|
||||
|
||||
self._inst_map: dict[str, dict] = {}
|
||||
self._inst_map = {}
|
||||
|
||||
self.astrbot_config = config
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
|
||||
@@ -31,8 +29,6 @@ class PlatformManager:
|
||||
"""初始化所有平台适配器"""
|
||||
for platform in self.platforms_config:
|
||||
try:
|
||||
if ensure_platform_webhook_config(platform):
|
||||
self.astrbot_config.save_config()
|
||||
await self.load_platform(platform)
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
|
||||
@@ -41,10 +37,7 @@ class PlatformManager:
|
||||
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
|
||||
self.platform_insts.append(webchat_inst)
|
||||
asyncio.create_task(
|
||||
self._task_wrapper(
|
||||
asyncio.create_task(webchat_inst.run(), name="webchat"),
|
||||
platform=webchat_inst,
|
||||
),
|
||||
self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")),
|
||||
)
|
||||
|
||||
async def load_platform(self, platform_config: dict):
|
||||
@@ -70,6 +63,10 @@ class PlatformManager:
|
||||
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
||||
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
)
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import (
|
||||
LarkPlatformAdapter, # noqa: F401
|
||||
@@ -110,7 +107,7 @@ class PlatformManager:
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")
|
||||
@@ -134,7 +131,6 @@ class PlatformManager:
|
||||
inst.run(),
|
||||
name=f"platform_{platform_config['type']}_{platform_config['id']}",
|
||||
),
|
||||
platform=inst,
|
||||
),
|
||||
)
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
@@ -149,28 +145,17 @@ class PlatformManager:
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
|
||||
# 设置平台状态为运行中
|
||||
if platform:
|
||||
platform.status = PlatformStatus.RUNNING
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
if platform:
|
||||
platform.status = PlatformStatus.STOPPED
|
||||
pass
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
tb_str = traceback.format_exc()
|
||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||
for line in tb_str.split("\n"):
|
||||
for line in traceback.format_exc().split("\n"):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
# 记录错误到平台实例
|
||||
if platform:
|
||||
platform.record_error(error_msg, tb_str)
|
||||
|
||||
async def reload(self, platform_config: dict):
|
||||
await self.terminate_platform(platform_config["id"])
|
||||
if platform_config["enable"]:
|
||||
@@ -187,9 +172,9 @@ class PlatformManager:
|
||||
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
|
||||
|
||||
# client_id = self._inst_map.pop(platform_id, None)
|
||||
info = self._inst_map.pop(platform_id)
|
||||
info = self._inst_map.pop(platform_id, None)
|
||||
client_id = info["client_id"]
|
||||
inst: Platform = info["inst"]
|
||||
inst = info["inst"]
|
||||
try:
|
||||
self.platform_insts.remove(
|
||||
next(
|
||||
@@ -211,46 +196,3 @@ class PlatformManager:
|
||||
|
||||
def get_insts(self):
|
||||
return self.platform_insts
|
||||
|
||||
def get_all_stats(self) -> dict:
|
||||
"""获取所有平台的统计信息
|
||||
|
||||
Returns:
|
||||
包含所有平台统计信息的字典
|
||||
"""
|
||||
stats_list = []
|
||||
total_errors = 0
|
||||
running_count = 0
|
||||
error_count = 0
|
||||
|
||||
for inst in self.platform_insts:
|
||||
try:
|
||||
stat = inst.get_stats()
|
||||
stats_list.append(stat)
|
||||
total_errors += stat.get("error_count", 0)
|
||||
if stat.get("status") == PlatformStatus.RUNNING.value:
|
||||
running_count += 1
|
||||
elif stat.get("status") == PlatformStatus.ERROR.value:
|
||||
error_count += 1
|
||||
except Exception as e:
|
||||
# 如果获取统计信息失败,记录基本信息
|
||||
logger.warning(f"获取平台统计信息失败: {e}")
|
||||
stats_list.append(
|
||||
{
|
||||
"id": getattr(inst, "config", {}).get("id", "unknown"),
|
||||
"type": "unknown",
|
||||
"status": "unknown",
|
||||
"error_count": 0,
|
||||
"last_error": None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"platforms": stats_list,
|
||||
"summary": {
|
||||
"total": len(stats_list),
|
||||
"running": running_count,
|
||||
"error": error_count,
|
||||
"total_errors": total_errors,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -15,100 +12,15 @@ from .message_session import MessageSesion
|
||||
from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
class PlatformStatus(Enum):
|
||||
"""平台运行状态"""
|
||||
|
||||
PENDING = "pending" # 待启动
|
||||
RUNNING = "running" # 运行中
|
||||
ERROR = "error" # 发生错误
|
||||
STOPPED = "stopped" # 已停止
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformError:
|
||||
"""平台错误信息"""
|
||||
|
||||
message: str
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
traceback: str | None = None
|
||||
|
||||
|
||||
class Platform(abc.ABC):
|
||||
def __init__(self, config: dict, event_queue: Queue):
|
||||
def __init__(self, event_queue: Queue):
|
||||
super().__init__()
|
||||
# 平台配置
|
||||
self.config = config
|
||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||
self._event_queue = event_queue
|
||||
self.client_self_id = uuid.uuid4().hex
|
||||
|
||||
# 平台运行状态
|
||||
self._status: PlatformStatus = PlatformStatus.PENDING
|
||||
self._errors: list[PlatformError] = []
|
||||
self._started_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def status(self) -> PlatformStatus:
|
||||
"""获取平台运行状态"""
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, value: PlatformStatus):
|
||||
"""设置平台运行状态"""
|
||||
self._status = value
|
||||
if value == PlatformStatus.RUNNING and self._started_at is None:
|
||||
self._started_at = datetime.now()
|
||||
|
||||
@property
|
||||
def errors(self) -> list[PlatformError]:
|
||||
"""获取错误列表"""
|
||||
return self._errors
|
||||
|
||||
@property
|
||||
def last_error(self) -> PlatformError | None:
|
||||
"""获取最近的错误"""
|
||||
return self._errors[-1] if self._errors else None
|
||||
|
||||
def record_error(self, message: str, traceback_str: str | None = None):
|
||||
"""记录一个错误"""
|
||||
self._errors.append(PlatformError(message=message, traceback=traceback_str))
|
||||
self._status = PlatformStatus.ERROR
|
||||
|
||||
def clear_errors(self):
|
||||
"""清除错误记录"""
|
||||
self._errors.clear()
|
||||
if self._status == PlatformStatus.ERROR:
|
||||
self._status = PlatformStatus.RUNNING
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
"""是否正在使用统一 Webhook 模式"""
|
||||
return bool(
|
||||
self.config.get("unified_webhook_mode", False)
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取平台统计信息"""
|
||||
meta = self.meta()
|
||||
return {
|
||||
"id": meta.id or self.config.get("id"),
|
||||
"type": meta.name,
|
||||
"display_name": meta.adapter_display_name or meta.name,
|
||||
"status": self._status.value,
|
||||
"started_at": self._started_at.isoformat() if self._started_at else None,
|
||||
"error_count": len(self._errors),
|
||||
"last_error": {
|
||||
"message": self.last_error.message,
|
||||
"timestamp": self.last_error.timestamp.isoformat(),
|
||||
"traceback": self.last_error.traceback,
|
||||
}
|
||||
if self.last_error
|
||||
else None,
|
||||
"unified_webhook": self.unified_webhook(),
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
def run(self) -> Awaitable[Any]:
|
||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -124,7 +36,7 @@ class Platform(abc.ABC):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
) -> None:
|
||||
) -> Awaitable[Any]:
|
||||
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||
|
||||
异步方法。
|
||||
@@ -137,20 +49,3 @@ class Platform(abc.ABC):
|
||||
|
||||
def get_client(self):
|
||||
"""获取平台的客户端对象。"""
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口。
|
||||
|
||||
支持统一 Webhook 模式的平台需要实现此方法。
|
||||
当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容,格式取决于具体平台的要求
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 平台未实现统一 Webhook 模式
|
||||
"""
|
||||
raise NotImplementedError(f"平台 {self.meta().name} 未实现统一 Webhook 模式")
|
||||
|
||||
@@ -7,7 +7,7 @@ class PlatformMetadata:
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str
|
||||
id: str | None = None
|
||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||
|
||||
default_config_tmpl: dict | None = None
|
||||
|
||||
@@ -40,7 +40,6 @@ def register_platform_adapter(
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
id=adapter_name,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
|
||||
@@ -70,18 +70,16 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
bot: CQHttp,
|
||||
event: Event | None,
|
||||
is_group: bool,
|
||||
session_id: str | None,
|
||||
session_id: str,
|
||||
messages: list[dict],
|
||||
):
|
||||
# session_id 必须是纯数字字符串
|
||||
session_id_int = (
|
||||
int(session_id) if session_id and session_id.isdigit() else None
|
||||
)
|
||||
session_id = int(session_id) if session_id.isdigit() else None
|
||||
|
||||
if is_group and isinstance(session_id_int, int):
|
||||
await bot.send_group_msg(group_id=session_id_int, message=messages)
|
||||
elif not is_group and isinstance(session_id_int, int):
|
||||
await bot.send_private_msg(user_id=session_id_int, message=messages)
|
||||
if is_group and isinstance(session_id, int):
|
||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||
elif not is_group and isinstance(session_id, int):
|
||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||
elif isinstance(event, Event): # 最后兜底
|
||||
await bot.send(event=event, message=messages)
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
@@ -38,16 +38,18 @@ class AiocqhttpAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.host = platform_config["ws_reverse_host"]
|
||||
self.port = platform_config["ws_reverse_port"]
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=cast(str, self.config.get("id")),
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -126,20 +128,21 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 请求类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||
)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if event.get("group_id"):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
@@ -151,20 +154,23 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 通知类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||
)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if event.get("group_id"):
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = (
|
||||
str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.raw_message = event
|
||||
@@ -187,7 +193,6 @@ class AiocqhttpAdapter(Platform):
|
||||
@param event: 事件对象
|
||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||
"""
|
||||
assert event.sender is not None
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
@@ -197,15 +202,19 @@ class AiocqhttpAdapter(Platform):
|
||||
if event["message_type"] == "group":
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.group = Group(str(event.group_id))
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = (
|
||||
abm.sender.user_id + "_" + str(event.group_id)
|
||||
) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
|
||||
abm.message_id = str(event.message_id)
|
||||
abm.message = []
|
||||
@@ -218,7 +227,7 @@ class AiocqhttpAdapter(Platform):
|
||||
await self.bot.send(event, err)
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
raise ValueError(err)
|
||||
return None
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
@@ -237,13 +246,7 @@ class AiocqhttpAdapter(Platform):
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
# 检查多个可能的文件名字段
|
||||
file_name = (
|
||||
m["data"].get("file_name", "")
|
||||
or m["data"].get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or "file"
|
||||
)
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
@@ -262,14 +265,7 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
# 优先从 API 返回值获取文件名,其次从原始消息数据获取
|
||||
file_name = (
|
||||
ret.get("file_name", "")
|
||||
or ret.get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or m["data"].get("file_name", "")
|
||||
)
|
||||
a = File(name=file_name, url=file_url)
|
||||
a = File(name="", url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
@@ -371,25 +367,10 @@ class AiocqhttpAdapter(Platform):
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
|
||||
message_str += "".join(at_parts)
|
||||
elif t == "markdown":
|
||||
text = m["data"].get("markdown") or m["data"].get("content", "")
|
||||
abm.message.append(Plain(text=text))
|
||||
message_str += text
|
||||
else:
|
||||
for m in m_group:
|
||||
try:
|
||||
if t not in ComponentTypes:
|
||||
logger.warning(
|
||||
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
|
||||
)
|
||||
continue
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"消息段解析失败: type={t}, data={m['data']}. {e}"
|
||||
)
|
||||
continue
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
@@ -422,7 +403,7 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
await self.shutdown_event.wait()
|
||||
logger.info("aiocqhttp 适配器已被关闭")
|
||||
logger.info("aiocqhttp 适配器已被优雅地关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -48,19 +47,21 @@ class DingtalkPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.client_id = platform_config["client_id"]
|
||||
self.client_secret = platform_config["client_secret"]
|
||||
|
||||
outer_self = self
|
||||
|
||||
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||
async def process(self, message: dingtalk_stream.CallbackMessage):
|
||||
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
||||
logger.debug(f"dingtalk: {message.data}")
|
||||
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||
abm = await outer_self.convert_msg(im)
|
||||
await outer_self.handle_msg(abm)
|
||||
abm = await self.convert_msg(im)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
@@ -74,15 +75,14 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self.client,
|
||||
)
|
||||
self.client_ = client # 用于 websockets 的 client
|
||||
self._shutdown_event: threading.Event | None = None
|
||||
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
|
||||
if not dingtalk_id:
|
||||
return dingtalk_id or "unknown"
|
||||
return dingtalk_id
|
||||
prefix = "$:LWCP_v1:$"
|
||||
if dingtalk_id.startswith(prefix):
|
||||
return dingtalk_id[len(prefix) :]
|
||||
return dingtalk_id or "unknown"
|
||||
return dingtalk_id
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
@@ -95,7 +95,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=cast(str, self.config.get("id")),
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -106,7 +106,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
abm = AstrBotMessage()
|
||||
abm.message = []
|
||||
abm.message_str = ""
|
||||
abm.timestamp = int(cast(int, message.create_at) / 1000)
|
||||
abm.timestamp = int(message.create_at / 1000)
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if message.conversation_type == "2"
|
||||
@@ -117,7 +117,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
nickname=message.sender_nick,
|
||||
)
|
||||
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
||||
abm.message_id = cast(str, message.message_id)
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
@@ -127,20 +127,21 @@ class DingtalkPlatformAdapter(Platform):
|
||||
if id := self._id_to_sid(user.dingtalk_id):
|
||||
abm.message.append(At(qq=id))
|
||||
abm.group_id = message.conversation_id
|
||||
abm.session_id = abm.group_id
|
||||
if self.unique_session:
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
message_type: str = cast(str, message.message_type)
|
||||
message_type: str = message.message_type
|
||||
match message_type:
|
||||
case "text":
|
||||
abm.message_str = message.text.content.strip()
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
case "richText":
|
||||
rtc: dingtalk_stream.RichTextContent = cast(
|
||||
dingtalk_stream.RichTextContent, message.rich_text_content
|
||||
)
|
||||
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
|
||||
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
||||
contents: list[dict] = rtc.rich_text_list
|
||||
for content in contents:
|
||||
plains = ""
|
||||
if "text" in content:
|
||||
@@ -149,7 +150,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
elif "type" in content and content["type"] == "picture":
|
||||
f_path = await self.download_ding_file(
|
||||
content["downloadCode"],
|
||||
cast(str, message.robot_code),
|
||||
message.robot_code,
|
||||
"jpg",
|
||||
)
|
||||
abm.message.append(Image.fromFileSystem(f_path))
|
||||
@@ -194,7 +195,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return ""
|
||||
return None
|
||||
resp_data = await resp.json()
|
||||
download_url = resp_data["data"]["downloadUrl"]
|
||||
await download_file(download_url, f_path)
|
||||
@@ -214,7 +215,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return ""
|
||||
return None
|
||||
return (await resp.json())["data"]["accessToken"]
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
@@ -240,7 +241,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
task.result()
|
||||
except Exception as e:
|
||||
if "Graceful shutdown" in str(e):
|
||||
logger.info("钉钉适配器已被关闭")
|
||||
logger.info("钉钉适配器已被优雅地关闭")
|
||||
return
|
||||
logger.error(f"钉钉机器人启动失败: {e}")
|
||||
|
||||
@@ -249,13 +250,11 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
async def terminate(self):
|
||||
def monkey_patch_close():
|
||||
raise KeyboardInterrupt("Graceful shutdown")
|
||||
raise Exception("Graceful shutdown")
|
||||
|
||||
if self.client_.websocket is not None:
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
self._shutdown_event.set()
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
import dingtalk_stream
|
||||
|
||||
@@ -25,20 +24,6 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client: dingtalk_stream.ChatbotHandler,
|
||||
message: MessageChain,
|
||||
):
|
||||
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
|
||||
ats = []
|
||||
# fixes: #4218
|
||||
# 钉钉 at 机器人需要使用 sender_staff_id 而不是 sender_id
|
||||
for i in message.chain:
|
||||
if isinstance(i, Comp.At):
|
||||
print(i.qq, icm.sender_id, icm.sender_staff_id)
|
||||
if str(i.qq) in str(icm.sender_id or ""):
|
||||
# 适配器会将开头的 $:LWCP_v1:$ 去掉,因此我们用 in 判断
|
||||
ats.append(f"@{icm.sender_staff_id}")
|
||||
else:
|
||||
ats.append(f"@{i.qq}")
|
||||
at_str = " ".join(ats)
|
||||
|
||||
for segment in message.chain:
|
||||
if isinstance(segment, Comp.Plain):
|
||||
segment.text = segment.text.strip()
|
||||
@@ -46,8 +31,8 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
None,
|
||||
client.reply_markdown,
|
||||
segment.text,
|
||||
f"{at_str} {segment.text}".strip(),
|
||||
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
@@ -68,9 +53,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
cast(
|
||||
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
|
||||
),
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import discord
|
||||
|
||||
@@ -28,16 +27,13 @@ class DiscordBotClient(discord.Bot):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
|
||||
# 回调函数
|
||||
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
|
||||
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
|
||||
self.on_message_received = None
|
||||
self.on_ready_once_callback = None
|
||||
self._ready_once_fired = False
|
||||
|
||||
@override
|
||||
async def on_ready(self):
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
if self.user is None:
|
||||
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
|
||||
return
|
||||
|
||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||
logger.info("[Discord] 客户端已准备就绪。")
|
||||
|
||||
@@ -53,9 +49,6 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
is_mentioned = self.user in message.mentions
|
||||
return {
|
||||
"message": message,
|
||||
@@ -73,12 +66,6 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||
"""从 discord.Interaction 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
if interaction.user is None:
|
||||
raise ValueError("Interaction received without a valid user")
|
||||
|
||||
return {
|
||||
"interaction": interaction,
|
||||
"bot_id": str(self.user.id),
|
||||
@@ -93,6 +80,7 @@ class DiscordBotClient(discord.Bot):
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
|
||||
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
components: list[BaseMessageComponent] | None = None,
|
||||
timeout: float | None = None,
|
||||
components: list[BaseMessageComponent] = None,
|
||||
timeout: float = None,
|
||||
):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
from discord.abc import GuildChannel, Messageable, PrivateChannel
|
||||
from discord.abc import Messageable
|
||||
from discord.channel import DMChannel
|
||||
|
||||
from astrbot import logger
|
||||
@@ -44,9 +44,10 @@ class DiscordPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.client_self_id: str | None = None
|
||||
self.client_self_id = None
|
||||
self.registered_handlers = []
|
||||
# 指令注册相关
|
||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||
@@ -62,12 +63,6 @@ class DiscordPlatformAdapter(Platform):
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 创建一个 message_obj 以便在 event 中使用
|
||||
message_obj = AstrBotMessage()
|
||||
if "_" in session.session_id:
|
||||
@@ -95,7 +90,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
user_id=str(self.client_self_id),
|
||||
nickname=self.client.user.display_name,
|
||||
)
|
||||
message_obj.self_id = cast(str, self.client_self_id)
|
||||
message_obj.self_id = self.client_self_id
|
||||
message_obj.session_id = session.session_id
|
||||
message_obj.message = message_chain.chain
|
||||
|
||||
@@ -116,7 +111,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"discord",
|
||||
"Discord 适配器",
|
||||
id=cast(str, self.config.get("id")),
|
||||
id=self.config.get("id"),
|
||||
default_config_tmpl=self.config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
@@ -166,7 +161,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
|
||||
def _get_message_type(
|
||||
self,
|
||||
channel: Messageable | GuildChannel | PrivateChannel,
|
||||
channel: Messageable,
|
||||
guild_id: int | None = None,
|
||||
) -> MessageType:
|
||||
"""根据 channel 对象和 guild_id 判断消息类型"""
|
||||
@@ -176,15 +171,13 @@ class DiscordPlatformAdapter(Platform):
|
||||
return MessageType.FRIEND_MESSAGE
|
||||
return MessageType.GROUP_MESSAGE
|
||||
|
||||
def _get_channel_id(
|
||||
self, channel: Messageable | GuildChannel | PrivateChannel
|
||||
) -> str:
|
||||
def _get_channel_id(self, channel: Messageable) -> str:
|
||||
"""根据 channel 对象获取ID"""
|
||||
return str(getattr(channel, "id", None))
|
||||
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message = data["message"]
|
||||
message: discord.Message = data["message"]
|
||||
|
||||
content = message.content
|
||||
|
||||
@@ -241,7 +234,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = str(message.channel.id)
|
||||
abm.message_id = str(message.id)
|
||||
return abm
|
||||
@@ -262,52 +255,32 @@ class DiscordPlatformAdapter(Platform):
|
||||
interaction_followup_webhook=followup_webhook,
|
||||
)
|
||||
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否为斜杠指令
|
||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||
|
||||
# 1. 优先处理斜杠指令
|
||||
if is_slash_command:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
self.commit_event(message_event)
|
||||
return
|
||||
|
||||
# 2. 处理普通消息(提及检测)
|
||||
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
|
||||
raw_message = message.raw_message
|
||||
if not isinstance(raw_message, discord.Message):
|
||||
logger.warning(
|
||||
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||
is_mention = False
|
||||
|
||||
# User Mention
|
||||
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
|
||||
if self.client.user in raw_message.mentions:
|
||||
is_mention = True
|
||||
|
||||
if (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
):
|
||||
if self.client.user in message.raw_message.mentions:
|
||||
is_mention = True
|
||||
# Role Mention(Bot 拥有的角色被提及)
|
||||
if not is_mention and raw_message.role_mentions:
|
||||
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
||||
bot_member = None
|
||||
if raw_message.guild:
|
||||
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
||||
try:
|
||||
bot_member = raw_message.guild.get_member(
|
||||
bot_member = message.raw_message.guild.get_member(
|
||||
self.client.user.id,
|
||||
)
|
||||
except Exception:
|
||||
bot_member = None
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
bot_roles = set(bot_member.roles)
|
||||
mentioned_roles = set(raw_message.role_mentions)
|
||||
mentioned_roles = set(message.raw_message.role_mentions)
|
||||
if (
|
||||
bot_roles
|
||||
and mentioned_roles
|
||||
@@ -315,8 +288,8 @@ class DiscordPlatformAdapter(Platform):
|
||||
):
|
||||
is_mention = True
|
||||
|
||||
# 如果是被@的消息,设置为唤醒状态
|
||||
if is_mention:
|
||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||
if is_slash_command or is_mention:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
|
||||
@@ -452,7 +425,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = [Plain(text=message_str_for_filter)]
|
||||
abm.raw_message = ctx.interaction
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = str(ctx.channel_id)
|
||||
abm.message_id = str(ctx.interaction.id)
|
||||
|
||||
@@ -465,7 +438,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
def _extract_command_info(
|
||||
event_filter: Any,
|
||||
handler_metadata: StarHandlerMetadata,
|
||||
) -> tuple[str, str, CommandFilter | None] | None:
|
||||
) -> tuple[str, str, CommandFilter] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
# is_group = False
|
||||
|
||||
@@ -4,10 +4,8 @@ import binascii
|
||||
from collections.abc import AsyncGenerator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import discord
|
||||
from discord.types.interactions import ComponentInteractionData
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -87,9 +85,6 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
channel = await self._get_channel()
|
||||
if not channel:
|
||||
return
|
||||
if not isinstance(channel, discord.abc.Messageable):
|
||||
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
|
||||
return
|
||||
await channel.send(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
@@ -112,9 +107,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _get_channel(
|
||||
self,
|
||||
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
|
||||
async def _get_channel(self) -> discord.abc.Messageable | None:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
channel_id = int(self.session_id)
|
||||
@@ -128,13 +121,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
async def _parse_to_discord(
|
||||
self,
|
||||
message: MessageChain,
|
||||
) -> tuple[
|
||||
str,
|
||||
list[discord.File],
|
||||
discord.ui.View | None,
|
||||
list[discord.Embed],
|
||||
str | int | None,
|
||||
]:
|
||||
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
|
||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||
content_parts = []
|
||||
files = []
|
||||
@@ -274,9 +261,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"add_reaction",
|
||||
):
|
||||
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
|
||||
emoji
|
||||
)
|
||||
await self.message_obj.raw_message.add_reaction(emoji)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||
|
||||
@@ -285,7 +270,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
and self.message_obj.raw_message.type
|
||||
== discord.InteractionType.application_command
|
||||
)
|
||||
|
||||
@@ -294,18 +279,14 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
== discord.InteractionType.component
|
||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
||||
)
|
||||
|
||||
def get_interaction_custom_id(self) -> str:
|
||||
"""获取交互组件的custom_id"""
|
||||
if self.is_button_interaction():
|
||||
try:
|
||||
return cast(
|
||||
ComponentInteractionData,
|
||||
cast(discord.Interaction, self.message_obj.raw_message).data,
|
||||
).get("custom_id", "")
|
||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
@@ -318,9 +299,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
):
|
||||
return any(
|
||||
mention.id == int(self.message_obj.self_id)
|
||||
for mention in cast(
|
||||
discord.Message, self.message_obj.raw_message
|
||||
).mentions
|
||||
for mention in self.message_obj.raw_message.mentions
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -330,5 +309,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"clean_content",
|
||||
):
|
||||
return cast(discord.Message, self.message_obj.raw_message).clean_content
|
||||
return self.message_obj.raw_message.clean_content
|
||||
return self.message_str
|
||||
|
||||
@@ -2,17 +2,10 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
GetMessageResourceRequest,
|
||||
)
|
||||
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
|
||||
from lark_oapi.api.im.v1 import *
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot import logger
|
||||
@@ -25,11 +18,9 @@ from astrbot.api.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .lark_event import LarkMessageEvent
|
||||
from .server import LarkWebhookServer
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -42,20 +33,20 @@ class LarkPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.appid = platform_config["app_id"]
|
||||
self.appsecret = platform_config["app_secret"]
|
||||
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
|
||||
self.bot_name = platform_config.get("lark_bot_name", "astrbot")
|
||||
|
||||
# socket or webhook
|
||||
self.connection_mode = platform_config.get("lark_connection_mode", "socket")
|
||||
|
||||
if not self.bot_name:
|
||||
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
|
||||
|
||||
# 初始化 WebSocket 长连接相关配置
|
||||
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
await self.convert_msg(event)
|
||||
|
||||
@@ -68,8 +59,6 @@ class LarkPlatformAdapter(Platform):
|
||||
.build()
|
||||
)
|
||||
|
||||
self.do_v2_msg_event = do_v2_msg_event
|
||||
|
||||
self.client = lark.ws.Client(
|
||||
app_id=self.appid,
|
||||
app_secret=self.appsecret,
|
||||
@@ -79,56 +68,14 @@ class LarkPlatformAdapter(Platform):
|
||||
)
|
||||
|
||||
self.lark_api = (
|
||||
lark.Client.builder()
|
||||
.app_id(self.appid)
|
||||
.app_secret(self.appsecret)
|
||||
.log_level(lark.LogLevel.ERROR)
|
||||
.domain(self.domain)
|
||||
.build()
|
||||
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
|
||||
)
|
||||
|
||||
self.webhook_server = None
|
||||
if self.connection_mode == "webhook":
|
||||
self.webhook_server = LarkWebhookServer(platform_config, event_queue)
|
||||
self.webhook_server.set_callback(self.handle_webhook_event)
|
||||
|
||||
self.event_id_timestamps: dict[str, float] = {}
|
||||
|
||||
def _clean_expired_events(self):
|
||||
"""清理超过 30 分钟的事件记录"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
event_id
|
||||
for event_id, timestamp in self.event_id_timestamps.items()
|
||||
if current_time - timestamp > 1800
|
||||
]
|
||||
for event_id in expired_keys:
|
||||
del self.event_id_timestamps[event_id]
|
||||
|
||||
def _is_duplicate_event(self, event_id: str) -> bool:
|
||||
"""检查事件是否重复
|
||||
|
||||
Args:
|
||||
event_id: 事件ID
|
||||
|
||||
Returns:
|
||||
True 表示重复事件,False 表示新事件
|
||||
"""
|
||||
self._clean_expired_events()
|
||||
if event_id in self.event_id_timestamps:
|
||||
return True
|
||||
self.event_id_timestamps[event_id] = time.time()
|
||||
return False
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||
return
|
||||
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
@@ -169,25 +116,14 @@ class LarkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=cast(str, self.config.get("id")),
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
if event.event is None:
|
||||
logger.debug("[Lark] 收到空事件(event.event is None)")
|
||||
return
|
||||
message = event.event.message
|
||||
if message is None:
|
||||
logger.debug("[Lark] 事件中没有消息体(message is None)")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
|
||||
if message.create_time:
|
||||
abm.timestamp = int(message.create_time) // 1000
|
||||
else:
|
||||
abm.timestamp = int(time.time())
|
||||
abm.timestamp = int(message.create_time) / 1000
|
||||
abm.message = []
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
@@ -202,28 +138,14 @@ class LarkPlatformAdapter(Platform):
|
||||
at_list = {}
|
||||
if message.mentions:
|
||||
for m in message.mentions:
|
||||
if m.id is None:
|
||||
continue
|
||||
# 飞书 open_id 可能是 None,这里做个防护
|
||||
open_id = m.id.open_id if m.id.open_id else ""
|
||||
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
|
||||
|
||||
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
||||
if m.name == self.bot_name:
|
||||
if m.id.open_id is not None:
|
||||
abm.self_id = m.id.open_id
|
||||
abm.self_id = m.id.open_id
|
||||
|
||||
if message.content is None:
|
||||
logger.warning("[Lark] 消息内容为空")
|
||||
return
|
||||
|
||||
try:
|
||||
content_json_b = json.loads(message.content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
|
||||
return
|
||||
content_json_b = json.loads(message.content)
|
||||
|
||||
if message.message_type == "text":
|
||||
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
|
||||
message_str_raw = content_json_b["text"] # 带有 @ 的消息
|
||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||
# at_users = re.findall(at_pattern, message_str_raw)
|
||||
# 拆分文本,去掉AT符号部分
|
||||
@@ -248,47 +170,27 @@ class LarkPlatformAdapter(Platform):
|
||||
content_json_b = _ls
|
||||
elif message.message_type == "image":
|
||||
content_json_b = [
|
||||
{
|
||||
"tag": "img",
|
||||
"image_key": content_json_b.get("image_key"),
|
||||
"style": [],
|
||||
},
|
||||
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
|
||||
]
|
||||
|
||||
if message.message_type in ("post", "image"):
|
||||
for comp in content_json_b:
|
||||
if comp.get("tag") == "at":
|
||||
user_id = comp.get("user_id")
|
||||
if user_id in at_list:
|
||||
abm.message.append(at_list[user_id])
|
||||
elif comp.get("tag") == "text" and comp.get("text", "").strip():
|
||||
if comp["tag"] == "at":
|
||||
abm.message.append(at_list[comp["user_id"]])
|
||||
elif comp["tag"] == "text" and comp["text"].strip():
|
||||
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||
elif comp.get("tag") == "img":
|
||||
image_key = comp.get("image_key")
|
||||
if not image_key:
|
||||
continue
|
||||
|
||||
elif comp["tag"] == "img":
|
||||
image_key = comp["image_key"]
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(cast(str, message.message_id))
|
||||
.message_id(message.message_id)
|
||||
.file_key(image_key)
|
||||
.type("image")
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
continue
|
||||
|
||||
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法下载飞书图片: {image_key}")
|
||||
continue
|
||||
|
||||
if response.file is None:
|
||||
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
|
||||
continue
|
||||
|
||||
image_bytes = response.file.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode()
|
||||
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||
@@ -296,27 +198,20 @@ class LarkPlatformAdapter(Platform):
|
||||
for comp in abm.message:
|
||||
if isinstance(comp, Comp.Plain):
|
||||
abm.message_str += comp.text
|
||||
|
||||
if message.message_id is None:
|
||||
logger.error("[Lark] 消息缺少 message_id")
|
||||
return
|
||||
|
||||
if (
|
||||
event.event.sender is None
|
||||
or event.event.sender.sender_id is None
|
||||
or event.event.sender.sender_id.open_id is None
|
||||
):
|
||||
logger.error("[Lark] 消息发送者信息不完整")
|
||||
return
|
||||
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
abm.sender = MessageMember(
|
||||
user_id=event.event.sender.sender_id.open_id,
|
||||
nickname=event.event.sender.sender_id.open_id[:8],
|
||||
)
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
# 独立会话
|
||||
if not self.unique_session:
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
@@ -334,61 +229,13 @@ class LarkPlatformAdapter(Platform):
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
async def handle_webhook_event(self, event_data: dict):
|
||||
"""处理 Webhook 事件
|
||||
|
||||
Args:
|
||||
event_data: Webhook 事件数据
|
||||
"""
|
||||
try:
|
||||
header = event_data.get("header", {})
|
||||
event_id = header.get("event_id", "")
|
||||
if event_id and self._is_duplicate_event(event_id):
|
||||
logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
|
||||
return
|
||||
event_type = header.get("event_type", "")
|
||||
if event_type == "im.message.receive_v1":
|
||||
processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
|
||||
data = (processor.type())(event_data)
|
||||
processor.do(data)
|
||||
else:
|
||||
logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
|
||||
|
||||
async def run(self):
|
||||
if self.connection_mode == "webhook":
|
||||
# Webhook 模式
|
||||
if self.webhook_server is None:
|
||||
logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
|
||||
return
|
||||
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
|
||||
else:
|
||||
logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
|
||||
else:
|
||||
# 长连接模式
|
||||
await self.client._connect()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
if not self.webhook_server:
|
||||
return {"error": "Webhook server not initialized"}, 500
|
||||
|
||||
return await self.webhook_server.handle_callback(request)
|
||||
# self.client.start()
|
||||
await self.client._connect()
|
||||
|
||||
async def terminate(self):
|
||||
if self.connection_mode == "socket":
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已关闭")
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||
|
||||
def get_client(self) -> lark.ws.Client:
|
||||
def get_client(self) -> lark.Client:
|
||||
return self.client
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
return bool(
|
||||
self.config.get("lark_connection_mode", "") == "webhook"
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user