Compare commits
108 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e14ed804da | |||
| 8e4e49df20 | |||
| 5d856900ef | |||
| 380a68b96c | |||
| 8879bd7e9d | |||
| 2cce09400f | |||
| 54d26dcd38 | |||
| 205024f27a | |||
| efde994907 | |||
| 8ca4f9cb74 | |||
| 54e49b997b | |||
| 5714944eef | |||
| defc46b6c9 | |||
| 4d819546b0 | |||
| 8006981976 | |||
| f7a716af43 | |||
| a708901e7f | |||
| e9be8cf69f | |||
| 31d53edb9d | |||
| 2ba0460f19 | |||
| 0e034f0fbd | |||
| 2a7d03f9e1 | |||
| 72fac4b9f1 | |||
| 38281ba2cf | |||
| 21aa3174f4 | |||
| dcda871fc0 | |||
| c13c51f499 | |||
| a130db5cf4 | |||
| 7faeb5cea8 | |||
| 8d3ff61e0d | |||
| 4c03e82570 | |||
| e7e8664ab4 | |||
| 1dd1623e7d | |||
| 80d8161d58 | |||
| fc80d7d681 | |||
| c2f036b27c | |||
| 4087bbb512 | |||
| e1c728582d | |||
| 93c69a639a | |||
| a7fdc98b29 | |||
| 85b7f104df | |||
| d76d1bd7fe | |||
| df4412aa80 | |||
| ab2c94e19a | |||
| 37cc4e2121 | |||
| 60dfdd0a66 | |||
| bb8b2cb194 | |||
| 4e29684aa3 | |||
| 0e17e3553d | |||
| 0a55060e89 | |||
| 77859c7daa | |||
| ba39c393a0 | |||
| 6a50d316d9 | |||
| 88c1d77f0b | |||
| 758ce40cc1 | |||
| 3e7bb80492 | |||
| 75e95aa9ca | |||
| a389842e25 | |||
| 0f6a3c3f5a | |||
| 133f27422d | |||
| abc6deb244 | |||
| 06869b4597 | |||
| d32cea9870 | |||
| 4b68100f16 | |||
| 5c5515d462 | |||
| 3932b8f982 | |||
| 82488ca900 | |||
| 29d9b9b2d6 | |||
| 02215e9b7b | |||
| 7160b7a18b | |||
| ea8dac837a | |||
| e2a7a028bd | |||
| 70db8d264b | |||
| 0518e6d487 | |||
| 39eb367866 | |||
| f1d51a22ad | |||
| 77fb554e8f | |||
| 91f8a0ae09 | |||
| 370cda7cf0 | |||
| 66b3eed273 | |||
| 99b061a143 | |||
| 5f3c7ed673 | |||
| a6dc458212 | |||
| 520f521887 | |||
| 01427d9969 | |||
| 34c03ce983 | |||
| 95e9da42d6 | |||
| 1338cab61b | |||
| 7ba98c1e91 | |||
| 9a5f507cbe | |||
| d560671d1f | |||
| 82c9cf4db6 | |||
| 910ec6c695 | |||
| 766d6f2bec | |||
| 9f39140987 | |||
| 89716ef4da | |||
| 3c4ea5a339 | |||
| 601846a8c1 | |||
| 85d66c1056 | |||
| b89d3f663c | |||
| 0260d430d1 | |||
| 2e608cdc09 | |||
| 234ce93dc1 | |||
| 2ada1deb9a | |||
| 788ceb9721 | |||
| 61a68477d0 | |||
| e74f626383 | |||
| ef99f64291 |
@@ -13,7 +13,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dashboard Build
|
||||
run: |
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
@@ -118,7 +118,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
@@ -34,6 +34,7 @@ dashboard/node_modules/
|
||||
dashboard/dist/
|
||||
package-lock.json
|
||||
package.json
|
||||
yarn.lock
|
||||
|
||||
# Operating System
|
||||
**/.DS_Store
|
||||
@@ -47,3 +48,5 @@ astrbot.lock
|
||||
chroma
|
||||
venv/*
|
||||
pytest.ini
|
||||
AGENTS.md
|
||||
IFLOW.md
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
# CONTRIBUTING
|
||||
|
||||
## 贡献指南
|
||||
|
||||
首先,感谢您花时间做出贡献!❤️
|
||||
|
||||
所有类型的贡献都受到鼓励和重视。有关不同的帮助方式和处理方式的详细信息,请参阅[目录](#目录)。在做出贡献之前,请确保阅读相关部分。这将使我们维护人员的工作变得更加容易,并为所有参与者带来顺畅的体验。社区期待您的贡献。🎉
|
||||
|
||||
### 目录
|
||||
|
||||
- [报告问题](#报告问题)
|
||||
- [提交代码更改](#提交代码更改)
|
||||
|
||||
### 报告问题
|
||||
|
||||
如果您在使用 AstrBot 时遇到任何问题,请按照以下步骤报告:
|
||||
|
||||
1. **检查现有问题**:在提交新问题之前,请先检查 [Issues](https://github.com/AstrBotDevs/AstrBot/issues) 中是否已经存在类似的问题。
|
||||
2. **创建新问题**:如果没有类似的问题,请创建一个新问题。请确保提供以下信息:
|
||||
- 问题的简要描述
|
||||
- 重现问题的步骤
|
||||
- 预期结果和实际结果
|
||||
- 相关日志或错误消息
|
||||
|
||||
### 提交代码更改
|
||||
|
||||
#### 分支命名
|
||||
|
||||
我们使用 `fix/` 前缀来修复错误,使用 `feat/` 前缀来添加新功能。对于 `fix/` 分支,请使用简短的描述,或者直接使用 Issue 编号。例如:`fix/1234` 或者 `fix/1234-login-typo`。对于 `feat/` 分支,请使用简短的描述,例如:`feat/add-user-profile`。
|
||||
|
||||
#### PR 描述
|
||||
|
||||
- 请使用英文描述您的 PR。
|
||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
First off, thanks for taking the time to contribute! ❤️
|
||||
|
||||
All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- [Reporting Issues](#reporting-issues)
|
||||
- [Pull Requests](#pull-requests)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
If you encounter any issues while using AstrBot, please follow these steps to report them:
|
||||
1. **Check Existing Issues**: Before submitting a new issue, please check if a similar issue already exists in the [Issues](https://github.com/AstrBotDevs/AstrBot/issues) section of the repository.
|
||||
2. **Create a New Issue**: If no similar issue exists, please create a new issue. Make sure to provide the following information:
|
||||
- A brief description of the issue
|
||||
- Steps to reproduce the issue
|
||||
- Expected and actual results
|
||||
- Relevant logs or error messages
|
||||
|
||||
### Pull Requests
|
||||
|
||||
#### Branch Naming
|
||||
|
||||
We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features. For `fix/` branches, please use a short description or directly use the Issue number, e.g., `fix/1234` or `fix/1234-login-typo`. For `feat/` branches, please use a short description, e.g., `feat/add-user-profile`.
|
||||
|
||||
#### PR Description
|
||||
- Please use English to describe your PR.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
@@ -1,10 +1,13 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
@@ -14,35 +17,37 @@
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">文档</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## 主要功能
|
||||
|
||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
1. 💯 免费 & 开源。
|
||||
1. ✨ AI 大模型对话,多模态,Agent,MCP,知识库,人格设定。
|
||||
2. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。
|
||||
2. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。
|
||||
3. 📦 插件扩展,已有近 800 个插件可一键安装。
|
||||
5. 💻 WebUI 支持。
|
||||
6. 🌐 国际化(i18n)支持。
|
||||
|
||||
## 部署方式
|
||||
## 快速开始
|
||||
|
||||
#### Docker 部署(推荐 🥳)
|
||||
|
||||
@@ -50,6 +55,12 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
|
||||
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
#### uv 部署
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
AstrBot 与宝塔面板合作,已上架至宝塔面板。
|
||||
@@ -101,24 +112,6 @@ uv run main.py
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
## 🌍 社区
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群组
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## 支持的消息平台
|
||||
|
||||
**官方维护**
|
||||
@@ -205,6 +198,24 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 社区
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群组
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
|
||||
+40
-26
@@ -19,30 +19,38 @@
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
|
||||
</div>
|
||||
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform and development framework.
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform that integrates with mainstream instant messaging apps. It provides reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether you're building a personal AI companion, intelligent customer service, automation assistant, or enterprise knowledge base, AstrBot enables you to quickly build production-ready AI applications within your IM platform workflows.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Key Features
|
||||
|
||||
1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more.
|
||||
2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation.
|
||||
3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more.
|
||||
4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem.
|
||||
5. **Web UI**. Visual configuration and management of your bot with comprehensive features.
|
||||
1. 💯 Free & Open Source.
|
||||
2. ✨ AI LLM Conversations, Multimodal, Agent, MCP, Knowledge Base, Persona Settings.
|
||||
3. 🤖 Supports integration with Dify, Alibaba Cloud Bailian, Coze and other agent platforms.
|
||||
4. 🌐 Multi-Platform: QQ, WeChat Work, Feishu, DingTalk, WeChat Official Accounts, Telegram, Slack, and [more](#supported-messaging-platforms).
|
||||
5. 📦 Plugin Extensions with nearly 800 plugins available for one-click installation.
|
||||
6. 💻 WebUI Support.
|
||||
7. 🌐 Internationalization (i18n) Support.
|
||||
|
||||
## Deployment Methods
|
||||
## Quick Start
|
||||
|
||||
#### Docker Deployment (Recommended 🥳)
|
||||
|
||||
@@ -50,6 +58,12 @@ We recommend deploying AstrBot using Docker or Docker Compose.
|
||||
|
||||
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### uv Deployment
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### BT-Panel Deployment
|
||||
|
||||
AstrBot has partnered with BT-Panel and is now available in their marketplace.
|
||||
@@ -101,24 +115,6 @@ uv run main.py
|
||||
|
||||
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## 🌍 Community
|
||||
|
||||
### QQ Groups
|
||||
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
### Telegram Group
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord Server
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## Supported Messaging Platforms
|
||||
|
||||
**Officially Maintained**
|
||||
@@ -205,6 +201,24 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Community
|
||||
|
||||
### QQ Groups
|
||||
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
### Telegram Group
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord Server
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
|
||||
|
||||
+248
@@ -0,0 +1,248 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Feuille de route</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Signaler un problème</a>
|
||||
</div>
|
||||
|
||||
AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègre aux principales applications de messagerie instantanée. Elle fournit une infrastructure d'IA conversationnelle fiable et évolutive pour les particuliers, les développeurs et les équipes. Que vous construisiez un compagnon IA personnel, un service client intelligent, un assistant d'automatisation ou une base de connaissances d'entreprise, AstrBot vous permet de créer rapidement des applications d'IA prêtes pour la production dans les flux de travail de votre plateforme de messagerie.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Fonctionnalités principales
|
||||
|
||||
1. 💯 Gratuit & Open Source.
|
||||
2. ✨ Conversations avec LLM IA, Multimodal, Agent, MCP, Base de connaissances, Paramètres de personnalité.
|
||||
3. 🤖 Prise en charge de l'intégration avec Dify, Alibaba Cloud Bailian, Coze et autres plateformes d'agents.
|
||||
4. 🌐 Multi-plateforme : QQ, WeChat Work, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack, et [plus encore](#plateformes-de-messagerie-prises-en-charge).
|
||||
5. 📦 Extensions de plugins avec près de 800 plugins disponibles pour une installation en un clic.
|
||||
6. 💻 Support WebUI.
|
||||
7. 🌐 Support de l'internationalisation (i18n).
|
||||
|
||||
## Démarrage rapide
|
||||
|
||||
#### Déploiement Docker (Recommandé 🥳)
|
||||
|
||||
Nous recommandons de déployer AstrBot en utilisant Docker ou Docker Compose.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### Déploiement uv
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### Déploiement BT-Panel
|
||||
|
||||
AstrBot s'est associé à BT-Panel et est maintenant disponible sur leur marketplace.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### Déploiement 1Panel
|
||||
|
||||
AstrBot a été officiellement listé sur le marketplace 1Panel.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Déployer sur RainYun
|
||||
|
||||
AstrBot a été officiellement listé sur la plateforme d'applications cloud de RainYun avec un déploiement en un clic.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Déployer sur Replit
|
||||
|
||||
Méthode de déploiement contribuée par la communauté.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Installateur Windows en un clic
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déployer AstrBot avec l'installateur Windows en un clic](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### Déploiement CasaOS
|
||||
|
||||
Méthode de déploiement contribuée par la communauté.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Déploiement manuel
|
||||
|
||||
Tout d'abord, installez uv :
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Installez AstrBot via Git Clone :
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## Plateformes de messagerie prises en charge
|
||||
|
||||
**Maintenues officiellement**
|
||||
|
||||
- QQ (Plateforme officielle & OneBot)
|
||||
- Telegram
|
||||
- Application WeChat Work & Bot intelligent WeChat Work
|
||||
- Service client WeChat & Comptes officiels WeChat
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Bientôt disponible)
|
||||
- LINE (Bientôt disponible)
|
||||
|
||||
**Maintenues par la communauté**
|
||||
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## Services de modèles pris en charge
|
||||
|
||||
**Services LLM**
|
||||
|
||||
- OpenAI et services compatibles
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Auto-hébergé)
|
||||
- LM Studio (Auto-hébergé)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**Plateformes LLMOps**
|
||||
|
||||
- Dify
|
||||
- Applications Alibaba Cloud Bailian
|
||||
- Coze
|
||||
|
||||
**Services de reconnaissance vocale**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**Services de synthèse vocale**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
## ❤️ Contribuer
|
||||
|
||||
Les Issues et Pull Requests sont toujours les bienvenues ! N'hésitez pas à soumettre vos modifications à ce projet :)
|
||||
|
||||
### Comment contribuer
|
||||
|
||||
Vous pouvez contribuer en examinant les issues ou en aidant à la revue des pull requests. Toutes les issues ou PRs sont les bienvenues pour encourager la participation de la communauté. Bien sûr, ce ne sont que des suggestions - vous pouvez contribuer de la manière que vous souhaitez. Pour l'ajout de nouvelles fonctionnalités, veuillez d'abord en discuter via une Issue.
|
||||
|
||||
### Environnement de développement
|
||||
|
||||
AstrBot utilise `ruff` pour le formatage et le linting du code.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Communauté
|
||||
|
||||
### Groupes QQ
|
||||
|
||||
- Groupe 1 : 322154837
|
||||
- Groupe 3 : 630166526
|
||||
- Groupe 5 : 822130018
|
||||
- Groupe 6 : 753075035
|
||||
- Groupe développeurs : 975206796
|
||||
|
||||
### Groupe Telegram
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Serveur Discord
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Remerciements spéciaux
|
||||
|
||||
Un grand merci à tous les contributeurs et développeurs de plugins pour leurs contributions à AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
De plus, la naissance de ce projet n'aurait pas été possible sans l'aide des projets open source suivants :
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - L'incroyable framework chat
|
||||
|
||||
## ⭐ Historique des étoiles
|
||||
|
||||
> [!TIP]
|
||||
> Si ce projet vous a aidé dans votre vie ou votre travail, ou si vous êtes intéressé par son développement futur, veuillez donner une étoile au projet. C'est la force motrice derrière la maintenance de ce projet open source <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
+40
-26
@@ -19,30 +19,38 @@
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E3%83%97%E3%83%A9%E3%82%B0%E3%82%A4%E3%83%B3&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">ドキュメント</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue</a>
|
||||
</div>
|
||||
|
||||
AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークです。
|
||||
AstrBot は、主要なインスタントメッセージングアプリと統合できるオープンソースのオールインワン Agent チャットボットプラットフォームです。個人、開発者、チームに信頼性が高くスケーラブルな会話型 AI インフラストラクチャを提供します。パーソナル AI コンパニオン、インテリジェントカスタマーサービス、オートメーションアシスタント、エンタープライズナレッジベースなど、AstrBot を使用すると、IM プラットフォームのワークフロー内で本番環境対応の AI アプリケーションを迅速に構築できます。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## 主な機能
|
||||
|
||||
1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載。
|
||||
2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート。
|
||||
3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。
|
||||
4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム。
|
||||
5. **WebUI**。ビジュアル設定とボット管理、充実した機能。
|
||||
1. 💯 無料 & オープンソース。
|
||||
2. ✨ AI 大規模言語モデル対話、マルチモーダル、Agent、MCP、ナレッジベース、ペルソナ設定。
|
||||
3. 🤖 Dify、Alibaba Cloud 百炼、Coze などの Agent プラットフォームとの統合をサポート。
|
||||
4. 🌐 マルチプラットフォーム:QQ、WeChat Work、Feishu、DingTalk、WeChat 公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)。
|
||||
5. 📦 約800個のプラグインをワンクリックでインストール可能なプラグイン拡張機能。
|
||||
6. 💻 WebUI サポート。
|
||||
7. 🌐 国際化(i18n)サポート。
|
||||
|
||||
## デプロイ方法
|
||||
## クイックスタート
|
||||
|
||||
#### Docker デプロイ(推奨 🥳)
|
||||
|
||||
@@ -50,6 +58,12 @@ Docker / Docker Compose を使用した AstrBot のデプロイを推奨しま
|
||||
|
||||
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
|
||||
|
||||
#### uv デプロイ
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 宝塔パネルデプロイ
|
||||
|
||||
AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
|
||||
@@ -101,24 +115,6 @@ uv run main.py
|
||||
|
||||
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
|
||||
|
||||
## 🌍 コミュニティ
|
||||
|
||||
### QQ グループ
|
||||
|
||||
- 1群:322154837
|
||||
- 3群:630166526
|
||||
- 5群:822130018
|
||||
- 6群:753075035
|
||||
- 開発者群:975206796
|
||||
|
||||
### Telegram グループ
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord サーバー
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## サポートされているメッセージプラットフォーム
|
||||
|
||||
**公式メンテナンス**
|
||||
@@ -205,6 +201,24 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 コミュニティ
|
||||
|
||||
### QQ グループ
|
||||
|
||||
- 1群: 322154837
|
||||
- 3群: 630166526
|
||||
- 5群: 822130018
|
||||
- 6群: 753075035
|
||||
- 開発者群: 975206796
|
||||
|
||||
### Telegram グループ
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord サーバー
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
|
||||
|
||||
+248
@@ -0,0 +1,248 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20%D0%BF%D0%BB%D0%B0%D0%B3%D0%B8%D0%BD%D0%BE%D0%B2&style=for-the-badge&label=%D0%9C%D0%B0%D0%B3%D0%B0%D0%B7%D0%B8%D0%BD&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
|
||||
|
||||
<a href="https://astrbot.app/">Документация</a> |
|
||||
<a href="https://blog.astrbot.app/">Блог</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Дорожная карта</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Сообщить о проблеме</a>
|
||||
</div>
|
||||
|
||||
AstrBot — это универсальная платформа Agent-чатботов с открытым исходным кодом, которая интегрируется с основными приложениями для обмена мгновенными сообщениями. Она предоставляет надёжную и масштабируемую инфраструктуру разговорного ИИ для частных лиц, разработчиков и команд. Будь то персональный ИИ-компаньон, интеллектуальная служба поддержки, автоматизированный помощник или корпоративная база знаний — AstrBot позволяет быстро создавать готовые к использованию ИИ-приложения в рабочих процессах вашей платформы обмена сообщениями.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Основные возможности
|
||||
|
||||
1. 💯 Бесплатно и с открытым исходным кодом.
|
||||
2. ✨ ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности.
|
||||
3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов.
|
||||
4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями).
|
||||
5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик.
|
||||
6. 💻 Поддержка WebUI.
|
||||
7. 🌐 Поддержка интернационализации (i18n).
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
#### Развёртывание Docker (Рекомендуется 🥳)
|
||||
|
||||
Мы рекомендуем развёртывать AstrBot с помощью Docker или Docker Compose.
|
||||
|
||||
См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### Развёртывание uv
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### Развёртывание BT-Panel
|
||||
|
||||
AstrBot в партнёрстве с BT-Panel теперь доступен на их маркетплейсе.
|
||||
|
||||
См. официальную документацию: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### Развёртывание 1Panel
|
||||
|
||||
AstrBot официально размещён на маркетплейсе 1Panel.
|
||||
|
||||
См. официальную документацию: [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Развёртывание на RainYun
|
||||
|
||||
AstrBot официально размещён на облачной платформе приложений RainYun с развёртыванием в один клик.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Развёртывание на Replit
|
||||
|
||||
Метод развёртывания от сообщества.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Установщик Windows в один клик
|
||||
|
||||
См. официальную документацию: [Развёртывание AstrBot с установщиком Windows в один клик](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### Развёртывание CasaOS
|
||||
|
||||
Метод развёртывания от сообщества.
|
||||
|
||||
См. официальную документацию: [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Ручное развёртывание
|
||||
|
||||
Сначала установите uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Установите AstrBot через Git Clone:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## Поддерживаемые платформы обмена сообщениями
|
||||
|
||||
**Официально поддерживаемые**
|
||||
|
||||
- QQ (Официальная платформа и OneBot)
|
||||
- Telegram
|
||||
- Приложение WeChat Work и интеллектуальный бот WeChat Work
|
||||
- Служба поддержки WeChat и официальные аккаунты WeChat
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Скоро)
|
||||
- LINE (Скоро)
|
||||
|
||||
**Поддерживаемые сообществом**
|
||||
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## Поддерживаемые сервисы моделей
|
||||
|
||||
**Сервисы LLM**
|
||||
|
||||
- OpenAI и совместимые сервисы
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Самостоятельное размещение)
|
||||
- LM Studio (Самостоятельное размещение)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**Платформы LLMOps**
|
||||
|
||||
- Dify
|
||||
- Приложения Alibaba Cloud Bailian
|
||||
- Coze
|
||||
|
||||
**Сервисы распознавания речи**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**Сервисы синтеза речи**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
## ❤️ Вклад в проект
|
||||
|
||||
Issues и Pull Request всегда приветствуются! Не стесняйтесь отправлять свои изменения в этот проект :)
|
||||
|
||||
### Как внести вклад
|
||||
|
||||
Вы можете внести вклад, просматривая issues или помогая с ревью pull request. Любые issues или PR приветствуются для поощрения участия сообщества. Конечно, это лишь предложения — вы можете вносить вклад любым удобным для вас способом. Для добавления новых функций сначала обсудите это через Issue.
|
||||
|
||||
### Среда разработки
|
||||
|
||||
AstrBot использует `ruff` для форматирования и линтинга кода.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Сообщество
|
||||
|
||||
### Группы QQ
|
||||
|
||||
- Группа 1: 322154837
|
||||
- Группа 3: 630166526
|
||||
- Группа 5: 822130018
|
||||
- Группа 6: 753075035
|
||||
- Группа разработчиков: 975206796
|
||||
|
||||
### Группа Telegram
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Сервер Discord
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Особая благодарность
|
||||
|
||||
Особая благодарность всем контрибьюторам и разработчикам плагинов за их вклад в AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
Кроме того, рождение этого проекта было бы невозможно без помощи следующих проектов с открытым исходным кодом:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - Замечательный кошачий фреймворк
|
||||
|
||||
## ⭐ История звёзд
|
||||
|
||||
> [!TIP]
|
||||
> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
+248
@@ -0,0 +1,248 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%A0%B4&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">简体中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">文件</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路線圖</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題回報</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主流即時通訊軟體,為個人、開發者和團隊打造可靠、可擴展的對話式智慧基礎設施。無論是個人 AI 夥伴、智慧客服、自動化助手,還是企業知識庫,AstrBot 都能在您的即時通訊軟體平台的工作流程中快速構建生產可用的 AI 應用程式。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## 主要功能
|
||||
|
||||
1. 💯 免費 & 開源。
|
||||
2. ✨ AI 大型模型對話,多模態,Agent,MCP,知識庫,人格設定。
|
||||
3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。
|
||||
4. 🌐 多平台:QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
|
||||
5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。
|
||||
6. 💻 WebUI 支援。
|
||||
7. 🌐 國際化(i18n)支援。
|
||||
|
||||
## 快速開始
|
||||
|
||||
#### Docker 部署(推薦 🥳)
|
||||
|
||||
推薦使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
請參閱官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
|
||||
|
||||
#### uv 部署
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 寶塔面板部署
|
||||
|
||||
AstrBot 與寶塔面板合作,已上架至寶塔面板。
|
||||
|
||||
請參閱官方文件 [寶塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html)。
|
||||
|
||||
#### 1Panel 部署
|
||||
|
||||
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
|
||||
|
||||
請參閱官方文件 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html)。
|
||||
|
||||
#### 在雨雲上部署
|
||||
|
||||
AstrBot 已由雨雲官方上架至雲端應用程式平台,可一鍵部署。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
社群貢獻的部署方式。
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows 一鍵安裝器部署
|
||||
|
||||
請參閱官方文件 [使用 Windows 一鍵安裝器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html)。
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社群貢獻的部署方式。
|
||||
|
||||
請參閱官方文件 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html)。
|
||||
|
||||
#### 手動部署
|
||||
|
||||
首先安裝 uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
透過 Git Clone 安裝 AstrBot:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。
|
||||
|
||||
## 支援的訊息平台
|
||||
|
||||
**官方維護**
|
||||
|
||||
- QQ(官方平台 & OneBot)
|
||||
- Telegram
|
||||
- 企微應用 & 企微智慧機器人
|
||||
- 微信客服 & 微信公眾號
|
||||
- 飛書
|
||||
- 釘釘
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- Whatsapp(即將支援)
|
||||
- LINE(即將支援)
|
||||
|
||||
**社群維護**
|
||||
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## 支援的模型服務
|
||||
|
||||
**大型模型服務**
|
||||
|
||||
- OpenAI 及相容服務
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智譜 AI
|
||||
- DeepSeek
|
||||
- Ollama(本機部署)
|
||||
- LM Studio(本機部署)
|
||||
- [優雲智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小馬算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [矽基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派歐雲](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**LLMOps 平台**
|
||||
|
||||
- Dify
|
||||
- 阿里雲百煉應用
|
||||
- Coze
|
||||
|
||||
**語音轉文字服務**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**文字轉語音服務**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- 阿里雲百煉 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- 火山引擎 TTS
|
||||
|
||||
## ❤️ 貢獻
|
||||
|
||||
歡迎任何 Issues/Pull Requests!只需要將您的變更提交到此專案 :)
|
||||
|
||||
### 如何貢獻
|
||||
|
||||
您可以透過檢視問題或協助審核 PR(拉取請求)來貢獻。任何問題或 PR 都歡迎參與,以促進社群貢獻。當然,這些只是建議,您可以以任何方式進行貢獻。對於新功能的新增,請先透過 Issue 討論。
|
||||
|
||||
### 開發環境
|
||||
|
||||
AstrBot 使用 `ruff` 進行程式碼格式化和檢查。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 社群
|
||||
|
||||
### QQ 群組
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 開發者群:975206796
|
||||
|
||||
### Telegram 群組
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群組
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特別感謝所有 Contributors 和外掛開發者對 AstrBot 的貢獻 ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
此外,本專案的誕生離不開以下開源專案的幫助:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 偉大的貓貓框架
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本專案對您的生活 / 工作產生了幫助,或者您關注本專案的未來發展,請給專案 Star,這是我們維護這個開源專案的動力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "3.5.23"
|
||||
__version__ = "4.8.0"
|
||||
|
||||
@@ -345,9 +345,6 @@ class MCPClient:
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
# Close current exit stack
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
@@ -359,6 +356,9 @@ class MCPClient:
|
||||
# Just clear the list to release references
|
||||
self._old_exit_stacks.clear()
|
||||
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
|
||||
class MCPTool(FunctionTool, Generic[TContext]):
|
||||
"""A function tool that calls an MCP service."""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -145,22 +145,39 @@ class Message(BaseModel):
|
||||
"tool",
|
||||
]
|
||||
|
||||
content: str | list[ContentPart]
|
||||
content: str | list[ContentPart] | None = None
|
||||
"""The content of the message."""
|
||||
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
"""The tool calls of the message."""
|
||||
|
||||
tool_call_id: str | None = None
|
||||
"""The ID of the tool call."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content_required(self):
|
||||
# assistant + tool_calls is not None: allow content to be None
|
||||
if self.role == "assistant" and self.tool_calls is not None:
|
||||
return self
|
||||
|
||||
# other all cases: content is required
|
||||
if self.content is None:
|
||||
raise ValueError(
|
||||
"content is required unless role='assistant' and tool_calls is not None"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
|
||||
|
||||
class ToolCallMessageSegment(Message):
|
||||
"""A message segment representing a tool call."""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class UserMessageSegment(Message):
|
||||
|
||||
@@ -2,13 +2,12 @@ import abc
|
||||
import typing as T
|
||||
from enum import Enum, auto
|
||||
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..response import AgentResponse
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
@@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
@@ -60,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
This method should be called after the agent is done.
|
||||
"""
|
||||
...
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""Transition the agent state."""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
@@ -0,0 +1,367 @@
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Coze Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://"),
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
# 会话相关缓存
|
||||
self.file_id_cache: dict[str, dict[str, str]] = {}
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Coze Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Coze 请求并处理结果
|
||||
async for response in self._execute_coze_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"Coze 请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await self.api_client.close()
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _execute_coze_request(self):
|
||||
"""执行 Coze 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
contexts = self.req.contexts or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
# 用户ID参数
|
||||
user_id = session_id
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=user_id,
|
||||
key="coze_conversation_id",
|
||||
default="",
|
||||
)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
# 处理历史上下文
|
||||
if not self.auto_save_history and contexts:
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
# 处理上下文中的图片
|
||||
content = ctx["content"]
|
||||
if isinstance(content, list):
|
||||
# 多模态内容,需要处理图片
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片上传
|
||||
try:
|
||||
image_data = item.get("image_url", {})
|
||||
url = image_data.get("url", "")
|
||||
if url:
|
||||
file_id = (
|
||||
await self._download_and_upload_image(
|
||||
url, session_id
|
||||
)
|
||||
)
|
||||
processed_content.append(
|
||||
{
|
||||
"type": "file",
|
||||
"file_id": file_id,
|
||||
"file_url": url,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理上下文图片失败: {e}")
|
||||
continue
|
||||
|
||||
if processed_content:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本内容
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": content,
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
|
||||
# 构建当前消息
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
# the url is a base64 string
|
||||
try:
|
||||
image_data = base64.b64decode(url)
|
||||
file_id = await self.api_client.upload_file(image_data)
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理图片失败 {url}: {e}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
elif prompt:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
# 执行 Coze API 请求
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=user_id,
|
||||
key="coze_conversation_id",
|
||||
value=data["conversation_id"],
|
||||
)
|
||||
|
||||
if event_type == "conversation.message.delta":
|
||||
# 增量消息
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
accumulated_content += content
|
||||
message_started = True
|
||||
|
||||
# 如果是流式响应,发送增量数据
|
||||
if self.streaming:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(content)
|
||||
),
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
# 消息完成
|
||||
logger.debug("Coze message completed")
|
||||
message_started = True
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
# 对话完成
|
||||
logger.debug("Coze chat completed")
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
# 错误处理
|
||||
error_msg = data.get("msg", "未知错误")
|
||||
error_code = data.get("code", "UNKNOWN")
|
||||
logger.error(f"Coze 出现错误: {error_code} - {error_msg}")
|
||||
raise Exception(f"Coze 出现错误: {error_code} - {error_msg}")
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
logger.warning("Coze 未返回任何内容")
|
||||
accumulated_content = ""
|
||||
|
||||
# 创建最终响应
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self,
|
||||
image_url: str,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
import hashlib
|
||||
|
||||
# 计算哈希实现缓存
|
||||
cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest()
|
||||
|
||||
if session_id:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
file_id = await self.api_client.upload_file(image_data)
|
||||
|
||||
if session_id:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {e!s}")
|
||||
raise Exception(f"处理图片失败: {e!s}")
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -0,0 +1,403 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import queue
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import typing as T
|
||||
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dashscope Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("阿里云百炼 API Key 不能为空。")
|
||||
self.app_id = provider_config.get("dashscope_app_id", "")
|
||||
if not self.app_id:
|
||||
raise Exception("阿里云百炼 APP ID 不能为空。")
|
||||
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
||||
if not self.dashscope_app_type:
|
||||
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||
|
||||
self.variables: dict = provider_config.get("variables", {}) or {}
|
||||
self.rag_options: dict = provider_config.get("rag_options", {})
|
||||
self.output_reference = self.rag_options.get("output_reference", False)
|
||||
self.rag_options = self.rag_options.copy()
|
||||
self.rag_options.pop("output_reference", None)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
def has_rag_options(self):
|
||||
"""判断是否有 RAG 选项
|
||||
|
||||
Returns:
|
||||
bool: 是否有 RAG 选项
|
||||
|
||||
"""
|
||||
if self.rag_options and (
|
||||
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||
or len(self.rag_options.get("file_ids", [])) > 0
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Dashscope Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Dashscope 请求并处理结果
|
||||
async for response in self._execute_dashscope_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"阿里云百炼请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
def _consume_sync_generator(
|
||||
self, response: T.Any, response_queue: queue.Queue
|
||||
) -> None:
|
||||
"""在线程中消费同步generator,将结果放入队列
|
||||
|
||||
Args:
|
||||
response: 同步generator对象
|
||||
response_queue: 用于传递数据的队列
|
||||
|
||||
"""
|
||||
try:
|
||||
if self.streaming:
|
||||
for chunk in response:
|
||||
response_queue.put(("data", chunk))
|
||||
else:
|
||||
response_queue.put(("data", response))
|
||||
except Exception as e:
|
||||
response_queue.put(("error", e))
|
||||
finally:
|
||||
response_queue.put(("done", None))
|
||||
|
||||
async def _process_stream_chunk(
|
||||
self, chunk: ApplicationResponse, output_text: str
|
||||
) -> tuple[str, list | None, AgentResponse | None]:
|
||||
"""处理流式响应的单个chunk
|
||||
|
||||
Args:
|
||||
chunk: Dashscope响应chunk
|
||||
output_text: 当前累积的输出文本
|
||||
|
||||
Returns:
|
||||
(更新后的output_text, doc_references, AgentResponse或None)
|
||||
|
||||
"""
|
||||
logger.debug(f"dashscope stream chunk: {chunk}")
|
||||
|
||||
if chunk.status_code != 200:
|
||||
logger.error(
|
||||
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
|
||||
)
|
||||
self._transition_state(AgentState.ERROR)
|
||||
error_msg = (
|
||||
f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
|
||||
)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err",
|
||||
result_chain=MessageChain().message(error_msg),
|
||||
)
|
||||
return (
|
||||
output_text,
|
||||
None,
|
||||
AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(chain=MessageChain().message(error_msg)),
|
||||
),
|
||||
)
|
||||
|
||||
chunk_text = chunk.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
|
||||
|
||||
response = None
|
||||
if chunk_text:
|
||||
output_text += chunk_text
|
||||
response = AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=MessageChain().message(chunk_text)),
|
||||
)
|
||||
|
||||
# 获取文档引用
|
||||
doc_references = chunk.output.get("doc_references", None)
|
||||
|
||||
return output_text, doc_references, response
|
||||
|
||||
def _format_doc_references(self, doc_references: list) -> str:
|
||||
"""格式化文档引用为文本
|
||||
|
||||
Args:
|
||||
doc_references: 文档引用列表
|
||||
|
||||
Returns:
|
||||
格式化后的引用文本
|
||||
|
||||
"""
|
||||
ref_parts = []
|
||||
for ref in doc_references:
|
||||
ref_title = (
|
||||
ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
|
||||
)
|
||||
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
|
||||
ref_str = "".join(ref_parts)
|
||||
return f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
async def _build_request_payload(
|
||||
self, prompt: str, session_id: str, contexts: list, system_prompt: str
|
||||
) -> dict:
|
||||
"""构建请求payload
|
||||
|
||||
Args:
|
||||
prompt: 用户输入
|
||||
session_id: 会话ID
|
||||
contexts: 上下文列表
|
||||
system_prompt: 系统提示词
|
||||
|
||||
Returns:
|
||||
请求payload字典
|
||||
|
||||
"""
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dashscope_conversation_id",
|
||||
default="",
|
||||
)
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
session_var = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_variables",
|
||||
default={},
|
||||
)
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||
and not self.has_rag_options()
|
||||
):
|
||||
# 支持多轮对话的
|
||||
p = {
|
||||
"app_id": self.app_id,
|
||||
"api_key": self.api_key,
|
||||
"prompt": prompt,
|
||||
"biz_params": payload_vars or None,
|
||||
"stream": self.streaming,
|
||||
"incremental_output": True,
|
||||
}
|
||||
if conversation_id:
|
||||
p["session_id"] = conversation_id
|
||||
return p
|
||||
else:
|
||||
# 不支持多轮对话的
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"prompt": prompt,
|
||||
"api_key": self.api_key,
|
||||
"biz_params": payload_vars or None,
|
||||
"stream": self.streaming,
|
||||
"incremental_output": True,
|
||||
}
|
||||
if self.rag_options:
|
||||
payload["rag_options"] = self.rag_options
|
||||
return payload
|
||||
|
||||
async def _handle_streaming_response(
|
||||
self, response: T.Any, session_id: str
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""处理流式响应
|
||||
|
||||
Args:
|
||||
response: Dashscope 流式响应 generator
|
||||
|
||||
Yields:
|
||||
AgentResponse 对象
|
||||
|
||||
"""
|
||||
response_queue = queue.Queue()
|
||||
consumer_thread = threading.Thread(
|
||||
target=self._consume_sync_generator,
|
||||
args=(response, response_queue),
|
||||
daemon=True,
|
||||
)
|
||||
consumer_thread.start()
|
||||
|
||||
output_text = ""
|
||||
doc_references = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, response_queue.get, True, 1
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if item_type == "done":
|
||||
break
|
||||
elif item_type == "error":
|
||||
raise item_data
|
||||
elif item_type == "data":
|
||||
chunk = item_data
|
||||
assert isinstance(chunk, ApplicationResponse)
|
||||
|
||||
(
|
||||
output_text,
|
||||
chunk_doc_refs,
|
||||
response,
|
||||
) = await self._process_stream_chunk(chunk, output_text)
|
||||
|
||||
if response:
|
||||
if response.type == "err":
|
||||
yield response
|
||||
return
|
||||
yield response
|
||||
|
||||
if chunk_doc_refs:
|
||||
doc_references = chunk_doc_refs
|
||||
|
||||
if chunk.output.session_id:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dashscope_conversation_id",
|
||||
value=chunk.output.session_id,
|
||||
)
|
||||
|
||||
# 添加 RAG 引用
|
||||
if self.output_reference and doc_references:
|
||||
ref_text = self._format_doc_references(doc_references)
|
||||
output_text += ref_text
|
||||
|
||||
if self.streaming:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=MessageChain().message(ref_text)),
|
||||
)
|
||||
|
||||
# 创建最终响应
|
||||
chain = MessageChain(chain=[Comp.Plain(output_text)])
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def _execute_dashscope_request(self):
|
||||
"""执行 Dashscope 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
contexts = self.req.contexts or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
# 检查图片输入
|
||||
if image_urls:
|
||||
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
||||
|
||||
# 构建请求payload
|
||||
payload = await self._build_request_payload(
|
||||
prompt, session_id, contexts, system_prompt
|
||||
)
|
||||
|
||||
if not self.streaming:
|
||||
payload["incremental_output"] = False
|
||||
|
||||
# 发起请求
|
||||
partial = functools.partial(Application.call, **payload)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
async for resp in self._handle_streaming_response(response, session_id):
|
||||
yield resp
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -0,0 +1,336 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .dify_api_client import DifyAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class DifyAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""Dify Agent Runner"""
|
||||
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_type = provider_config.get("dify_api_type", "chat")
|
||||
self.workflow_output_key = provider_config.get(
|
||||
"dify_workflow_output_key",
|
||||
"astrbot_wf_output",
|
||||
)
|
||||
self.dify_query_input_key = provider_config.get(
|
||||
"dify_query_input_key",
|
||||
"astrbot_text_query",
|
||||
)
|
||||
self.variables: dict = provider_config.get("variables", {}) or {}
|
||||
self.timeout = provider_config.get("timeout", 60)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
self.api_client = DifyAPIClient(self.api_key, self.api_base)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Dify Agent 的一个步骤
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行 Dify 请求并处理结果
|
||||
async for response in self._execute_dify_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Dify 请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"Dify 请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"Dify 请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await self.api_client.close()
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _execute_dify_request(self):
|
||||
"""执行 Dify 请求的核心逻辑"""
|
||||
prompt = self.req.prompt or ""
|
||||
session_id = self.req.session_id or "unknown"
|
||||
image_urls = self.req.image_urls or []
|
||||
system_prompt = self.req.system_prompt
|
||||
|
||||
conversation_id = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dify_conversation_id",
|
||||
default="",
|
||||
)
|
||||
result = ""
|
||||
|
||||
# 处理图片上传
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
# image_url is a base64 string
|
||||
try:
|
||||
image_data = base64.b64decode(image_url)
|
||||
file_response = await self.api_client.file_upload(
|
||||
file_data=image_data,
|
||||
user=session_id,
|
||||
mime_type="image/png",
|
||||
file_name="image.png",
|
||||
)
|
||||
logger.debug(f"Dify 上传图片响应:{file_response}")
|
||||
if "id" not in file_response:
|
||||
logger.warning(
|
||||
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
|
||||
)
|
||||
continue
|
||||
files_payload.append(
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response["id"],
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"上传图片失败:{e}")
|
||||
continue
|
||||
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_var = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_variables",
|
||||
default={},
|
||||
)
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
# 处理不同的 API 类型
|
||||
match self.api_type:
|
||||
case "chat" | "agent" | "chatflow":
|
||||
if not prompt:
|
||||
prompt = "请描述这张图片。"
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**payload_vars,
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk["event"] == "message" or chunk["event"] == "agent_message":
|
||||
result += chunk["answer"]
|
||||
if not conversation_id:
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="dify_conversation_id",
|
||||
value=chunk["conversation_id"],
|
||||
)
|
||||
conversation_id = chunk["conversation_id"]
|
||||
|
||||
# 如果是流式响应,发送增量数据
|
||||
if self.streaming and chunk["answer"]:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(chunk["answer"])
|
||||
),
|
||||
)
|
||||
elif chunk["event"] == "message_end":
|
||||
logger.debug("Dify message end")
|
||||
break
|
||||
elif chunk["event"] == "error":
|
||||
logger.error(f"Dify 出现错误:{chunk}")
|
||||
raise Exception(
|
||||
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
|
||||
)
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**payload_vars,
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify workflow resp chunk: {chunk}")
|
||||
match chunk["event"]:
|
||||
case "workflow_started":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
|
||||
)
|
||||
case "node_finished":
|
||||
logger.debug(
|
||||
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
|
||||
)
|
||||
case "text_chunk":
|
||||
if self.streaming and chunk["data"]["text"]:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
chunk["data"]["text"]
|
||||
)
|
||||
),
|
||||
)
|
||||
case "workflow_finished":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
|
||||
)
|
||||
logger.debug(f"Dify 工作流结果:{chunk}")
|
||||
if chunk["data"]["error"]:
|
||||
logger.error(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
||||
)
|
||||
if self.workflow_output_key not in chunk["data"]["outputs"]:
|
||||
raise Exception(
|
||||
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
|
||||
)
|
||||
result = chunk
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
|
||||
if not result:
|
||||
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
||||
|
||||
# 解析结果
|
||||
chain = await self.parse_dify_result(result)
|
||||
|
||||
# 创建最终响应
|
||||
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回最终结果
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
"""解析 Dify 的响应结果"""
|
||||
if isinstance(chunk, str):
|
||||
# Chat
|
||||
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||
|
||||
async def parse_file(item: dict):
|
||||
match item["type"]:
|
||||
case "image":
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
return Comp.Video(file=item["url"])
|
||||
case _:
|
||||
return Comp.File(name=item["filename"], file=item["url"])
|
||||
|
||||
output = chunk["data"]["outputs"][self.workflow_output_key]
|
||||
chains = []
|
||||
if isinstance(output, str):
|
||||
# 纯文本输出
|
||||
chains.append(Comp.Plain(output))
|
||||
elif isinstance(output, list):
|
||||
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
|
||||
for item in output:
|
||||
# handle Array[File]
|
||||
if (
|
||||
not isinstance(item, dict)
|
||||
or item.get("dify_model_identity", "") != "__dify__file__"
|
||||
):
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
break
|
||||
else:
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
|
||||
# scan file
|
||||
files = chunk["data"].get("files", [])
|
||||
for item in files:
|
||||
comp = await parse_file(item)
|
||||
chains.append(comp)
|
||||
|
||||
return MessageChain(chain=chains)
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
@override
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
+51
-13
@@ -3,7 +3,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ClientResponse, ClientSession
|
||||
from aiohttp import ClientResponse, ClientSession, FormData
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -101,21 +101,59 @@ class DifyAPIClient:
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
file_path: str,
|
||||
user: str,
|
||||
file_path: str | None = None,
|
||||
file_data: bytes | None = None,
|
||||
file_name: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Upload a file to Dify. Must provide either file_path or file_data.
|
||||
|
||||
Args:
|
||||
user: The user ID.
|
||||
file_path: The path to the file to upload.
|
||||
file_data: The file data in bytes.
|
||||
file_name: Optional file name when using file_data.
|
||||
Returns:
|
||||
A dictionary containing the uploaded file information.
|
||||
"""
|
||||
url = f"{self.api_base}/files/upload"
|
||||
with open(file_path, "rb") as f:
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": f,
|
||||
}
|
||||
async with self.session.post(
|
||||
url,
|
||||
data=payload,
|
||||
headers=self.headers,
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
form = FormData()
|
||||
form.add_field("user", user)
|
||||
|
||||
if file_data is not None:
|
||||
# 使用 bytes 数据
|
||||
form.add_field(
|
||||
"file",
|
||||
file_data,
|
||||
filename=file_name or "uploaded_file",
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
)
|
||||
elif file_path is not None:
|
||||
# 使用文件路径
|
||||
import os
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
form.add_field(
|
||||
"file",
|
||||
file_content,
|
||||
filename=os.path.basename(file_path),
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
)
|
||||
else:
|
||||
raise ValueError("file_path 和 file_data 不能同时为 None")
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
data=form,
|
||||
headers=self.headers, # 不包含 Content-Type,让 aiohttp 自动设置
|
||||
) as resp:
|
||||
if resp.status != 200 and resp.status != 201:
|
||||
text = await resp.text()
|
||||
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
@@ -69,12 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
if self.streaming:
|
||||
|
||||
@@ -9,6 +9,7 @@ from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
@@ -72,7 +73,20 @@ async def run_agent(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
|
||||
error_llm_response = LLMResponse(
|
||||
role="err",
|
||||
completion_text=err_msg,
|
||||
)
|
||||
try:
|
||||
await agent_runner.agent_hooks.on_agent_done(
|
||||
agent_runner.run_context, error_llm_response
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in on_agent_done hook")
|
||||
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
|
||||
+290
-27
@@ -4,9 +4,17 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.6.0"
|
||||
VERSION = "4.8.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
"qq_official_webhook",
|
||||
"weixin_official_account",
|
||||
"wecom",
|
||||
"wecom_ai_bot",
|
||||
"slack",
|
||||
]
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"config_version": 2,
|
||||
@@ -68,9 +76,19 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"agent_runner_type": "local",
|
||||
"dify_agent_runner_provider_id": "",
|
||||
"coze_agent_runner_provider_id": "",
|
||||
"dashscope_agent_runner_provider_id": "",
|
||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||
"reachability_check": False,
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
"file_extract": {
|
||||
"enable": False,
|
||||
"provider": "moonshotai",
|
||||
"moonshotai_api_key": "",
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -86,6 +104,7 @@ DEFAULT_CONFIG = {
|
||||
"group_icl_enable": False,
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_provider_id": "",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
@@ -141,7 +160,16 @@ DEFAULT_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
# 配置项的中文描述、值类型
|
||||
"""
|
||||
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
|
||||
|
||||
1. 保存配置时,配置项的类型验证
|
||||
2. WebUI 展示提供商和平台适配器模版
|
||||
|
||||
WebUI 的配置文件在 `CONFIG_METADATA_3` 中。
|
||||
|
||||
未来将会逐步淘汰此配置元数据。
|
||||
"""
|
||||
CONFIG_METADATA_2 = {
|
||||
"platform_group": {
|
||||
"metadata": {
|
||||
@@ -165,6 +193,8 @@ CONFIG_METADATA_2 = {
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"is_sandbox": False,
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
@@ -195,6 +225,8 @@ CONFIG_METADATA_2 = {
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6194,
|
||||
"active_send_mode": False,
|
||||
@@ -209,6 +241,8 @@ CONFIG_METADATA_2 = {
|
||||
"encoding_aes_key": "",
|
||||
"kf_name": "",
|
||||
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6195,
|
||||
},
|
||||
@@ -221,6 +255,8 @@ CONFIG_METADATA_2 = {
|
||||
"wecom_ai_bot_name": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6198,
|
||||
},
|
||||
@@ -288,6 +324,8 @@ CONFIG_METADATA_2 = {
|
||||
"app_token": "",
|
||||
"signing_secret": "",
|
||||
"slack_connection_mode": "socket", # webhook, socket
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"slack_webhook_host": "0.0.0.0",
|
||||
"slack_webhook_port": 6197,
|
||||
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
||||
@@ -367,16 +405,28 @@ CONFIG_METADATA_2 = {
|
||||
"description": "Slack Webhook Host",
|
||||
"type": "string",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
"condition": {
|
||||
"slack_connection_mode": "webhook",
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"slack_webhook_port": {
|
||||
"description": "Slack Webhook Port",
|
||||
"type": "int",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
"condition": {
|
||||
"slack_connection_mode": "webhook",
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"slack_webhook_path": {
|
||||
"description": "Slack Webhook Path",
|
||||
"type": "string",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
"condition": {
|
||||
"slack_connection_mode": "webhook",
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"active_send_mode": {
|
||||
"description": "是否换用主动发送接口",
|
||||
@@ -567,6 +617,33 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
|
||||
},
|
||||
"port": {
|
||||
"description": "回调服务器端口",
|
||||
"type": "int",
|
||||
"hint": "回调服务器端口。留空则不启用回调服务器。",
|
||||
"condition": {
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"callback_server_host": {
|
||||
"description": "回调服务器主机",
|
||||
"type": "string",
|
||||
"hint": "回调服务器主机。留空则不启用回调服务器。",
|
||||
"condition": {
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"unified_webhook_mode": {
|
||||
"description": "统一 Webhook 模式",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。",
|
||||
},
|
||||
"webhook_uuid": {
|
||||
"invisible": True,
|
||||
"description": "Webhook UUID",
|
||||
"type": "string",
|
||||
"hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -634,7 +711,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"words_count_threshold": {
|
||||
"type": "int",
|
||||
"hint": "超过这个字数的消息不会被分段回复。默认为 150",
|
||||
"hint": "分段回复的字数上限。只有字数小于此值的消息才会被分段,超过此值的长消息将直接发送(不分段)。默认为 150",
|
||||
},
|
||||
"regex": {
|
||||
"type": "string",
|
||||
@@ -1011,7 +1088,7 @@ CONFIG_METADATA_2 = {
|
||||
"id": "dify_app_default",
|
||||
"provider": "dify",
|
||||
"type": "dify",
|
||||
"provider_type": "chat_completion",
|
||||
"provider_type": "agent_runner",
|
||||
"enable": True,
|
||||
"dify_api_type": "chat",
|
||||
"dify_api_key": "",
|
||||
@@ -1025,20 +1102,20 @@ CONFIG_METADATA_2 = {
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
"provider": "coze",
|
||||
"provider_type": "chat_completion",
|
||||
"provider_type": "agent_runner",
|
||||
"type": "coze",
|
||||
"enable": True,
|
||||
"coze_api_key": "",
|
||||
"bot_id": "",
|
||||
"coze_api_base": "https://api.coze.cn",
|
||||
"timeout": 60,
|
||||
"auto_save_history": True,
|
||||
# "auto_save_history": True,
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope",
|
||||
"provider_type": "chat_completion",
|
||||
"provider_type": "agent_runner",
|
||||
"enable": True,
|
||||
"dashscope_app_type": "agent",
|
||||
"dashscope_api_key": "",
|
||||
@@ -1087,7 +1164,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "",
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"Whisper(本地加载)": {
|
||||
"Whisper(Local)": {
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_selfhost",
|
||||
@@ -1096,7 +1173,7 @@ CONFIG_METADATA_2 = {
|
||||
"id": "whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"SenseVoice(本地加载)": {
|
||||
"SenseVoice(Local)": {
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider": "sensevoice",
|
||||
@@ -1131,7 +1208,7 @@ CONFIG_METADATA_2 = {
|
||||
"pitch": "+0Hz",
|
||||
"timeout": 20,
|
||||
},
|
||||
"GSV TTS(本地加载)": {
|
||||
"GSV TTS(Local)": {
|
||||
"id": "gsv_tts",
|
||||
"enable": False,
|
||||
"provider": "gpt_sovits",
|
||||
@@ -1308,6 +1385,19 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 20,
|
||||
"launch_model_if_not_running": False,
|
||||
},
|
||||
"阿里云百炼重排序": {
|
||||
"id": "bailian_rerank",
|
||||
"type": "bailian_rerank",
|
||||
"provider": "bailian",
|
||||
"provider_type": "rerank",
|
||||
"enable": True,
|
||||
"rerank_api_key": "",
|
||||
"rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
||||
"rerank_model": "qwen3-rerank",
|
||||
"timeout": 30,
|
||||
"return_documents": False,
|
||||
"instruct": "",
|
||||
},
|
||||
"Xinference STT": {
|
||||
"id": "xinference_stt",
|
||||
"type": "xinference_stt",
|
||||
@@ -1342,6 +1432,16 @@ CONFIG_METADATA_2 = {
|
||||
"description": "重排序模型名称",
|
||||
"type": "string",
|
||||
},
|
||||
"return_documents": {
|
||||
"description": "是否在排序结果中返回文档原文",
|
||||
"type": "bool",
|
||||
"hint": "默认值false,以减少网络传输开销。",
|
||||
},
|
||||
"instruct": {
|
||||
"description": "自定义排序任务类型说明",
|
||||
"type": "string",
|
||||
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
|
||||
},
|
||||
"launch_model_if_not_running": {
|
||||
"description": "模型未运行时自动启动",
|
||||
"type": "bool",
|
||||
@@ -1884,7 +1984,6 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用",
|
||||
"type": "bool",
|
||||
"hint": "是否启用。",
|
||||
},
|
||||
"key": {
|
||||
"description": "API Key",
|
||||
@@ -2014,14 +2113,38 @@ CONFIG_METADATA_2 = {
|
||||
"unsupported_streaming_strategy": {
|
||||
"type": "string",
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"type": "string",
|
||||
},
|
||||
"dify_agent_runner_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"dashscope_agent_runner_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
},
|
||||
"tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
},
|
||||
"file_extract": {
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"type": "bool",
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
@@ -2064,6 +2187,9 @@ CONFIG_METADATA_2 = {
|
||||
"image_caption": {
|
||||
"type": "bool",
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"type": "string",
|
||||
},
|
||||
@@ -2153,34 +2279,87 @@ CONFIG_METADATA_2 = {
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
v4.7.0 之后,name, description, hint 等字段已经实现 i18n 国际化。国际化资源文件位于:
|
||||
|
||||
- dashboard/src/i18n/locales/en-US/features/config-metadata.json
|
||||
- dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
|
||||
|
||||
如果在此文件中添加了新的配置字段,请务必同步更新上述两个国际化资源文件。
|
||||
"""
|
||||
CONFIG_METADATA_3 = {
|
||||
"ai_group": {
|
||||
"name": "AI 配置",
|
||||
"metadata": {
|
||||
"ai": {
|
||||
"description": "模型",
|
||||
"agent_runner": {
|
||||
"description": "Agent 执行方式",
|
||||
"hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"description": "启用",
|
||||
"type": "bool",
|
||||
"hint": "AI 对话总开关",
|
||||
},
|
||||
"provider_settings.agent_runner_type": {
|
||||
"description": "执行器",
|
||||
"type": "string",
|
||||
"options": ["local", "dify", "coze", "dashscope"],
|
||||
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"],
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent 执行器提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_agent_runner_provider:coze",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "coze",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.dify_agent_runner_provider_id": {
|
||||
"description": "Dify Agent 执行器提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_agent_runner_provider:dify",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "dify",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.dashscope_agent_runner_provider_id": {
|
||||
"description": "阿里云百炼应用 Agent 执行器提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_agent_runner_provider:dashscope",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "dashscope",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ai": {
|
||||
"description": "模型",
|
||||
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.default_provider_id": {
|
||||
"description": "默认聊天模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时使用第一个模型。",
|
||||
"hint": "留空时使用第一个模型",
|
||||
},
|
||||
"provider_settings.default_image_caption_provider_id": {
|
||||
"description": "默认图片转述模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||
"hint": "留空代表不使用,可用于非多模态模型",
|
||||
},
|
||||
"provider_stt_settings.enable": {
|
||||
"description": "启用语音转文本",
|
||||
"type": "bool",
|
||||
"hint": "STT 总开关。",
|
||||
"hint": "STT 总开关",
|
||||
},
|
||||
"provider_stt_settings.provider_id": {
|
||||
"description": "默认语音转文本模型",
|
||||
@@ -2194,12 +2373,11 @@ CONFIG_METADATA_3 = {
|
||||
"provider_tts_settings.enable": {
|
||||
"description": "启用文本转语音",
|
||||
"type": "bool",
|
||||
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
|
||||
"hint": "TTS 总开关",
|
||||
},
|
||||
"provider_tts_settings.provider_id": {
|
||||
"description": "默认文本转语音模型",
|
||||
"type": "string",
|
||||
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
|
||||
"_special": "select_provider_tts",
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
@@ -2210,6 +2388,9 @@ CONFIG_METADATA_3 = {
|
||||
"type": "text",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
"description": "人格",
|
||||
@@ -2221,6 +2402,10 @@ CONFIG_METADATA_3 = {
|
||||
"_special": "select_persona",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"knowledgebase": {
|
||||
"description": "知识库",
|
||||
@@ -2249,6 +2434,10 @@ CONFIG_METADATA_3 = {
|
||||
"hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"websearch": {
|
||||
"description": "网页搜索",
|
||||
@@ -2285,7 +2474,41 @@ CONFIG_METADATA_3 = {
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
# "file_extract": {
|
||||
# "description": "文档解析能力 [beta]",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "provider_settings.file_extract.enable": {
|
||||
# "description": "启用文档解析能力",
|
||||
# "type": "bool",
|
||||
# },
|
||||
# "provider_settings.file_extract.provider": {
|
||||
# "description": "文档解析提供商",
|
||||
# "type": "string",
|
||||
# "options": ["moonshotai"],
|
||||
# "condition": {
|
||||
# "provider_settings.file_extract.enable": True,
|
||||
# },
|
||||
# },
|
||||
# "provider_settings.file_extract.moonshotai_api_key": {
|
||||
# "description": "Moonshot AI API Key",
|
||||
# "type": "string",
|
||||
# "condition": {
|
||||
# "provider_settings.file_extract.provider": "moonshotai",
|
||||
# "provider_settings.file_extract.enable": True,
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
# "condition": {
|
||||
# "provider_settings.agent_runner_type": "local",
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2293,34 +2516,51 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.display_reasoning_text": {
|
||||
"description": "显示思考内容",
|
||||
"type": "bool",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.identifier": {
|
||||
"description": "用户识别",
|
||||
"type": "bool",
|
||||
"hint": "启用后,会在提示词前包含用户 ID 信息。",
|
||||
},
|
||||
"provider_settings.group_name_display": {
|
||||
"description": "显示群名称",
|
||||
"type": "bool",
|
||||
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
|
||||
"hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。",
|
||||
},
|
||||
"provider_settings.datetime_system_prompt": {
|
||||
"description": "现实世界时间感知",
|
||||
"type": "bool",
|
||||
"hint": "启用后,会在系统提示词中附带当前时间信息。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.show_tool_use_status": {
|
||||
"description": "输出函数调用状态",
|
||||
"type": "bool",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.streaming_response": {
|
||||
"description": "流式回复",
|
||||
"description": "流式输出",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.unsupported_streaming_strategy": {
|
||||
@@ -2336,17 +2576,23 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数。",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求",
|
||||
},
|
||||
"provider_settings.prompt_prefix": {
|
||||
"description": "用户提示词",
|
||||
@@ -2357,6 +2603,14 @@ CONFIG_METADATA_3 = {
|
||||
"description": "开启 TTS 时同时输出语音和文字内容",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.reachability_check": {
|
||||
"description": "提供商可达性检测",
|
||||
"type": "bool",
|
||||
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -2647,7 +2901,16 @@ CONFIG_METADATA_3 = {
|
||||
"provider_ltm_settings.image_caption": {
|
||||
"description": "自动理解图片",
|
||||
"type": "bool",
|
||||
"hint": "需要设置默认图片转述模型。",
|
||||
"hint": "需要设置群聊图片转述模型。",
|
||||
},
|
||||
"provider_ltm_settings.image_caption_provider_id": {
|
||||
"description": "群聊图片转述模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
|
||||
"condition": {
|
||||
"provider_ltm_settings.image_caption": True,
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings.active_reply.enable": {
|
||||
"description": "主动回复",
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
配置元数据国际化工具
|
||||
|
||||
提供配置元数据的国际化键转换功能
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ConfigMetadataI18n:
|
||||
"""配置元数据国际化转换器"""
|
||||
|
||||
@staticmethod
|
||||
def _get_i18n_key(group: str, section: str, field: str, attr: str) -> str:
|
||||
"""
|
||||
生成国际化键
|
||||
|
||||
Args:
|
||||
group: 配置组,如 'ai_group', 'platform_group'
|
||||
section: 配置节,如 'agent_runner', 'general'
|
||||
field: 字段名,如 'enable', 'default_provider'
|
||||
attr: 属性类型,如 'description', 'hint', 'labels'
|
||||
|
||||
Returns:
|
||||
国际化键,格式如: 'ai_group.agent_runner.enable.description'
|
||||
"""
|
||||
if field:
|
||||
return f"{group}.{section}.{field}.{attr}"
|
||||
else:
|
||||
return f"{group}.{section}.{attr}"
|
||||
|
||||
@staticmethod
|
||||
def convert_to_i18n_keys(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
将配置元数据转换为使用国际化键
|
||||
|
||||
Args:
|
||||
metadata: 原始配置元数据字典
|
||||
|
||||
Returns:
|
||||
使用国际化键的配置元数据字典
|
||||
"""
|
||||
result = {}
|
||||
|
||||
for group_key, group_data in metadata.items():
|
||||
group_result = {
|
||||
"name": f"{group_key}.name",
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
for section_key, section_data in group_data.get("metadata", {}).items():
|
||||
section_result = {
|
||||
"description": f"{group_key}.{section_key}.description",
|
||||
"type": section_data.get("type"),
|
||||
}
|
||||
|
||||
# 复制其他属性
|
||||
for key in ["items", "condition", "_special", "invisible"]:
|
||||
if key in section_data:
|
||||
section_result[key] = section_data[key]
|
||||
|
||||
# 处理 hint
|
||||
if "hint" in section_data:
|
||||
section_result["hint"] = f"{group_key}.{section_key}.hint"
|
||||
|
||||
# 处理 items 中的字段
|
||||
if "items" in section_data and isinstance(section_data["items"], dict):
|
||||
items_result = {}
|
||||
for field_key, field_data in section_data["items"].items():
|
||||
# 处理嵌套的点号字段名(如 provider_settings.enable)
|
||||
field_name = field_key
|
||||
|
||||
field_result = {}
|
||||
|
||||
# 复制基本属性
|
||||
for attr in [
|
||||
"type",
|
||||
"condition",
|
||||
"_special",
|
||||
"invisible",
|
||||
"options",
|
||||
]:
|
||||
if attr in field_data:
|
||||
field_result[attr] = field_data[attr]
|
||||
|
||||
# 转换文本属性为国际化键
|
||||
if "description" in field_data:
|
||||
field_result["description"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.description"
|
||||
)
|
||||
|
||||
if "hint" in field_data:
|
||||
field_result["hint"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.hint"
|
||||
)
|
||||
|
||||
if "labels" in field_data:
|
||||
field_result["labels"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.labels"
|
||||
)
|
||||
|
||||
items_result[field_key] = field_result
|
||||
|
||||
section_result["items"] = items_result
|
||||
|
||||
group_result["metadata"][section_key] = section_result
|
||||
|
||||
result[group_key] = group_result
|
||||
|
||||
return result
|
||||
@@ -16,15 +16,13 @@ import time
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
|
||||
from astrbot.core import LogBroker, logger, sp
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.memory.memory_manager import MemoryManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
@@ -35,6 +33,7 @@ from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.migra_helper import migra
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
from .event_bus import EventBus
|
||||
@@ -98,18 +97,16 @@ class AstrBotCoreLifecycle:
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
# apply migration
|
||||
try:
|
||||
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
|
||||
await migra(
|
||||
self.db,
|
||||
self.astrbot_config_mgr,
|
||||
self.umop_config_router,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for webchat session
|
||||
try:
|
||||
await migrate_webchat_session(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(f"AstrBot migration failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
@@ -137,8 +134,6 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
# 初始化记忆管理器
|
||||
self.memory_manager = MemoryManager()
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
@@ -152,7 +147,6 @@ class AstrBotCoreLifecycle:
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
self.kb_manager,
|
||||
self.memory_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
|
||||
@@ -173,7 +173,7 @@ class BaseDatabase(abc.ABC):
|
||||
content: dict,
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
) -> PlatformMessageHistory:
|
||||
"""Insert a new platform message history record."""
|
||||
...
|
||||
|
||||
@@ -198,6 +198,14 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get platform message history for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_message_history_by_id(
|
||||
self,
|
||||
message_id: int,
|
||||
) -> PlatformMessageHistory | None:
|
||||
"""Get a platform message history record by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_attachment(
|
||||
self,
|
||||
@@ -213,6 +221,27 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get an attachment by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]:
|
||||
"""Get multiple attachments by their IDs."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||
"""Delete an attachment by its ID.
|
||||
|
||||
Returns True if the attachment was deleted, False if it was not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
"""Delete multiple attachments by their IDs.
|
||||
|
||||
Returns the number of attachments deleted.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_persona(
|
||||
self,
|
||||
|
||||
@@ -25,7 +25,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_webchat_session"
|
||||
"global", "global", "migration_done_webchat_session_1"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
@@ -43,7 +43,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
||||
)
|
||||
.where(col(PlatformMessageHistory.platform_id) == "webchat")
|
||||
.where(col(PlatformMessageHistory.sender_id) == "astrbot")
|
||||
.where(col(PlatformMessageHistory.sender_id) != "bot")
|
||||
.group_by(col(PlatformMessageHistory.user_id))
|
||||
)
|
||||
|
||||
@@ -53,7 +53,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
if not webchat_users:
|
||||
logger.info("没有找到需要迁移的 WebChat 数据")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_webchat_session", True
|
||||
"global", "global", "migration_done_webchat_session_1", True
|
||||
)
|
||||
return
|
||||
|
||||
@@ -124,7 +124,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
logger.info("没有新会话需要迁移")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_webchat_session", True)
|
||||
await sp.put_async("global", "global", "migration_done_webchat_session_1", True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
|
||||
@@ -173,7 +173,7 @@ class PlatformSession(SQLModel, table=True):
|
||||
max_length=100,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: f"webchat_{uuid.uuid4()}",
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
platform_id: str = Field(default="webchat", nullable=False)
|
||||
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
|
||||
|
||||
@@ -105,8 +105,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
text("""
|
||||
SELECT * FROM platform_stats
|
||||
WHERE timestamp >= :start_time
|
||||
ORDER BY timestamp DESC
|
||||
GROUP BY platform_id
|
||||
ORDER BY timestamp DESC
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
@@ -449,6 +449,18 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query.offset(offset).limit(page_size))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_platform_message_history_by_id(
|
||||
self, message_id: int
|
||||
) -> PlatformMessageHistory | None:
|
||||
"""Get a platform message history record by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(PlatformMessageHistory).where(
|
||||
PlatformMessageHistory.id == message_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def insert_attachment(self, path, type, mime_type):
|
||||
"""Insert a new attachment record."""
|
||||
async with self.get_db() as session:
|
||||
@@ -470,6 +482,48 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_attachments(self, attachment_ids: list[str]) -> list:
|
||||
"""Get multiple attachments by their IDs."""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(
|
||||
Attachment.attachment_id.in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||
"""Delete an attachment by its ID.
|
||||
|
||||
Returns True if the attachment was deleted, False if it was not found.
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id) == attachment_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
"""Delete multiple attachments by their IDs.
|
||||
|
||||
Returns the number of attachments deleted.
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return 0
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.rowcount
|
||||
|
||||
async def insert_persona(
|
||||
self,
|
||||
persona_id,
|
||||
@@ -794,7 +848,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
await session.execute(
|
||||
update(PlatformSession)
|
||||
.where(col(PlatformSession.session_id == session_id))
|
||||
.where(col(PlatformSession.session_id) == session_id)
|
||||
.values(**values),
|
||||
)
|
||||
|
||||
@@ -805,6 +859,6 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(PlatformSession).where(
|
||||
col(PlatformSession.session_id == session_id),
|
||||
col(PlatformSession.session_id) == session_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,20 +1,11 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
class ResultData(TypedDict):
|
||||
id: str
|
||||
doc_id: str
|
||||
text: str
|
||||
metadata: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
similarity: float
|
||||
data: ResultData | dict
|
||||
data: dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
|
||||
@@ -1,822 +0,0 @@
|
||||
{
|
||||
"type": "excalidraw",
|
||||
"version": 2,
|
||||
"source": "https://marketplace.visualstudio.com/items?itemName=pomdtr.excalidraw-editor",
|
||||
"elements": [
|
||||
{
|
||||
"id": "l6cYurMvF69IM4Kc33Qou",
|
||||
"type": "rectangle",
|
||||
"x": 173.140625,
|
||||
"y": -29.0234375,
|
||||
"width": 92.95703125,
|
||||
"height": 77.109375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a0",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1409469537,
|
||||
"version": 91,
|
||||
"versionNonce": 307958671,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "1ZvS6t8U6ihUjNU0dakgl",
|
||||
"type": "arrow",
|
||||
"x": 409.30859375,
|
||||
"y": 9.6875,
|
||||
"width": 118.2734375,
|
||||
"height": 1.9609375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a1",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 326508865,
|
||||
"version": 120,
|
||||
"versionNonce": 199367023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
-118.2734375,
|
||||
-1.9609375
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "tfdUGiJdcMoOHGfqFHXK6",
|
||||
"type": "text",
|
||||
"x": 153.46875,
|
||||
"y": -70.9765625,
|
||||
"width": 136.4598846435547,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a2",
|
||||
"roundness": null,
|
||||
"seed": 688712865,
|
||||
"version": 67,
|
||||
"versionNonce": 300660705,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703743816,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FAISS+SQLite",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FAISS+SQLite",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "AeL3kEB9a8_TAvAXpAbpl",
|
||||
"type": "text",
|
||||
"x": 438.36328125,
|
||||
"y": -3.78125,
|
||||
"width": 116.109375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a3",
|
||||
"roundness": null,
|
||||
"seed": 788579535,
|
||||
"version": 33,
|
||||
"versionNonce": 946602095,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703932431,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FACT",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FACT",
|
||||
"autoResize": false,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Pe3TeMZvxQ8tRTcbD5v6P",
|
||||
"type": "arrow",
|
||||
"x": 297.125,
|
||||
"y": 40.2578125,
|
||||
"width": 120.2421875,
|
||||
"height": 1.421875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a4",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1146229999,
|
||||
"version": 44,
|
||||
"versionNonce": 636917679,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703759050,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
120.2421875,
|
||||
1.421875
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "GhmQoadtQRK8c8aEEbYKQ",
|
||||
"type": "text",
|
||||
"x": 283.53515625,
|
||||
"y": 64.76171875,
|
||||
"width": 130.85989379882812,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a5",
|
||||
"roundness": null,
|
||||
"seed": 1445650959,
|
||||
"version": 79,
|
||||
"versionNonce": 566193167,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703768982,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "top-n Similary\n",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "top-n Similary\n",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"type": "rectangle",
|
||||
"x": 528.1586158430439,
|
||||
"y": -173.43472375183552,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a6",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 223409231,
|
||||
"version": 44,
|
||||
"versionNonce": 1066827105,
|
||||
"isDeleted": false,
|
||||
"boundElements": [
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow"
|
||||
}
|
||||
],
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "2SzqzpJ4C2ymVj8-8vN7H",
|
||||
"type": "text",
|
||||
"x": 548.1480270948795,
|
||||
"y": -211,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a7",
|
||||
"roundness": null,
|
||||
"seed": 1015608623,
|
||||
"version": 23,
|
||||
"versionNonce": 950374849,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704047884,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "CgW6Yf9v0a9q1tsjhDl7b",
|
||||
"type": "text",
|
||||
"x": 568.3099317299038,
|
||||
"y": -154.69469411681115,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aA",
|
||||
"roundness": null,
|
||||
"seed": 452254927,
|
||||
"version": 10,
|
||||
"versionNonce": 972895023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "knvlKpaFZ8lY-73Y-e9W6",
|
||||
"type": "text",
|
||||
"x": 569.11328125,
|
||||
"y": -116.91056665512056,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aB",
|
||||
"roundness": null,
|
||||
"seed": 914644015,
|
||||
"version": 90,
|
||||
"versionNonce": 158135631,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Q7URqvTSMpvj08ye-afTT",
|
||||
"type": "rectangle",
|
||||
"x": 444.515625,
|
||||
"y": 36.7890625,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aC",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1642537601,
|
||||
"version": 19,
|
||||
"versionNonce": 948406575,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703870173,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "JjxBt9cZIZXNTd6CmwyKL",
|
||||
"type": "rectangle",
|
||||
"x": 452.203125,
|
||||
"y": 46.064453125,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aD",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1746916641,
|
||||
"version": 40,
|
||||
"versionNonce": 1650978255,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703871882,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "XGBCPPFnjriqsL8LvLwyQ",
|
||||
"type": "rectangle",
|
||||
"x": 461.56640625,
|
||||
"y": 56.162109375,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aE",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 529794575,
|
||||
"version": 85,
|
||||
"versionNonce": 2131900641,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703874182,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow",
|
||||
"x": 537.6875,
|
||||
"y": 48.203125,
|
||||
"width": 6.615850226297994,
|
||||
"height": 75.81335873223107,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aF",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1982870689,
|
||||
"version": 90,
|
||||
"versionNonce": 25307457,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
6.615850226297994,
|
||||
-75.81335873223107
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": {
|
||||
"elementId": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"focus": 0.6071885090336794,
|
||||
"gap": 24.64453125
|
||||
},
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "jgJgqGMRWcaNX_28wY4CU",
|
||||
"type": "text",
|
||||
"x": 570,
|
||||
"y": 10,
|
||||
"width": 67.11994934082031,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aG",
|
||||
"roundness": null,
|
||||
"seed": 1065220559,
|
||||
"version": 26,
|
||||
"versionNonce": 2115991521,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703959397,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "update",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "update",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "_5pSPPOpp9h1TpFCIc055",
|
||||
"type": "text",
|
||||
"x": 292.36328125,
|
||||
"y": -138.5703125,
|
||||
"width": 122.87992858886719,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aH",
|
||||
"roundness": null,
|
||||
"seed": 51461025,
|
||||
"version": 26,
|
||||
"versionNonce": 1647492655,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703925147,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "ADD Memory",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "ADD Memory",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "YG6MdL14l7lk4ypQNMZ_k",
|
||||
"type": "text",
|
||||
"x": 296.71885397566257,
|
||||
"y": 161.399157096715,
|
||||
"width": 295.27984619140625,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aJ",
|
||||
"roundness": null,
|
||||
"seed": 1183210273,
|
||||
"version": 122,
|
||||
"versionNonce": 1702733281,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704085083,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RETRIEVE Memory (STATIC)",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RETRIEVE Memory (STATIC)",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Foa3VPJYqhj1uAX5mn3n0",
|
||||
"type": "rectangle",
|
||||
"x": 324.7616636099071,
|
||||
"y": 248.63213980937013,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aL",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 995116257,
|
||||
"version": 225,
|
||||
"versionNonce": 1886900225,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "pe3veI_yBFKYtbaJwDKQT",
|
||||
"type": "text",
|
||||
"x": 344.7510748617428,
|
||||
"y": 211.06686356120565,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aM",
|
||||
"roundness": null,
|
||||
"seed": 26673345,
|
||||
"version": 204,
|
||||
"versionNonce": 1004546017,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "bOlhO8AaKE86_43viu5UG",
|
||||
"type": "text",
|
||||
"x": 365.50408375566445,
|
||||
"y": 269.24725381983865,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aN",
|
||||
"roundness": null,
|
||||
"seed": 1849784033,
|
||||
"version": 106,
|
||||
"versionNonce": 762320737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "V_iDW10PKwMe7vWb5S5HF",
|
||||
"type": "text",
|
||||
"x": 366.3074332757606,
|
||||
"y": 307.03138128152926,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aO",
|
||||
"roundness": null,
|
||||
"seed": 1670509249,
|
||||
"version": 186,
|
||||
"versionNonce": 1964540737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "LHKMRdSowgcl2LsKacxTz",
|
||||
"type": "text",
|
||||
"x": 484.9493410573871,
|
||||
"y": 292.45619471187945,
|
||||
"width": 273.579833984375,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aP",
|
||||
"roundness": null,
|
||||
"seed": 945666991,
|
||||
"version": 104,
|
||||
"versionNonce": 1512137505,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704096016,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
}
|
||||
],
|
||||
"appState": {
|
||||
"gridSize": 20,
|
||||
"gridStep": 5,
|
||||
"gridModeEnabled": false,
|
||||
"viewBackgroundColor": "#ffffff"
|
||||
},
|
||||
"files": {}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
## Decay Score
|
||||
|
||||
记忆衰减分数定义为:
|
||||
|
||||
\[
|
||||
\text{decay\_score}
|
||||
= \alpha \cdot e^{-\lambda \cdot \Delta t \cdot \beta}
|
||||
|
||||
+ (1-\alpha)\cdot (1 - e^{-\gamma \cdot c})
|
||||
\]
|
||||
|
||||
其中:
|
||||
|
||||
+ \(\Delta t\):自上次检索以来经过的时间(天),由 `last_retrieval_at` 计算;
|
||||
+ \(c\):检索次数,对应字段 `retrieval_count`;
|
||||
+ \(\alpha\):控制时间衰减和检索次数影响的权重;
|
||||
+ \(\gamma\):控制检索次数影响的速率;
|
||||
+ \(\lambda\):控制时间衰减的速率;
|
||||
+ \(\beta\):时间衰减调节因子;
|
||||
|
||||
\[
|
||||
\beta = \frac{1}{1 + a \cdot c}
|
||||
\]
|
||||
|
||||
+ \(a\):控制检索次数对时间衰减影响的权重。
|
||||
|
||||
## ADD MEMORY
|
||||
|
||||
+ LLM 通过 `astr_add_memory` 工具调用,传入记忆内容和记忆类型。
|
||||
+ 生成 `mem_id = uuid4()`。
|
||||
+ 从上下文中获取 `owner_id = unified_message_origin`。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 使用 VecDB 以新记忆内容为 query,检索前 20 条相似记忆。
|
||||
2. 从中取相似度最高的前 5 条:
|
||||
+ 若相似度超过“合并阈值”(如 `sim >= merge_threshold`):
|
||||
+ 将该条记忆视为同一记忆,使用 LLM 将旧内容与新内容合并;
|
||||
+ 在同一个 `mem_id` 上更新 MemoryDB 和 VecDB(UPDATE,而非新建)。
|
||||
+ 否则:
|
||||
+ 作为全新的记忆插入:
|
||||
+ 写入 VecDB(metadata 中包含 `mem_id`, `owner_id`);
|
||||
+ 写入 MemoryDB 的 `memory_chunks` 表,初始化:
|
||||
+ `created_at = now`
|
||||
+ `last_retrieval_at = now`
|
||||
+ `retrieval_count = 1` 等。
|
||||
3. 对 VecDB 返回的前 20 条记忆,如果相似度高于某个“赫布阈值”(`hebb_threshold`),则:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
这一步体现了赫布学习:与新记忆共同被激活的旧记忆会获得一次强化。
|
||||
|
||||
## QUERY MEMORY (STATIC)
|
||||
|
||||
+ LLM 通过 `astr_query_memory` 工具调用,无参数。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 从 MemoryDB 的 `memory_chunks` 表中查询当前用户所有活跃记忆:
|
||||
+ `SELECT * FROM memory_chunks WHERE owner_id = ? AND is_active = 1`
|
||||
2. 对每条记忆,根据 `last_retrieval_at` 和 `retrieval_count` 计算对应的 `decay_score`。
|
||||
3. 按 `decay_score` 从高到低排序,返回前 `top_k` 条记忆内容给 LLM。
|
||||
4. 对返回的这 `top_k` 条记忆:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
## QUERY MEMORY (DYNAMIC)(暂不实现)
|
||||
|
||||
+ LLM 提供查询内容作为语义 query。
|
||||
+ 使用 VecDB 检索与该 query 最相似的前 `N` 条记忆(`N > top_k`)。
|
||||
+ 根据 `mem_id` 从 `memory_chunks` 中加载对应记录。
|
||||
+ 对这批候选记忆计算:
|
||||
+ 语义相似度(来自 VecDB)
|
||||
+ `decay_score`
|
||||
+ 最终排序分数(例如 `w1 * sim + w2 * decay_score`)
|
||||
+ 按最终排序分数从高到低返回前 `top_k` 条记忆内容,并更新它们的 `retrieval_count` 和 `last_retrieval_at`。
|
||||
@@ -1,63 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
from sqlmodel import Field, MetaData, SQLModel
|
||||
|
||||
MEMORY_TYPE_IMPORTANCE = {"persona": 1.3, "fact": 1.0, "ephemeral": 0.8}
|
||||
|
||||
|
||||
class BaseMemoryModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class MemoryChunk(BaseMemoryModel, table=True):
|
||||
"""A chunk of memory stored in the system."""
|
||||
|
||||
__tablename__ = "memory_chunks" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
mem_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
fact: str = Field(nullable=False)
|
||||
"""The factual content of the memory chunk."""
|
||||
owner_id: str = Field(max_length=255, nullable=False, index=True)
|
||||
"""The identifier of the owner (user) of the memory chunk."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The timestamp when the memory chunk was created."""
|
||||
last_retrieval_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
"""The timestamp when the memory chunk was last retrieved."""
|
||||
retrieval_count: int = Field(default=1, nullable=False)
|
||||
"""The number of times the memory chunk has been retrieved."""
|
||||
memory_type: str = Field(max_length=20, nullable=False, default="fact")
|
||||
"""The type of memory (e.g., 'persona', 'fact', 'ephemeral')."""
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
"""Whether the memory chunk is active."""
|
||||
|
||||
def compute_decay_score(self, current_time: datetime) -> float:
|
||||
"""Compute the decay score of the memory chunk based on time and retrievals."""
|
||||
# Constants for the decay formula
|
||||
alpha = 0.5
|
||||
gamma = 0.1
|
||||
lambda_ = 0.05
|
||||
a = 0.1
|
||||
|
||||
# Calculate delta_t in days
|
||||
delta_t = (current_time - self.last_retrieval_at).total_seconds() / 86400
|
||||
c = self.retrieval_count
|
||||
beta = 1 / (1 + a * c)
|
||||
decay_score = alpha * np.exp(-lambda_ * delta_t * beta) + (1 - alpha) * (
|
||||
1 - np.exp(-gamma * c)
|
||||
)
|
||||
return decay_score * MEMORY_TYPE_IMPORTANCE.get(self.memory_type, 1.0)
|
||||
@@ -1,174 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import col
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
from .entities import BaseMemoryModel, MemoryChunk
|
||||
|
||||
|
||||
class MemoryDatabase:
|
||||
def __init__(self, db_path: str = "data/astr_memory/memory.db") -> None:
|
||||
"""Initialize memory database
|
||||
|
||||
Args:
|
||||
db_path: Database file path, default is data/astr_memory/memory.db
|
||||
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# Ensure directory exists
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create async engine
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""Get database session
|
||||
|
||||
Usage:
|
||||
async with mem_db.get_db() as session:
|
||||
# Perform database operations
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize database, create tables and configure SQLite parameters"""
|
||||
async with self.engine.begin() as conn:
|
||||
# Create all memory related tables
|
||||
await conn.run_sync(BaseMemoryModel.metadata.create_all)
|
||||
|
||||
# Configure SQLite performance optimization parameters
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
await self._create_indexes()
|
||||
self.inited = True
|
||||
logger.info(f"Memory database initialized: {self.db_path}")
|
||||
|
||||
async def _create_indexes(self) -> None:
|
||||
"""Create indexes for memory_chunks table"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
# Create memory chunks table indexes
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_mem_id "
|
||||
"ON memory_chunks(mem_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_id "
|
||||
"ON memory_chunks(owner_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_active "
|
||||
"ON memory_chunks(owner_id, is_active)",
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"Memory database closed: {self.db_path}")
|
||||
|
||||
async def insert_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Insert a new memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_memory_by_id(self, mem_id: str) -> MemoryChunk | None:
|
||||
"""Get memory chunk by mem_id"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(col(MemoryChunk.mem_id) == mem_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Update an existing memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_active_memories(self, owner_id: str) -> list[MemoryChunk]:
|
||||
"""Get all active memories for a user"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(
|
||||
col(MemoryChunk.owner_id) == owner_id,
|
||||
col(MemoryChunk.is_active) == True, # noqa: E712
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_retrieval_stats(
|
||||
self,
|
||||
mem_ids: list[str],
|
||||
current_time: datetime | None = None,
|
||||
) -> None:
|
||||
"""Update retrieval statistics for multiple memories"""
|
||||
if not mem_ids:
|
||||
return
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id).in_(mem_ids))
|
||||
.values(
|
||||
retrieval_count=MemoryChunk.retrieval_count + 1,
|
||||
last_retrieval_at=current_time,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def deactivate_memory(self, mem_id: str) -> bool:
|
||||
"""Deactivate a memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id) == mem_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0 if result.rowcount else False # type: ignore
|
||||
@@ -1,281 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import Provider as LLMProvider
|
||||
|
||||
from .entities import MemoryChunk
|
||||
from .mem_db_sqlite import MemoryDatabase
|
||||
|
||||
MERGE_THRESHOLD = 0.85
|
||||
"""Similarity threshold for merging memories"""
|
||||
HEBB_THRESHOLD = 0.70
|
||||
"""Similarity threshold for Hebbian learning reinforcement"""
|
||||
MERGE_SYSTEM_PROMPT = """You are a memory consolidation assistant. Your task is to merge two related memory entries into a single, comprehensive memory.
|
||||
|
||||
Input format:
|
||||
- Old memory: [existing memory content]
|
||||
- New memory: [new memory content to be integrated]
|
||||
|
||||
Your output should be a single, concise memory that combines the essential information from both entries. Preserve specific details, update outdated information, and eliminate redundancy. Output only the merged memory content without any explanations or meta-commentary."""
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""Manager for user long-term memory storage and retrieval"""
|
||||
|
||||
def __init__(self, memory_root_dir: str = "data/astr_memory"):
|
||||
self.memory_root_dir = Path(memory_root_dir)
|
||||
self.memory_root_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.mem_db: MemoryDatabase | None = None
|
||||
self.vec_db: FaissVecDB | None = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
merge_llm_provider: LLMProvider,
|
||||
):
|
||||
"""Initialize memory database and vector database"""
|
||||
# Initialize MemoryDB
|
||||
db_path = self.memory_root_dir / "memory.db"
|
||||
self.mem_db = MemoryDatabase(db_path.as_posix())
|
||||
await self.mem_db.initialize()
|
||||
|
||||
self.embedding_provider = embedding_provider
|
||||
self.merge_llm_provider = merge_llm_provider
|
||||
|
||||
# Initialize VecDB
|
||||
doc_store_path = self.memory_root_dir / "doc.db"
|
||||
index_store_path = self.memory_root_dir / "index.faiss"
|
||||
self.vec_db = FaissVecDB(
|
||||
doc_store_path=doc_store_path.as_posix(),
|
||||
index_store_path=index_store_path.as_posix(),
|
||||
embedding_provider=self.embedding_provider,
|
||||
)
|
||||
await self.vec_db.initialize()
|
||||
|
||||
logger.info("Memory manager initialized")
|
||||
self._initialized = True
|
||||
|
||||
async def terminate(self):
|
||||
"""Close all database connections"""
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
if self.mem_db:
|
||||
await self.mem_db.close()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
fact: str,
|
||||
owner_id: str,
|
||||
memory_type: str = "fact",
|
||||
) -> MemoryChunk:
|
||||
"""Add a new memory with similarity check and merge logic
|
||||
|
||||
Implements the ADD MEMORY workflow from _README.md:
|
||||
1. Search for similar memories using VecDB
|
||||
2. If similarity >= merge_threshold, merge with existing memory
|
||||
3. Otherwise, create new memory
|
||||
4. Apply Hebbian learning to similar memories (similarity >= hebb_threshold)
|
||||
|
||||
Args:
|
||||
fact: Memory content
|
||||
owner_id: User identifier
|
||||
memory_type: Memory type ('persona', 'fact', 'ephemeral')
|
||||
|
||||
Returns:
|
||||
The created or updated MemoryChunk
|
||||
|
||||
"""
|
||||
if not self.vec_db or not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Search for similar memories
|
||||
similar_results = await self.vec_db.retrieve(
|
||||
query=fact,
|
||||
k=20,
|
||||
fetch_k=50,
|
||||
metadata_filters={"owner_id": owner_id},
|
||||
)
|
||||
|
||||
# Step 2: Check if we should merge with existing memories (top 3 similar ones)
|
||||
merge_candidates = [
|
||||
r for r in similar_results[:3] if r.similarity >= MERGE_THRESHOLD
|
||||
]
|
||||
|
||||
if merge_candidates:
|
||||
# Get all candidate memories from database
|
||||
candidate_memories: list[tuple[str, MemoryChunk]] = []
|
||||
for candidate in merge_candidates:
|
||||
mem_id = json.loads(candidate.data["metadata"])["mem_id"]
|
||||
memory = await self.mem_db.get_memory_by_id(mem_id)
|
||||
if memory:
|
||||
candidate_memories.append((mem_id, memory))
|
||||
|
||||
if candidate_memories:
|
||||
# Use the most similar memory as the base
|
||||
base_mem_id, base_memory = candidate_memories[0]
|
||||
|
||||
# Collect all facts to merge (existing candidates + new fact)
|
||||
all_facts = [mem.fact for _, mem in candidate_memories] + [fact]
|
||||
merged_fact = await self._merge_multiple_memories(all_facts)
|
||||
|
||||
# Update the base memory
|
||||
base_memory.fact = merged_fact
|
||||
base_memory.last_retrieval_at = current_time
|
||||
base_memory.retrieval_count += 1
|
||||
updated_memory = await self.mem_db.update_memory(base_memory)
|
||||
|
||||
# Update VecDB for base memory
|
||||
await self.vec_db.delete(base_mem_id)
|
||||
await self.vec_db.insert(
|
||||
content=merged_fact,
|
||||
metadata={
|
||||
"mem_id": base_mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=base_mem_id,
|
||||
)
|
||||
|
||||
# Deactivate and remove other merged memories
|
||||
for mem_id, _ in candidate_memories[1:]:
|
||||
await self.mem_db.deactivate_memory(mem_id)
|
||||
await self.vec_db.delete(mem_id)
|
||||
|
||||
logger.info(
|
||||
f"Merged {len(candidate_memories)} memories into {base_mem_id} for user {owner_id}"
|
||||
)
|
||||
return updated_memory
|
||||
|
||||
# Step 3: Create new memory
|
||||
mem_id = str(uuid.uuid4())
|
||||
new_memory = MemoryChunk(
|
||||
mem_id=mem_id,
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
created_at=current_time,
|
||||
last_retrieval_at=current_time,
|
||||
retrieval_count=1,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Insert into MemoryDB
|
||||
created_memory = await self.mem_db.insert_memory(new_memory)
|
||||
|
||||
# Insert into VecDB
|
||||
await self.vec_db.insert(
|
||||
content=fact,
|
||||
metadata={
|
||||
"mem_id": mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=mem_id,
|
||||
)
|
||||
|
||||
# Step 4: Apply Hebbian learning to similar memories
|
||||
hebb_mem_ids = [
|
||||
json.loads(r.data["metadata"])["mem_id"]
|
||||
for r in similar_results
|
||||
if r.similarity >= HEBB_THRESHOLD
|
||||
]
|
||||
if hebb_mem_ids:
|
||||
await self.mem_db.update_retrieval_stats(hebb_mem_ids, current_time)
|
||||
logger.debug(
|
||||
f"Applied Hebbian learning to {len(hebb_mem_ids)} memories for user {owner_id}",
|
||||
)
|
||||
|
||||
logger.info(f"Created new memory {mem_id} for user {owner_id}")
|
||||
return created_memory
|
||||
|
||||
async def query_memory(
|
||||
self,
|
||||
owner_id: str,
|
||||
top_k: int = 5,
|
||||
) -> list[MemoryChunk]:
|
||||
"""Query user's memories using static retrieval with decay score ranking
|
||||
|
||||
Implements the QUERY MEMORY (STATIC) workflow from _README.md:
|
||||
1. Get all active memories for user from MemoryDB
|
||||
2. Compute decay_score for each memory
|
||||
3. Sort by decay_score and return top_k
|
||||
4. Update retrieval statistics for returned memories
|
||||
|
||||
Args:
|
||||
owner_id: User identifier
|
||||
top_k: Number of memories to return
|
||||
|
||||
Returns:
|
||||
List of top_k MemoryChunk sorted by decay score
|
||||
"""
|
||||
if not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Get all active memories for user
|
||||
all_memories = await self.mem_db.get_active_memories(owner_id)
|
||||
|
||||
if not all_memories:
|
||||
return []
|
||||
|
||||
# Step 2-3: Compute decay scores and sort
|
||||
memories_with_scores = [
|
||||
(mem, mem.compute_decay_score(current_time)) for mem in all_memories
|
||||
]
|
||||
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get top_k memories
|
||||
top_memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
||||
|
||||
# Step 4: Update retrieval statistics
|
||||
mem_ids = [mem.mem_id for mem in top_memories]
|
||||
await self.mem_db.update_retrieval_stats(mem_ids, current_time)
|
||||
|
||||
logger.debug(f"Retrieved {len(top_memories)} memories for user {owner_id}")
|
||||
return top_memories
|
||||
|
||||
async def _merge_multiple_memories(self, facts: list[str]) -> str:
|
||||
"""Merge multiple memory facts using LLM in one call
|
||||
|
||||
Args:
|
||||
facts: List of memory facts to merge
|
||||
|
||||
Returns:
|
||||
Merged memory content
|
||||
"""
|
||||
if not self.merge_llm_provider:
|
||||
return " ".join(facts)
|
||||
|
||||
if len(facts) == 1:
|
||||
return facts[0]
|
||||
|
||||
try:
|
||||
# Format all facts as a numbered list
|
||||
facts_list = "\n".join(f"{i + 1}. {fact}" for i, fact in enumerate(facts))
|
||||
user_prompt = (
|
||||
f"Please merge the following {len(facts)} related memory entries "
|
||||
"into a single, comprehensive memory:"
|
||||
f"\n{facts_list}\n\nOutput only the merged memory content."
|
||||
)
|
||||
response = await self.merge_llm_provider.text_chat(
|
||||
prompt=user_prompt,
|
||||
system_prompt=MERGE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
merged_content = response.completion_text.strip()
|
||||
return merged_content if merged_content else " ".join(facts)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to merge memories with LLM: {e}, using fallback")
|
||||
return " ".join(facts)
|
||||
@@ -1,156 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext, ContextWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for adding memories to user's long-term memory storage"""
|
||||
|
||||
name: str = "astr_add_memory"
|
||||
description: str = (
|
||||
"Add a new memory to the user's long-term memory storage. "
|
||||
"Use this tool only when the user explicitly asks you to remember something, "
|
||||
"or when they share stable preferences, identity, or long-term goals that will be useful in future interactions."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fact": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The concrete memory content to store, such as a user preference, "
|
||||
"identity detail, long-term goal, or stable profile fact."
|
||||
),
|
||||
},
|
||||
"memory_type": {
|
||||
"type": "string",
|
||||
"enum": ["persona", "fact", "ephemeral"],
|
||||
"description": (
|
||||
"The relative importance of this memory. "
|
||||
"Use 'persona' for core identity or highly impactful information, "
|
||||
"'fact' for normal long-term preferences, "
|
||||
"and 'ephemeral' for minor or tentative facts."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["fact", "memory_type"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Add a memory to long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Must contain 'fact' and 'memory_type'
|
||||
|
||||
Returns:
|
||||
ToolExecResult with success message
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
fact = kwargs.get("fact")
|
||||
memory_type = kwargs.get("memory_type", "fact")
|
||||
|
||||
if not fact:
|
||||
return "Missing required parameter: fact"
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Add memory using memory manager
|
||||
memory = await mm.add_memory(
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
)
|
||||
|
||||
return f"Memory added successfully (ID: {memory.mem_id})"
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to add memory: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for querying user's long-term memories"""
|
||||
|
||||
name: str = "astr_query_memory"
|
||||
description: str = (
|
||||
"Query the user's long-term memory storage and return the most relevant memories. "
|
||||
"Use this tool when you need user-specific context, preferences, or past facts "
|
||||
"that are not explicitly present in the current conversation."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of memories to retrieve after retention-based ranking. "
|
||||
"Typically between 3 and 10."
|
||||
),
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Query memories from long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Optional 'top_k' parameter
|
||||
|
||||
Returns:
|
||||
ToolExecResult with formatted memory list
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Query memories using memory manager
|
||||
memories = await mm.query_memory(
|
||||
owner_id=owner_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not memories:
|
||||
return "No memories found for this user."
|
||||
|
||||
# Format memories for output
|
||||
formatted_memories = []
|
||||
for i, mem in enumerate(memories, 1):
|
||||
formatted_memories.append(
|
||||
f"{i}. [{mem.memory_type.upper()}] {mem.fact} "
|
||||
f"(retrieved {mem.retrieval_count} times, "
|
||||
f"last: {mem.last_retrieval_at.strftime('%Y-%m-%d')})"
|
||||
)
|
||||
|
||||
result_text = "Retrieved memories:\n" + "\n".join(formatted_memories)
|
||||
return result_text
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to query memories: {str(e)}"
|
||||
|
||||
|
||||
ADD_MEMORY_TOOL = AddMemory()
|
||||
QUERY_MEMORY_TOOL = QueryMemory()
|
||||
@@ -722,7 +722,12 @@ class File(BaseMessageComponent):
|
||||
"""下载文件"""
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
if self.name:
|
||||
name, ext = os.path.splitext(self.name)
|
||||
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
else:
|
||||
filename = f"{uuid.uuid4().hex}"
|
||||
file_path = os.path.join(download_dir, filename)
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from .agent_sub_stages.internal import InternalAgentSubStage
|
||||
from .agent_sub_stages.third_party import ThirdPartyAgentSubStage
|
||||
|
||||
|
||||
class AgentRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
|
||||
self.bot_wake_prefixs: list[str] = self.config["wake_prefix"]
|
||||
self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"]
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.prov_wake_prefix.startswith(bwp):
|
||||
logger.info(
|
||||
f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
||||
)
|
||||
self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :]
|
||||
|
||||
agent_runner_type = self.config["provider_settings"]["agent_runner_type"]
|
||||
if agent_runner_type == "local":
|
||||
self.agent_sub_stage = InternalAgentSubStage()
|
||||
else:
|
||||
self.agent_sub_stage = ThirdPartyAgentSubStage()
|
||||
await self.agent_sub_stage.initialize(ctx)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]:
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug(
|
||||
"This pipeline does not enable AI capability, skip processing."
|
||||
)
|
||||
return
|
||||
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(
|
||||
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
|
||||
)
|
||||
return
|
||||
|
||||
async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix):
|
||||
yield resp
|
||||
+74
-49
@@ -9,7 +9,7 @@ from astrbot.core import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
@@ -21,28 +21,25 @@ from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from ....astr_agent_context import AgentContextWrapper
|
||||
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....astr_agent_run_util import AgentRunner, run_agent
|
||||
from ....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ....memory.tools import ADD_MEMORY_TOOL, QUERY_MEMORY_TOOL
|
||||
from ...context import PipelineContext, call_event_hook
|
||||
from ..stage import Stage
|
||||
from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
||||
from .....astr_agent_context import AgentContextWrapper
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from .....astr_agent_run_util import AgentRunner, run_agent
|
||||
from .....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
class InternalAgentSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
||||
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
@@ -60,12 +57,12 @@ class LLMRequestSubStage(Stage):
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(
|
||||
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
||||
)
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
||||
file_extract_conf: dict = settings.get("file_extract", {})
|
||||
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
|
||||
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
|
||||
self.file_extract_msh_api_key: str = file_extract_conf.get(
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
@@ -125,14 +122,49 @@ class LLMRequestSubStage(Stage):
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
async def _apply_memory(self, req: ProviderRequest):
|
||||
mm = self.ctx.plugin_manager.context.memory_manager
|
||||
if not mm or not mm._initialized:
|
||||
async def _apply_file_extract(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply file extract to the provider request"""
|
||||
file_paths = []
|
||||
file_names = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
file_paths.append(await comp.get_file())
|
||||
file_names.append(comp.name)
|
||||
elif isinstance(comp, Reply) and comp.chain:
|
||||
for reply_comp in comp.chain:
|
||||
if isinstance(reply_comp, File):
|
||||
file_paths.append(await reply_comp.get_file())
|
||||
file_names.append(reply_comp.name)
|
||||
if not file_paths:
|
||||
return
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(ADD_MEMORY_TOOL)
|
||||
req.func_tool.add_tool(QUERY_MEMORY_TOOL)
|
||||
if not req.prompt:
|
||||
req.prompt = "总结一下文件里面讲了什么?"
|
||||
if self.file_extract_prov == "moonshotai":
|
||||
if not self.file_extract_msh_api_key:
|
||||
logger.error("Moonshot AI API key for file extract is not set")
|
||||
return
|
||||
file_contents = await asyncio.gather(
|
||||
*[
|
||||
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
|
||||
for file_path in file_paths
|
||||
]
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
|
||||
return
|
||||
|
||||
# add file extract results to contexts
|
||||
for file_content, file_name in zip(file_contents, file_names):
|
||||
req.contexts.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
@@ -314,21 +346,10 @@ class LLMRequestSubStage(Stage):
|
||||
return fixed_messages
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
_nested: bool = False,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
@@ -358,12 +379,12 @@ class LLMRequestSubStage(Stage):
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if self.provider_wake_prefix and not event.message_str.startswith(
|
||||
self.provider_wake_prefix
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
@@ -377,6 +398,17 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
@@ -387,13 +419,6 @@ class LLMRequestSubStage(Stage):
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# apply memory feature
|
||||
await self._apply_memory(req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
@@ -0,0 +1,205 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
|
||||
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
||||
DashscopeAgentRunner,
|
||||
)
|
||||
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.agent.runners.base import BaseAgentRunner
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
|
||||
AGENT_RUNNER_TYPE_KEY = {
|
||||
"dify": "dify_agent_runner_provider_id",
|
||||
"coze": "coze_agent_runner_provider_id",
|
||||
"dashscope": "dashscope_agent_runner_provider_id",
|
||||
}
|
||||
|
||||
|
||||
async def run_third_party_agent(
|
||||
runner: "BaseAgentRunner",
|
||||
stream_to_general: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
"""
|
||||
运行第三方 agent runner 并转换响应格式
|
||||
类似于 run_agent 函数,但专门处理第三方 agent runner
|
||||
"""
|
||||
try:
|
||||
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
|
||||
if resp.type == "streaming_delta":
|
||||
if stream_to_general:
|
||||
continue
|
||||
yield resp.data["chain"]
|
||||
elif resp.type == "llm_result":
|
||||
if stream_to_general:
|
||||
yield resp.data["chain"]
|
||||
except Exception as e:
|
||||
logger.error(f"Third party agent runner error: {e}")
|
||||
err_msg = (
|
||||
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
|
||||
f"错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
)
|
||||
yield MessageChain().message(err_msg)
|
||||
|
||||
|
||||
class ThirdPartyAgentSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.conf = ctx.astrbot_config
|
||||
self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
|
||||
self.prov_id = self.conf["provider_settings"].get(
|
||||
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
|
||||
"",
|
||||
)
|
||||
settings = ctx.astrbot_config["provider_settings"]
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
]
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
self.prov_cfg: dict = next(
|
||||
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id),
|
||||
{},
|
||||
)
|
||||
if not self.prov_id:
|
||||
logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。")
|
||||
return
|
||||
if not self.prov_cfg:
|
||||
logger.error(
|
||||
f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。"
|
||||
)
|
||||
return
|
||||
|
||||
# make provider request
|
||||
req = ProviderRequest()
|
||||
req.session_id = event.unified_msg_origin
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_base64()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
if self.runner_type == "dify":
|
||||
runner = DifyAgentRunner[AstrAgentContext]()
|
||||
elif self.runner_type == "coze":
|
||||
runner = CozeAgentRunner[AstrAgentContext]()
|
||||
elif self.runner_type == "dashscope":
|
||||
runner = DashscopeAgentRunner[AstrAgentContext]()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported third party agent runner type: {self.runner_type}",
|
||||
)
|
||||
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
|
||||
await runner.reset(
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=60,
|
||||
),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
provider_config=self.prov_cfg,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_third_party_agent(
|
||||
runner,
|
||||
stream_to_general=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if runner.done():
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
if final_resp and final_resp.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=final_resp.result_chain.chain or [],
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 非流式响应或转换为普通响应
|
||||
async for _ in run_third_party_agent(
|
||||
runner,
|
||||
stream_to_general=stream_to_general,
|
||||
):
|
||||
yield
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
|
||||
if not final_resp or not final_resp.result_chain:
|
||||
logger.warning("Agent Runner 未返回最终结果。")
|
||||
return
|
||||
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=final_resp.result_chain.chain or [],
|
||||
result_content_type=ResultContentType.LLM_RESULT,
|
||||
),
|
||||
)
|
||||
yield
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=self.runner_type,
|
||||
provider_type=self.runner_type,
|
||||
),
|
||||
)
|
||||
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[None, None]:
|
||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||
"activated_handlers",
|
||||
)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
from .method.llm_request import LLMRequestSubStage
|
||||
from .method.agent_request import AgentRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
|
||||
|
||||
@@ -17,9 +16,12 @@ class ProcessStage(Stage):
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
self.llm_request_sub_stage = LLMRequestSubStage()
|
||||
await self.llm_request_sub_stage.initialize(ctx)
|
||||
|
||||
# initialize agent sub stage
|
||||
self.agent_sub_stage = AgentRequestSubStage()
|
||||
await self.agent_sub_stage.initialize(ctx)
|
||||
|
||||
# initialize star request sub stage
|
||||
self.star_request_sub_stage = StarRequestSubStage()
|
||||
await self.star_request_sub_stage.initialize(ctx)
|
||||
|
||||
@@ -39,7 +41,7 @@ class ProcessStage(Stage):
|
||||
# Handler 的 LLM 请求
|
||||
event.set_extra("provider_request", resp)
|
||||
_t = False
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
_t = True
|
||||
yield
|
||||
if not _t:
|
||||
@@ -60,12 +62,5 @@ class ProcessStage(Stage):
|
||||
if (
|
||||
event.get_result() and not event.get_result().is_stopped()
|
||||
) or not event.get_result():
|
||||
# 事件没有终止传播
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -161,11 +161,21 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
@@ -6,7 +6,7 @@ from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
|
||||
from .platform import Platform
|
||||
from .platform import Platform, PlatformStatus
|
||||
from .register import platform_cls_map
|
||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
@@ -16,7 +16,7 @@ class PlatformManager:
|
||||
self.platform_insts: list[Platform] = []
|
||||
"""加载的 Platform 的实例"""
|
||||
|
||||
self._inst_map = {}
|
||||
self._inst_map: dict[str, dict] = {}
|
||||
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
@@ -37,7 +37,10 @@ class PlatformManager:
|
||||
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
|
||||
self.platform_insts.append(webchat_inst)
|
||||
asyncio.create_task(
|
||||
self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")),
|
||||
self._task_wrapper(
|
||||
asyncio.create_task(webchat_inst.run(), name="webchat"),
|
||||
platform=webchat_inst,
|
||||
),
|
||||
)
|
||||
|
||||
async def load_platform(self, platform_config: dict):
|
||||
@@ -107,7 +110,7 @@ class PlatformManager:
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。",
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")
|
||||
@@ -131,6 +134,7 @@ class PlatformManager:
|
||||
inst.run(),
|
||||
name=f"platform_{platform_config['type']}_{platform_config['id']}",
|
||||
),
|
||||
platform=inst,
|
||||
),
|
||||
)
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
@@ -145,17 +149,28 @@ class PlatformManager:
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
|
||||
# 设置平台状态为运行中
|
||||
if platform:
|
||||
platform.status = PlatformStatus.RUNNING
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if platform:
|
||||
platform.status = PlatformStatus.STOPPED
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
tb_str = traceback.format_exc()
|
||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||
for line in traceback.format_exc().split("\n"):
|
||||
for line in tb_str.split("\n"):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
# 记录错误到平台实例
|
||||
if platform:
|
||||
platform.record_error(error_msg, tb_str)
|
||||
|
||||
async def reload(self, platform_config: dict):
|
||||
await self.terminate_platform(platform_config["id"])
|
||||
if platform_config["enable"]:
|
||||
@@ -172,9 +187,9 @@ class PlatformManager:
|
||||
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
|
||||
|
||||
# client_id = self._inst_map.pop(platform_id, None)
|
||||
info = self._inst_map.pop(platform_id, None)
|
||||
info = self._inst_map.pop(platform_id)
|
||||
client_id = info["client_id"]
|
||||
inst = info["inst"]
|
||||
inst: Platform = info["inst"]
|
||||
try:
|
||||
self.platform_insts.remove(
|
||||
next(
|
||||
@@ -196,3 +211,46 @@ class PlatformManager:
|
||||
|
||||
def get_insts(self):
|
||||
return self.platform_insts
|
||||
|
||||
def get_all_stats(self) -> dict:
|
||||
"""获取所有平台的统计信息
|
||||
|
||||
Returns:
|
||||
包含所有平台统计信息的字典
|
||||
"""
|
||||
stats_list = []
|
||||
total_errors = 0
|
||||
running_count = 0
|
||||
error_count = 0
|
||||
|
||||
for inst in self.platform_insts:
|
||||
try:
|
||||
stat = inst.get_stats()
|
||||
stats_list.append(stat)
|
||||
total_errors += stat.get("error_count", 0)
|
||||
if stat.get("status") == PlatformStatus.RUNNING.value:
|
||||
running_count += 1
|
||||
elif stat.get("status") == PlatformStatus.ERROR.value:
|
||||
error_count += 1
|
||||
except Exception as e:
|
||||
# 如果获取统计信息失败,记录基本信息
|
||||
logger.warning(f"获取平台统计信息失败: {e}")
|
||||
stats_list.append(
|
||||
{
|
||||
"id": getattr(inst, "config", {}).get("id", "unknown"),
|
||||
"type": "unknown",
|
||||
"status": "unknown",
|
||||
"error_count": 0,
|
||||
"last_error": None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"platforms": stats_list,
|
||||
"summary": {
|
||||
"total": len(stats_list),
|
||||
"running": running_count,
|
||||
"error": error_count,
|
||||
"total_errors": total_errors,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -12,13 +15,90 @@ from .message_session import MessageSesion
|
||||
from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
class PlatformStatus(Enum):
|
||||
"""平台运行状态"""
|
||||
|
||||
PENDING = "pending" # 待启动
|
||||
RUNNING = "running" # 运行中
|
||||
ERROR = "error" # 发生错误
|
||||
STOPPED = "stopped" # 已停止
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformError:
|
||||
"""平台错误信息"""
|
||||
|
||||
message: str
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
traceback: str | None = None
|
||||
|
||||
|
||||
class Platform(abc.ABC):
|
||||
def __init__(self, event_queue: Queue):
|
||||
def __init__(self, config: dict, event_queue: Queue):
|
||||
super().__init__()
|
||||
# 平台配置
|
||||
self.config = config
|
||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||
self._event_queue = event_queue
|
||||
self.client_self_id = uuid.uuid4().hex
|
||||
|
||||
# 平台运行状态
|
||||
self._status: PlatformStatus = PlatformStatus.PENDING
|
||||
self._errors: list[PlatformError] = []
|
||||
self._started_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def status(self) -> PlatformStatus:
|
||||
"""获取平台运行状态"""
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, value: PlatformStatus):
|
||||
"""设置平台运行状态"""
|
||||
self._status = value
|
||||
if value == PlatformStatus.RUNNING and self._started_at is None:
|
||||
self._started_at = datetime.now()
|
||||
|
||||
@property
|
||||
def errors(self) -> list[PlatformError]:
|
||||
"""获取错误列表"""
|
||||
return self._errors
|
||||
|
||||
@property
|
||||
def last_error(self) -> PlatformError | None:
|
||||
"""获取最近的错误"""
|
||||
return self._errors[-1] if self._errors else None
|
||||
|
||||
def record_error(self, message: str, traceback_str: str | None = None):
|
||||
"""记录一个错误"""
|
||||
self._errors.append(PlatformError(message=message, traceback=traceback_str))
|
||||
self._status = PlatformStatus.ERROR
|
||||
|
||||
def clear_errors(self):
|
||||
"""清除错误记录"""
|
||||
self._errors.clear()
|
||||
if self._status == PlatformStatus.ERROR:
|
||||
self._status = PlatformStatus.RUNNING
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取平台统计信息"""
|
||||
meta = self.meta()
|
||||
return {
|
||||
"id": meta.id or self.config.get("id"),
|
||||
"type": meta.name,
|
||||
"display_name": meta.adapter_display_name or meta.name,
|
||||
"status": self._status.value,
|
||||
"started_at": self._started_at.isoformat() if self._started_at else None,
|
||||
"error_count": len(self._errors),
|
||||
"last_error": {
|
||||
"message": self.last_error.message,
|
||||
"timestamp": self.last_error.timestamp.isoformat(),
|
||||
"traceback": self.last_error.traceback,
|
||||
}
|
||||
if self.last_error
|
||||
else None,
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self) -> Awaitable[Any]:
|
||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||
@@ -36,7 +116,7 @@ class Platform(abc.ABC):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
) -> Awaitable[Any]:
|
||||
):
|
||||
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||
|
||||
异步方法。
|
||||
@@ -49,3 +129,20 @@ class Platform(abc.ABC):
|
||||
|
||||
def get_client(self):
|
||||
"""获取平台的客户端对象。"""
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口。
|
||||
|
||||
支持统一 Webhook 模式的平台需要实现此方法。
|
||||
当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容,格式取决于具体平台的要求
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 平台未实现统一 Webhook 模式
|
||||
"""
|
||||
raise NotImplementedError(f"平台 {self.meta().name} 未实现统一 Webhook 模式")
|
||||
|
||||
@@ -38,9 +38,8 @@ class AiocqhttpAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.host = platform_config["ws_reverse_host"]
|
||||
@@ -154,7 +153,9 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 通知类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||
)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if event.get("group_id"):
|
||||
abm.group_id = str(event.group_id)
|
||||
@@ -246,7 +247,13 @@ class AiocqhttpAdapter(Platform):
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
# 检查多个可能的文件名字段
|
||||
file_name = (
|
||||
m["data"].get("file_name", "")
|
||||
or m["data"].get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or "file"
|
||||
)
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
@@ -265,7 +272,14 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
a = File(name="", url=file_url)
|
||||
# 优先从 API 返回值获取文件名,其次从原始消息数据获取
|
||||
file_name = (
|
||||
ret.get("file_name", "")
|
||||
or ret.get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or m["data"].get("file_name", "")
|
||||
)
|
||||
a = File(name=file_name, url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
|
||||
@@ -47,9 +47,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
@@ -76,13 +74,13 @@ class DingtalkPlatformAdapter(Platform):
|
||||
)
|
||||
self.client_ = client # 用于 websockets 的 client
|
||||
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
||||
if not dingtalk_id:
|
||||
return dingtalk_id
|
||||
return dingtalk_id or "unknown"
|
||||
prefix = "$:LWCP_v1:$"
|
||||
if dingtalk_id.startswith(prefix):
|
||||
return dingtalk_id[len(prefix) :]
|
||||
return dingtalk_id
|
||||
return dingtalk_id or "unknown"
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
@@ -250,7 +248,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
async def terminate(self):
|
||||
def monkey_patch_close():
|
||||
raise Exception("Graceful shutdown")
|
||||
raise KeyboardInterrupt("Graceful shutdown")
|
||||
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
|
||||
@@ -44,8 +44,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = None
|
||||
self.registered_handlers = []
|
||||
|
||||
@@ -33,9 +33,7 @@ class LarkPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
|
||||
@@ -55,8 +55,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config or {}
|
||||
super().__init__(platform_config or {}, event_queue)
|
||||
self.settings = platform_settings or {}
|
||||
self.instance_url = self.config.get("misskey_instance_url", "")
|
||||
self.access_token = self.config.get("misskey_token", "")
|
||||
|
||||
@@ -69,6 +69,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||
stream_payload["state"] = 10
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
else:
|
||||
ret = await self._post_send()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||
|
||||
@@ -97,9 +97,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import botpy
|
||||
import botpy.message
|
||||
@@ -11,6 +12,7 @@ from astrbot import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, MessageType, Platform, PlatformMetadata
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||
@@ -87,13 +89,12 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
intents = botpy.Intents(
|
||||
public_messages=True,
|
||||
@@ -106,6 +107,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
timeout=20,
|
||||
)
|
||||
self.client.set_platform(self)
|
||||
self.webhook_helper = None
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
@@ -128,16 +130,37 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
self.client,
|
||||
)
|
||||
await self.webhook_helper.initialize()
|
||||
await self.webhook_helper.start_polling()
|
||||
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(QQ 官方机器人 Webhook)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.webhook_helper.shutdown_event.wait()
|
||||
else:
|
||||
await self.webhook_helper.start_polling()
|
||||
|
||||
def get_client(self) -> botClient:
|
||||
return self.client
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
if not self.webhook_helper:
|
||||
return {"error": "Webhook helper not initialized"}, 500
|
||||
|
||||
# 复用 webhook_helper 的回调处理逻辑
|
||||
return await self.webhook_helper.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
self.webhook_helper.shutdown_event.set()
|
||||
if self.webhook_helper:
|
||||
self.webhook_helper.shutdown_event.set()
|
||||
await self.client.close()
|
||||
try:
|
||||
await self.webhook_helper.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
if self.webhook_helper and not self.unified_webhook_mode:
|
||||
try:
|
||||
await self.webhook_helper.server.shutdown()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Exception occurred during QQOfficialWebhook server shutdown: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
|
||||
|
||||
@@ -78,7 +78,19 @@ class QQOfficialWebhook:
|
||||
return response
|
||||
|
||||
async def callback(self):
|
||||
msg: dict = await quart.request.json
|
||||
"""内部服务器的回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request) -> dict:
|
||||
"""处理 webhook 回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
msg: dict = await request.json
|
||||
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
|
||||
|
||||
event = msg.get("t")
|
||||
|
||||
@@ -38,8 +38,7 @@ class SatoriPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
|
||||
self.api_base_url = self.config.get(
|
||||
|
||||
@@ -47,51 +47,62 @@ class SlackWebhookClient:
|
||||
|
||||
@self.app.route(self.path, methods=["POST"])
|
||||
async def slack_events():
|
||||
"""处理 Slack 事件"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = await request.get_data()
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
timestamp = request.headers.get("X-Slack-Request-Timestamp")
|
||||
signature = request.headers.get("X-Slack-Signature")
|
||||
if not timestamp or not signature:
|
||||
return Response("Missing headers", status=400)
|
||||
# Calculate the HMAC signature
|
||||
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
my_signature = (
|
||||
"v0="
|
||||
+ hmac.new(
|
||||
self.signing_secret.encode("utf-8"),
|
||||
sig_basestring.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
)
|
||||
# Verify the signature
|
||||
if not hmac.compare_digest(my_signature, signature):
|
||||
logger.warning("Slack request signature verification failed")
|
||||
return Response("Invalid signature", status=400)
|
||||
logger.info(f"Received Slack event: {event_data}")
|
||||
|
||||
# 处理 URL 验证事件
|
||||
if event_data.get("type") == "url_verification":
|
||||
return {"challenge": event_data.get("challenge")}
|
||||
# 处理事件
|
||||
if self.event_handler and event_data.get("type") == "event_callback":
|
||||
await self.event_handler(event_data)
|
||||
|
||||
return Response("", status=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Slack 事件时出错: {e}")
|
||||
return Response("Internal Server Error", status=500)
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(request)
|
||||
|
||||
@self.app.route("/health", methods=["GET"])
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {"status": "ok", "service": "slack-webhook"}
|
||||
|
||||
async def handle_callback(self, req):
|
||||
"""处理 Slack 回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
req: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
Response 对象或字典
|
||||
"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = await req.get_data()
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
timestamp = req.headers.get("X-Slack-Request-Timestamp")
|
||||
signature = req.headers.get("X-Slack-Signature")
|
||||
if not timestamp or not signature:
|
||||
return Response("Missing headers", status=400)
|
||||
# Calculate the HMAC signature
|
||||
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
my_signature = (
|
||||
"v0="
|
||||
+ hmac.new(
|
||||
self.signing_secret.encode("utf-8"),
|
||||
sig_basestring.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
)
|
||||
# Verify the signature
|
||||
if not hmac.compare_digest(my_signature, signature):
|
||||
logger.warning("Slack request signature verification failed")
|
||||
return Response("Invalid signature", status=400)
|
||||
logger.info(f"Received Slack event: {event_data}")
|
||||
|
||||
# 处理 URL 验证事件
|
||||
if event_data.get("type") == "url_verification":
|
||||
return {"challenge": event_data.get("challenge")}
|
||||
# 处理事件
|
||||
if self.event_handler and event_data.get("type") == "event_callback":
|
||||
await self.event_handler(event_data)
|
||||
|
||||
return Response("", status=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Slack 事件时出错: {e}")
|
||||
return Response("Internal Server Error", status=500)
|
||||
|
||||
async def start(self):
|
||||
"""启动 Webhook 服务器"""
|
||||
logger.info(
|
||||
|
||||
@@ -21,6 +21,7 @@ from astrbot.api.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .client import SlackSocketClient, SlackWebhookClient
|
||||
@@ -39,9 +40,7 @@ class SlackAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
@@ -49,6 +48,7 @@ class SlackAdapter(Platform):
|
||||
self.app_token = platform_config.get("app_token")
|
||||
self.signing_secret = platform_config.get("signing_secret")
|
||||
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
|
||||
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
|
||||
self.webhook_path = platform_config.get(
|
||||
@@ -361,10 +361,17 @@ class SlackAdapter(Platform):
|
||||
self._handle_webhook_event,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
|
||||
)
|
||||
await self.webhook_client.start()
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(Slack)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.webhook_client.shutdown_event.wait()
|
||||
else:
|
||||
logger.info(
|
||||
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
|
||||
)
|
||||
await self.webhook_client.start()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -391,6 +398,13 @@ class SlackAdapter(Platform):
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
if self.connection_mode != "webhook" or not self.webhook_client:
|
||||
return {"error": "Slack adapter is not in webhook mode"}, 400
|
||||
|
||||
return await self.webhook_client.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
if self.socket_client:
|
||||
await self.socket_client.stop()
|
||||
|
||||
@@ -31,7 +31,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
async def _from_segment_to_slack_block(
|
||||
segment: BaseMessageComponent,
|
||||
web_client: AsyncWebClient,
|
||||
) -> dict:
|
||||
) -> dict | None:
|
||||
"""将消息段转换为 Slack 块格式"""
|
||||
if isinstance(segment, Plain):
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
|
||||
@@ -85,7 +85,6 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
||||
},
|
||||
}
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||
|
||||
@staticmethod
|
||||
async def _parse_slack_blocks(
|
||||
@@ -115,7 +114,8 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
segment,
|
||||
web_client,
|
||||
)
|
||||
blocks.append(block)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
|
||||
# 如果最后还有文本内容
|
||||
if text_content.strip():
|
||||
|
||||
@@ -42,8 +42,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@@ -381,7 +380,9 @@ class TelegramPlatformAdapter(Platform):
|
||||
f"Telegram document file_path is None, cannot save the file {file_name}.",
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
||||
message.message.append(
|
||||
Comp.File(file=file_path, name=file_name, url=file_path)
|
||||
)
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
|
||||
@@ -6,7 +6,9 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Image, Plain, Record
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core.db.po import PlatformMessageHistory
|
||||
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform import (
|
||||
AstrBotMessage,
|
||||
@@ -74,9 +76,8 @@ class WebChatAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
@@ -96,6 +97,92 @@ class WebChatAdapter(Platform):
|
||||
await WebChatMessageEvent._send(message_chain, session.session_id)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def _get_message_history(
|
||||
self, message_id: int
|
||||
) -> PlatformMessageHistory | None:
|
||||
return await db_helper.get_platform_message_history_by_id(message_id)
|
||||
|
||||
async def _parse_message_parts(
|
||||
self,
|
||||
message_parts: list,
|
||||
depth: int = 0,
|
||||
max_depth: int = 1,
|
||||
) -> tuple[list, list[str]]:
|
||||
"""解析消息段列表,返回消息组件列表和纯文本列表
|
||||
|
||||
Args:
|
||||
message_parts: 消息段列表
|
||||
depth: 当前递归深度
|
||||
max_depth: 最大递归深度(用于处理 reply)
|
||||
|
||||
Returns:
|
||||
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
|
||||
"""
|
||||
components = []
|
||||
text_parts = []
|
||||
|
||||
for part in message_parts:
|
||||
part_type = part.get("type")
|
||||
if part_type == "plain":
|
||||
text = part.get("text", "")
|
||||
components.append(Plain(text))
|
||||
text_parts.append(text)
|
||||
elif part_type == "reply":
|
||||
message_id = part.get("message_id")
|
||||
reply_chain = []
|
||||
reply_message_str = ""
|
||||
sender_id = None
|
||||
sender_name = None
|
||||
|
||||
# recursively get the content of the referenced message
|
||||
if depth < max_depth and message_id:
|
||||
history = await self._get_message_history(message_id)
|
||||
if history and history.content:
|
||||
reply_parts = history.content.get("message", [])
|
||||
if isinstance(reply_parts, list):
|
||||
(
|
||||
reply_chain,
|
||||
reply_text_parts,
|
||||
) = await self._parse_message_parts(
|
||||
reply_parts,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
reply_message_str = "".join(reply_text_parts)
|
||||
sender_id = history.sender_id
|
||||
sender_name = history.sender_name
|
||||
|
||||
components.append(
|
||||
Reply(
|
||||
id=message_id,
|
||||
chain=reply_chain,
|
||||
message_str=reply_message_str,
|
||||
sender_id=sender_id,
|
||||
sender_nickname=sender_name,
|
||||
)
|
||||
)
|
||||
elif part_type == "image":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Image.fromFileSystem(path))
|
||||
elif part_type == "record":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Record.fromFileSystem(path))
|
||||
elif part_type == "file":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
filename = part.get("filename") or (
|
||||
os.path.basename(path) if path else "file"
|
||||
)
|
||||
components.append(File(name=filename, file=path))
|
||||
elif part_type == "video":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Video.fromFileSystem(path))
|
||||
|
||||
return components, text_parts
|
||||
|
||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||
username, cid, payload = data
|
||||
|
||||
@@ -108,36 +195,15 @@ class WebChatAdapter(Platform):
|
||||
abm.session_id = f"webchat!{username}!{cid}"
|
||||
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.message = []
|
||||
|
||||
if payload["message"]:
|
||||
abm.message.append(Plain(payload["message"]))
|
||||
if payload["image_url"]:
|
||||
if isinstance(payload["image_url"], list):
|
||||
for img in payload["image_url"]:
|
||||
abm.message.append(
|
||||
Image.fromFileSystem(os.path.join(self.imgs_dir, img)),
|
||||
)
|
||||
else:
|
||||
abm.message.append(
|
||||
Image.fromFileSystem(
|
||||
os.path.join(self.imgs_dir, payload["image_url"]),
|
||||
),
|
||||
)
|
||||
if payload["audio_url"]:
|
||||
if isinstance(payload["audio_url"], list):
|
||||
for audio in payload["audio_url"]:
|
||||
path = os.path.join(self.imgs_dir, audio)
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
else:
|
||||
path = os.path.join(self.imgs_dir, payload["audio_url"])
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
# 处理消息段列表
|
||||
message_parts = payload.get("message", [])
|
||||
abm.message, message_str_parts = await self._parse_message_parts(message_parts)
|
||||
|
||||
logger.debug(f"WebChatAdapter: {abm.message}")
|
||||
|
||||
message_str = payload["message"]
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.message_str = "".join(message_str_parts)
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import base64
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Image, Plain, Record
|
||||
from astrbot.api.message_components import File, Image, Plain, Record
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from .webchat_queue_mgr import webchat_queue_mgr
|
||||
|
||||
@@ -19,7 +19,9 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
os.makedirs(imgs_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||
async def _send(
|
||||
message: MessageChain | None, session_id: str, streaming: bool = False
|
||||
) -> str | None:
|
||||
cid = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
if not message:
|
||||
@@ -30,7 +32,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"streaming": False,
|
||||
}, # end means this request is finished
|
||||
)
|
||||
return ""
|
||||
return
|
||||
|
||||
data = ""
|
||||
for comp in message.chain:
|
||||
@@ -47,24 +49,11 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
filename = f"{str(uuid.uuid4())}.jpg"
|
||||
path = os.path.join(imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file.startswith("base64://"):
|
||||
base64_str = comp.file[9:]
|
||||
image_data = base64.b64decode(base64_str)
|
||||
with open(path, "wb") as f:
|
||||
f.write(image_data)
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
image_base64 = await comp.convert_to_base64()
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(image_base64))
|
||||
data = f"[IMAGE]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
@@ -76,19 +65,11 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
# save record to local
|
||||
filename = str(uuid.uuid4()) + ".wav"
|
||||
filename = f"{str(uuid.uuid4())}.wav"
|
||||
path = os.path.join(imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
record_base64 = await comp.convert_to_base64()
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(record_base64))
|
||||
data = f"[RECORD]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
@@ -98,6 +79,23 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"streaming": streaming,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
# save file to local
|
||||
file_path = await comp.get_file()
|
||||
original_name = comp.name or os.path.basename(file_path)
|
||||
ext = os.path.splitext(original_name)[1] or ""
|
||||
filename = f"{uuid.uuid4()!s}{ext}"
|
||||
dest_path = os.path.join(imgs_dir, filename)
|
||||
shutil.copy2(file_path, dest_path)
|
||||
data = f"[FILE]{filename}|{original_name}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "file",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.debug(f"webchat 忽略: {comp.type}")
|
||||
|
||||
@@ -131,6 +129,8 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
session_id=self.session_id,
|
||||
streaming=True,
|
||||
)
|
||||
if not r:
|
||||
continue
|
||||
if chain.type == "reasoning":
|
||||
reasoning_content += chain.get_plain_text()
|
||||
else:
|
||||
|
||||
@@ -42,10 +42,9 @@ class WeChatPadProAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
super().__init__(platform_config, event_queue)
|
||||
self._shutdown_event = None
|
||||
self.wxnewpass = None
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -24,6 +25,7 @@ from astrbot.api.platform import (
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from .wecom_event import WecomPlatformEvent
|
||||
from .wecom_kf import WeChatKF
|
||||
@@ -62,8 +64,20 @@ class WecomServer:
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||
args = quart.request.args
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
|
||||
async def handle_verify(self, request) -> str:
|
||||
"""处理验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
验证响应
|
||||
"""
|
||||
logger.info(f"验证请求有效性: {request.args}")
|
||||
args = request.args
|
||||
try:
|
||||
echo_str = self.crypto.check_signature(
|
||||
args.get("msg_signature"),
|
||||
@@ -78,10 +92,22 @@ class WecomServer:
|
||||
raise
|
||||
|
||||
async def callback_command(self):
|
||||
data = await quart.request.get_data()
|
||||
msg_signature = quart.request.args.get("msg_signature")
|
||||
timestamp = quart.request.args.get("timestamp")
|
||||
nonce = quart.request.args.get("nonce")
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request) -> str:
|
||||
"""处理回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
data = await request.get_data()
|
||||
msg_signature = request.args.get("msg_signature")
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
try:
|
||||
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||
except InvalidSignatureException:
|
||||
@@ -118,14 +144,14 @@ class WecomPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settingss = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
self.api_base_url = platform_config.get(
|
||||
"api_base_url",
|
||||
"https://qyapi.weixin.qq.com/cgi-bin/",
|
||||
)
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/"
|
||||
@@ -232,7 +258,23 @@ class WecomPlatformAdapter(Platform):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await self.server.start_polling()
|
||||
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(企业微信)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.server.shutdown_event.wait()
|
||||
else:
|
||||
await self.server.start_polling()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
# 根据请求方法分发到不同的处理函数
|
||||
if request.method == "GET":
|
||||
return await self.server.handle_verify(request)
|
||||
else:
|
||||
return await self.server.handle_callback(request)
|
||||
|
||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
|
||||
@@ -16,7 +16,7 @@ try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
|
||||
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from astrbot.api.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .wecomai_api import (
|
||||
@@ -103,9 +104,7 @@ class WecomAIBotAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
|
||||
# 初始化配置参数
|
||||
@@ -122,6 +121,7 @@ class WecomAIBotAdapter(Platform):
|
||||
"wecomaibot_friend_message_welcome_text",
|
||||
"",
|
||||
)
|
||||
self.unified_webhook_mode = self.config.get("unified_webhook_mode", False)
|
||||
|
||||
# 平台元数据
|
||||
self.metadata = PlatformMetadata(
|
||||
@@ -425,17 +425,34 @@ class WecomAIBotAdapter(Platform):
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
"""运行适配器,同时启动HTTP服务器和队列监听器"""
|
||||
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
|
||||
|
||||
async def run_both():
|
||||
# 同时运行HTTP服务器和队列监听器
|
||||
await asyncio.gather(
|
||||
self.server.start_server(),
|
||||
self.queue_listener.run(),
|
||||
)
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
|
||||
# 只运行队列监听器
|
||||
await self.queue_listener.run()
|
||||
else:
|
||||
logger.info(
|
||||
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
|
||||
)
|
||||
# 同时运行HTTP服务器和队列监听器
|
||||
await asyncio.gather(
|
||||
self.server.start_server(),
|
||||
self.queue_listener.run(),
|
||||
)
|
||||
|
||||
return run_both()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
# 根据请求方法分发到不同的处理函数
|
||||
if request.method == "GET":
|
||||
return await self.server.handle_verify(request)
|
||||
else:
|
||||
return await self.server.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("企业微信智能机器人适配器正在关闭...")
|
||||
|
||||
@@ -59,8 +59,19 @@ class WecomAIBotServer:
|
||||
)
|
||||
|
||||
async def verify_url(self):
|
||||
"""验证回调 URL"""
|
||||
args = quart.request.args
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
|
||||
async def handle_verify(self, request):
|
||||
"""处理 URL 验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
验证响应元组 (content, status_code, headers)
|
||||
"""
|
||||
args = request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
@@ -81,8 +92,19 @@ class WecomAIBotServer:
|
||||
return result, 200, {"Content-Type": "text/plain"}
|
||||
|
||||
async def handle_message(self):
|
||||
"""处理消息回调"""
|
||||
args = quart.request.args
|
||||
"""内部服务器的 POST 消息回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request):
|
||||
"""处理消息回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应元组 (content, status_code, headers)
|
||||
"""
|
||||
args = request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
@@ -102,7 +124,7 @@ class WecomAIBotServer:
|
||||
|
||||
try:
|
||||
# 获取请求体
|
||||
post_data = await quart.request.get_data()
|
||||
post_data = await request.get_data()
|
||||
|
||||
# 确保 post_data 是 bytes 类型
|
||||
if isinstance(post_data, str):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -22,6 +23,7 @@ from astrbot.api.platform import (
|
||||
)
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||
|
||||
@@ -31,7 +33,7 @@ else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class WecomServer:
|
||||
class WeixinOfficialAccountServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
@@ -57,9 +59,21 @@ class WecomServer:
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
|
||||
args = quart.request.args
|
||||
async def handle_verify(self, request) -> str:
|
||||
"""处理验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
验证响应
|
||||
"""
|
||||
logger.info(f"验证请求有效性: {request.args}")
|
||||
|
||||
args = request.args
|
||||
if not args.get("signature", None):
|
||||
logger.error("未知的响应,请检查回调地址是否填写正确。")
|
||||
return "err"
|
||||
@@ -77,10 +91,22 @@ class WecomServer:
|
||||
return "err"
|
||||
|
||||
async def callback_command(self):
|
||||
data = await quart.request.get_data()
|
||||
msg_signature = quart.request.args.get("msg_signature")
|
||||
timestamp = quart.request.args.get("timestamp")
|
||||
nonce = quart.request.args.get("nonce")
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request) -> str:
|
||||
"""处理回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
data = await request.get_data()
|
||||
msg_signature = request.args.get("msg_signature")
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
try:
|
||||
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||
except InvalidSignatureException:
|
||||
@@ -123,8 +149,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settingss = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
self.api_base_url = platform_config.get(
|
||||
@@ -132,6 +157,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
"https://api.weixin.qq.com/cgi-bin/",
|
||||
)
|
||||
self.active_send_mode = self.config.get("active_send_mode", False)
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
|
||||
@@ -143,7 +169,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
if not self.api_base_url.endswith("/"):
|
||||
self.api_base_url += "/"
|
||||
|
||||
self.server = WecomServer(self._event_queue, self.config)
|
||||
self.server = WeixinOfficialAccountServer(self._event_queue, self.config)
|
||||
|
||||
self.client = WeChatClient(
|
||||
self.config["appid"].strip(),
|
||||
@@ -202,7 +228,22 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
await self.server.start_polling()
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(微信公众平台)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.server.shutdown_event.wait()
|
||||
else:
|
||||
await self.server.start_polling()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
# 根据请求方法分发到不同的处理函数
|
||||
if request.method == "GET":
|
||||
return await self.server.handle_verify(request)
|
||||
else:
|
||||
return await self.server.handle_callback(request)
|
||||
|
||||
async def convert_message(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,7 @@ try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,12 @@ class PlatformMessageHistoryManager:
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict], # TODO: parse from message chain
|
||||
content: dict, # TODO: parse from message chain
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
):
|
||||
) -> PlatformMessageHistory:
|
||||
"""Insert a new platform message history record."""
|
||||
await self.db.insert_platform_message_history(
|
||||
return await self.db.insert_platform_message_history(
|
||||
platform_id=platform_id,
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core import astrbot_config, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
@@ -24,6 +24,7 @@ class ProviderManager:
|
||||
db_helper: BaseDatabase,
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.reload_lock = asyncio.Lock()
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
@@ -226,6 +227,9 @@ class ProviderManager:
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||||
return
|
||||
if provider_config.get("provider_type", "") == "agent_runner":
|
||||
return
|
||||
|
||||
logger.info(
|
||||
@@ -247,14 +251,6 @@ class ProviderManager:
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "coze":
|
||||
from .sources.coze_source import ProviderCoze as ProviderCoze
|
||||
case "dashscope":
|
||||
from .sources.dashscope_source import (
|
||||
ProviderDashscope as ProviderDashscope,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
@@ -331,6 +327,10 @@ class ProviderManager:
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
||||
@@ -436,40 +436,46 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
async def reload(self, provider_config: dict):
|
||||
await self.terminate_provider(provider_config["id"])
|
||||
if provider_config["enable"]:
|
||||
await self.load_provider(provider_config)
|
||||
async with self.reload_lock:
|
||||
await self.terminate_provider(provider_config["id"])
|
||||
if provider_config["enable"]:
|
||||
await self.load_provider(provider_config)
|
||||
|
||||
# 和配置文件保持同步
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.debug(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
if key not in config_ids:
|
||||
await self.terminate_provider(key)
|
||||
# 和配置文件保持同步
|
||||
self.providers_config = astrbot_config["provider"]
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.info(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
if key not in config_ids:
|
||||
await self.terminate_provider(key)
|
||||
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
||||
)
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
||||
)
|
||||
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
||||
)
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif (
|
||||
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
|
||||
):
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
||||
)
|
||||
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
||||
)
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif (
|
||||
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
|
||||
):
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
||||
)
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
@@ -11,6 +12,7 @@ from astrbot.core.provider.entities import (
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
@@ -43,6 +45,14 @@ class AbstractProvider(abc.ABC):
|
||||
)
|
||||
return meta
|
||||
|
||||
async def test(self):
|
||||
"""test the provider is a
|
||||
|
||||
raises:
|
||||
Exception: if the provider is not available
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
"""Chat Provider"""
|
||||
@@ -165,6 +175,12 @@ class Provider(AbstractProvider):
|
||||
|
||||
return dicts
|
||||
|
||||
async def test(self, timeout: float = 45.0):
|
||||
await asyncio.wait_for(
|
||||
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -177,6 +193,14 @@ class STTProvider(AbstractProvider):
|
||||
"""获取音频的文本"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self):
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
await self.get_text(sample_audio_path)
|
||||
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -189,6 +213,9 @@ class TTSProvider(AbstractProvider):
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self):
|
||||
await self.get_audio("hi")
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -211,6 +238,9 @@ class EmbeddingProvider(AbstractProvider):
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
async def test(self):
|
||||
await self.get_embedding("astrbot")
|
||||
|
||||
async def get_embeddings_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
@@ -294,3 +324,8 @@ class RerankProvider(AbstractProvider):
|
||||
) -> list[RerankResult]:
|
||||
"""获取查询和文档的重排序分数"""
|
||||
...
|
||||
|
||||
async def test(self):
|
||||
result = await self.rerank("Apple", documents=["apple", "banana"])
|
||||
if not result:
|
||||
raise Exception("Rerank provider test failed, no results returned")
|
||||
|
||||
@@ -290,7 +290,7 @@ class ProviderAnthropic(Provider):
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..entities import ProviderType, RerankResult
|
||||
from ..provider import RerankProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
class BailianRerankError(Exception):
|
||||
"""百炼重排序服务异常基类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BailianAPIError(BailianRerankError):
|
||||
"""百炼API返回错误"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BailianNetworkError(BailianRerankError):
|
||||
"""百炼网络请求错误"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
|
||||
)
|
||||
class BailianRerankProvider(RerankProvider):
|
||||
"""阿里云百炼文本重排序适配器."""
|
||||
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
# API配置
|
||||
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
|
||||
"DASHSCOPE_API_KEY", ""
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError("阿里云百炼 API Key 不能为空。")
|
||||
|
||||
self.model = provider_config.get("rerank_model", "qwen3-rerank")
|
||||
self.timeout = provider_config.get("timeout", 30)
|
||||
self.return_documents = provider_config.get("return_documents", False)
|
||||
self.instruct = provider_config.get("instruct", "")
|
||||
|
||||
self.base_url = provider_config.get(
|
||||
"rerank_api_base",
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
||||
)
|
||||
|
||||
# 设置HTTP客户端
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
self.client = aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
)
|
||||
|
||||
# 设置模型名称
|
||||
self.set_model(self.model)
|
||||
|
||||
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
|
||||
|
||||
def _build_payload(
|
||||
self, query: str, documents: list[str], top_n: int | None
|
||||
) -> dict:
|
||||
"""构建请求载荷
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 文档列表
|
||||
top_n: 返回前N个结果,如果为None则返回所有结果
|
||||
|
||||
Returns:
|
||||
请求载荷字典
|
||||
"""
|
||||
base = {"model": self.model, "input": {"query": query, "documents": documents}}
|
||||
|
||||
params = {
|
||||
k: v
|
||||
for k, v in [
|
||||
("top_n", top_n if top_n is not None and top_n > 0 else None),
|
||||
("return_documents", True if self.return_documents else None),
|
||||
(
|
||||
"instruct",
|
||||
self.instruct
|
||||
if self.instruct and self.model == "qwen3-rerank"
|
||||
else None,
|
||||
),
|
||||
]
|
||||
if v is not None
|
||||
}
|
||||
|
||||
if params:
|
||||
base["parameters"] = params
|
||||
|
||||
return base
|
||||
|
||||
def _parse_results(self, data: dict) -> list[RerankResult]:
|
||||
"""解析API响应结果
|
||||
|
||||
Args:
|
||||
data: API响应数据
|
||||
|
||||
Returns:
|
||||
重排序结果列表
|
||||
|
||||
Raises:
|
||||
BailianAPIError: API返回错误
|
||||
KeyError: 结果缺少必要字段
|
||||
"""
|
||||
# 检查响应状态
|
||||
if data.get("code", "200") != "200":
|
||||
raise BailianAPIError(
|
||||
f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
|
||||
)
|
||||
|
||||
results = data.get("output", {}).get("results", [])
|
||||
if not results:
|
||||
logger.warning(f"百炼 Rerank 返回空结果: {data}")
|
||||
return []
|
||||
|
||||
# 转换为RerankResult对象,使用.get()避免KeyError
|
||||
rerank_results = []
|
||||
for idx, result in enumerate(results):
|
||||
try:
|
||||
index = result.get("index", idx)
|
||||
relevance_score = result.get("relevance_score", 0.0)
|
||||
|
||||
if relevance_score is None:
|
||||
logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
|
||||
relevance_score = 0.0
|
||||
|
||||
rerank_result = RerankResult(
|
||||
index=index, relevance_score=relevance_score
|
||||
)
|
||||
rerank_results.append(rerank_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
|
||||
continue
|
||||
|
||||
return rerank_results
|
||||
|
||||
def _log_usage(self, data: dict) -> None:
|
||||
"""记录使用量信息
|
||||
|
||||
Args:
|
||||
data: API响应数据
|
||||
"""
|
||||
tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
if tokens > 0:
|
||||
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int | None = None,
|
||||
) -> list[RerankResult]:
|
||||
"""
|
||||
对文档进行重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 待排序的文档列表
|
||||
top_n: 返回前N个结果,如果为None则使用配置中的默认值
|
||||
|
||||
Returns:
|
||||
重排序结果列表
|
||||
"""
|
||||
if not documents:
|
||||
logger.warning("文档列表为空,返回空结果")
|
||||
return []
|
||||
|
||||
if not query.strip():
|
||||
logger.warning("查询文本为空,返回空结果")
|
||||
return []
|
||||
|
||||
# 检查限制
|
||||
if len(documents) > 500:
|
||||
logger.warning(
|
||||
f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
|
||||
)
|
||||
documents = documents[:500]
|
||||
|
||||
try:
|
||||
# 构建请求载荷,如果top_n为None则返回所有重排序结果
|
||||
payload = self._build_payload(query, documents, top_n)
|
||||
|
||||
logger.debug(
|
||||
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
|
||||
)
|
||||
|
||||
# 发送请求
|
||||
async with self.client.post(self.base_url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
response_data = await response.json()
|
||||
|
||||
# 解析结果并记录使用量
|
||||
results = self._parse_results(response_data)
|
||||
self._log_usage(response_data)
|
||||
|
||||
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
|
||||
|
||||
return results
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
error_msg = f"网络请求失败: {e}"
|
||||
logger.error(f"百炼 Rerank 网络请求失败: {e}")
|
||||
raise BailianNetworkError(error_msg) from e
|
||||
except BailianRerankError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"重排序失败: {e}"
|
||||
logger.error(f"百炼 Rerank 处理失败: {e}")
|
||||
raise BailianRerankError(error_msg) from e
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""关闭HTTP客户端会话."""
|
||||
if self.client:
|
||||
logger.info("关闭 百炼 Rerank 客户端会话")
|
||||
try:
|
||||
await self.client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
|
||||
finally:
|
||||
self.client = None
|
||||
@@ -1,650 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
|
||||
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
||||
class ProviderCoze(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://"),
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
self.conversation_ids: dict[str, str] = {}
|
||||
self.file_id_cache: dict[str, dict[str, str]] = {}
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
||||
"""生成统一的缓存键
|
||||
|
||||
Args:
|
||||
data: 图片数据或路径
|
||||
is_base64: 是否是 base64 数据
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
|
||||
"""
|
||||
try:
|
||||
if is_base64 and data.startswith("data:image/"):
|
||||
try:
|
||||
header, encoded = data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
cache_key = hashlib.md5(image_bytes).hexdigest()
|
||||
return cache_key
|
||||
except Exception:
|
||||
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
elif data.startswith(("http://", "https://")):
|
||||
# URL图片,使用URL作为缓存键
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
clean_path = (
|
||||
data.split("_")[0]
|
||||
if "_" in data and len(data.split("_")) >= 3
|
||||
else data
|
||||
)
|
||||
|
||||
if os.path.exists(clean_path):
|
||||
with open(clean_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
cache_key = hashlib.md5(file_content).hexdigest()
|
||||
return cache_key
|
||||
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
||||
return cache_key
|
||||
|
||||
async def _upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
session_id: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id"""
|
||||
# 使用 API 客户端上传文件
|
||||
file_id = await self.api_client.upload_file(file_data)
|
||||
|
||||
# 缓存 file_id
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self,
|
||||
image_url: str,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
# 计算哈希实现缓存
|
||||
cache_key = self._generate_cache_key(image_url) if session_id else None
|
||||
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
|
||||
file_id = await self._upload_file(image_data, session_id, cache_key)
|
||||
|
||||
if session_id and cache_key:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {e!s}")
|
||||
raise Exception(f"处理图片失败: {e!s}")
|
||||
|
||||
async def _process_context_images(
|
||||
self,
|
||||
content: str | list,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
processed_content = []
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
processed_content.append(item)
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片逻辑
|
||||
if "file_id" in item:
|
||||
# 已经有 file_id
|
||||
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
||||
processed_content.append(item)
|
||||
else:
|
||||
# 获取图片数据
|
||||
image_data = ""
|
||||
if "image_url" in item and isinstance(item["image_url"], dict):
|
||||
image_data = item["image_url"].get("url", "")
|
||||
elif "data" in item:
|
||||
image_data = item.get("data", "")
|
||||
elif "url" in item:
|
||||
image_data = item.get("url", "")
|
||||
|
||||
if not image_data:
|
||||
continue
|
||||
# 计算哈希用于缓存
|
||||
cache_key = self._generate_cache_key(
|
||||
image_data,
|
||||
is_base64=image_data.startswith("data:image/"),
|
||||
)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id},
|
||||
)
|
||||
else:
|
||||
# 上传图片并缓存
|
||||
if image_data.startswith("data:image/"):
|
||||
# base64 处理
|
||||
_, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
elif image_data.startswith(("http://", "https://")):
|
||||
# URL 图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
image_data,
|
||||
session_id,
|
||||
)
|
||||
# 为URL图片也添加缓存
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
elif os.path.exists(image_data):
|
||||
# 本地文件
|
||||
with open(image_data, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法处理的图片格式: {image_data[:50]}...",
|
||||
)
|
||||
continue
|
||||
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id},
|
||||
)
|
||||
|
||||
result = json.dumps(processed_content, ensure_ascii=False)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"处理上下文图片失败: {e!s}")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""文本对话, 内部使用流式接口实现非流式
|
||||
|
||||
Args:
|
||||
prompt (str): 用户提示词
|
||||
session_id (str): 会话ID
|
||||
image_urls (List[str]): 图片URL列表
|
||||
func_tool (FuncCall): 函数调用工具(不支持)
|
||||
contexts (List): 上下文列表
|
||||
system_prompt (str): 系统提示语
|
||||
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
||||
model (str): 模型名称(不支持)
|
||||
|
||||
Returns:
|
||||
LLMResponse: LLM响应对象
|
||||
|
||||
"""
|
||||
accumulated_content = ""
|
||||
final_response = None
|
||||
|
||||
async for llm_response in self.text_chat_stream(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
model=model,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.completion_text:
|
||||
accumulated_content += llm_response.completion_text
|
||||
else:
|
||||
final_response = llm_response
|
||||
|
||||
if final_response:
|
||||
return final_response
|
||||
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
return LLMResponse(role="assistant", completion_text="")
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话接口"""
|
||||
# 用户ID参数(参考文档, 可以自定义)
|
||||
user_id = session_id or kwargs.get("user", "default_user")
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
contexts = self._ensure_message_to_dicts(contexts)
|
||||
if not self.auto_save_history and contexts:
|
||||
# 如果关闭了自动保存历史,传入上下文
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
content = ctx["content"]
|
||||
content_type = ctx.get("content_type", "text")
|
||||
|
||||
# 处理可能包含图片的上下文
|
||||
if (
|
||||
content_type == "object_string"
|
||||
or (isinstance(content, str) and content.startswith("["))
|
||||
or (
|
||||
isinstance(content, list)
|
||||
and any(
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "image_url"
|
||||
for item in content
|
||||
)
|
||||
)
|
||||
):
|
||||
processed_content = await self._process_context_images(
|
||||
content,
|
||||
user_id,
|
||||
)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": (
|
||||
content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content, ensure_ascii=False)
|
||||
),
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
||||
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
try:
|
||||
if url.startswith(("http://", "https://")):
|
||||
# 网络图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
url,
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
# 本地文件或 base64
|
||||
if url.startswith("data:image/"):
|
||||
# base64
|
||||
_, encoded = url.split(",", 1)
|
||||
image_data = base64.b64decode(encoded)
|
||||
cache_key = self._generate_cache_key(
|
||||
url,
|
||||
is_base64=True,
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data,
|
||||
user_id,
|
||||
cache_key,
|
||||
)
|
||||
# 本地文件
|
||||
elif os.path.exists(url):
|
||||
with open(url, "rb") as f:
|
||||
image_data = f.read()
|
||||
# 用文件路径和修改时间来缓存
|
||||
file_stat = os.stat(url)
|
||||
cache_key = self._generate_cache_key(
|
||||
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
||||
is_base64=False,
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data,
|
||||
user_id,
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"图片文件不存在: {url}")
|
||||
continue
|
||||
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {url}: {e!s}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
},
|
||||
)
|
||||
# 纯文本
|
||||
elif prompt:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
self.conversation_ids[user_id] = data["conversation_id"]
|
||||
|
||||
elif event_type == "conversation.message.delta":
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
message_started = True
|
||||
accumulated_content += content
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=content,
|
||||
is_chunk=True,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
if isinstance(data, dict):
|
||||
msg_type = data.get("type")
|
||||
if msg_type == "answer" and data.get("role") == "assistant":
|
||||
final_content = data.get("content", "")
|
||||
if not accumulated_content and final_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
elif event_type == "done":
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
error_msg = (
|
||||
data.get("message", "未知错误")
|
||||
if isinstance(data, dict)
|
||||
else str(data)
|
||||
)
|
||||
logger.error(f"Coze 流式响应错误: {error_msg}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 错误: {error_msg}",
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="LLM 未响应任何内容。",
|
||||
is_chunk=False,
|
||||
)
|
||||
elif message_started and accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 流式请求失败: {e!s}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 流式请求失败: {e!s}",
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
async def forget(self, session_id: str):
|
||||
"""清空指定会话的上下文"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if user_id in self.file_id_cache:
|
||||
self.file_id_cache.pop(user_id, None)
|
||||
|
||||
if not conversation_id:
|
||||
return True
|
||||
|
||||
try:
|
||||
response = await self.api_client.clear_context(conversation_id)
|
||||
|
||||
if "code" in response and response["code"] == 0:
|
||||
self.conversation_ids.pop(user_id, None)
|
||||
return True
|
||||
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空 Coze 会话失败: {e!s}")
|
||||
return False
|
||||
|
||||
async def get_current_key(self):
|
||||
"""获取当前API Key"""
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key: str):
|
||||
"""设置新的API Key"""
|
||||
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
"""获取可用模型列表"""
|
||||
return [f"bot_{self.bot_id}"]
|
||||
|
||||
def get_model(self):
|
||||
"""获取当前模型"""
|
||||
return f"bot_{self.bot_id}"
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型(在Coze中是Bot ID)"""
|
||||
if model.startswith("bot_"):
|
||||
self.bot_id = model[4:]
|
||||
else:
|
||||
self.bot_id = model
|
||||
|
||||
async def get_human_readable_context(
|
||||
self,
|
||||
session_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
):
|
||||
"""获取人类可读的上下文历史"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if not conversation_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await self.api_client.get_message_list(
|
||||
conversation_id=conversation_id,
|
||||
order="desc",
|
||||
limit=page_size,
|
||||
offset=(page - 1) * page_size,
|
||||
)
|
||||
|
||||
if data.get("code") != 0:
|
||||
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
||||
return []
|
||||
|
||||
messages = data.get("data", {}).get("messages", [])
|
||||
|
||||
readable_history = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
if role == "user":
|
||||
readable_history.append(f"用户: {content}")
|
||||
elif role == "assistant" and msg_type == "answer":
|
||||
readable_history.append(f"助手: {content}")
|
||||
|
||||
return readable_history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Coze 消息历史失败: {e!s}")
|
||||
return []
|
||||
|
||||
async def terminate(self):
|
||||
"""清理资源"""
|
||||
await self.api_client.close()
|
||||
@@ -1,207 +0,0 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import re
|
||||
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
||||
class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("阿里云百炼 API Key 不能为空。")
|
||||
self.app_id = provider_config.get("dashscope_app_id", "")
|
||||
if not self.app_id:
|
||||
raise Exception("阿里云百炼 APP ID 不能为空。")
|
||||
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
||||
if not self.dashscope_app_type:
|
||||
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||
self.model_name = "dashscope"
|
||||
self.variables: dict = provider_config.get("variables", {})
|
||||
self.rag_options: dict = provider_config.get("rag_options", {})
|
||||
self.output_reference = self.rag_options.get("output_reference", False)
|
||||
self.rag_options = self.rag_options.copy()
|
||||
self.rag_options.pop("output_reference", None)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
def has_rag_options(self):
|
||||
"""判断是否有 RAG 选项
|
||||
|
||||
Returns:
|
||||
bool: 是否有 RAG 选项
|
||||
|
||||
"""
|
||||
if self.rag_options and (
|
||||
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||
or len(self.rag_options.get("file_ids", [])) > 0
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||
and not self.has_rag_options()
|
||||
):
|
||||
# 支持多轮对话的
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
if image_urls:
|
||||
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
||||
contexts_no_img = await self._remove_image_from_context(contexts)
|
||||
context_query = [*contexts_no_img, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
# 调用阿里云百炼 API
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"api_key": self.api_key,
|
||||
"messages": context_query,
|
||||
"biz_params": payload_vars or None,
|
||||
}
|
||||
partial = functools.partial(
|
||||
Application.call,
|
||||
**payload,
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
else:
|
||||
# 不支持多轮对话的
|
||||
# 调用阿里云百炼 API
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"prompt": prompt,
|
||||
"api_key": self.api_key,
|
||||
"biz_params": payload_vars or None,
|
||||
}
|
||||
if self.rag_options:
|
||||
payload["rag_options"] = self.rag_options
|
||||
partial = functools.partial(
|
||||
Application.call,
|
||||
**payload,
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
assert isinstance(response, ApplicationResponse)
|
||||
|
||||
logger.debug(f"dashscope resp: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
|
||||
)
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
result_chain=MessageChain().message(
|
||||
f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
|
||||
),
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_parts = []
|
||||
for ref in response.output.get("doc_references", []) or []:
|
||||
ref_title = (
|
||||
ref.get("title", "")
|
||||
if ref.get("title")
|
||||
else ref.get("doc_name", "")
|
||||
)
|
||||
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
|
||||
ref_str = "".join(ref_parts)
|
||||
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = MessageChain().message(output_text)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def forget(self, session_id):
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
|
||||
|
||||
async def terminate(self):
|
||||
pass
|
||||
@@ -1,285 +0,0 @@
|
||||
import os
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
class ProviderDify(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Dify API Key 不能为空。")
|
||||
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_type = provider_config.get("dify_api_type", "")
|
||||
if not self.api_type:
|
||||
raise Exception("Dify API 类型不能为空。")
|
||||
self.model_name = "dify"
|
||||
self.workflow_output_key = provider_config.get(
|
||||
"dify_workflow_output_key",
|
||||
"astrbot_wf_output",
|
||||
)
|
||||
self.dify_query_input_key = provider_config.get(
|
||||
"dify_query_input_key",
|
||||
"astrbot_text_query",
|
||||
)
|
||||
if not self.dify_query_input_key:
|
||||
self.dify_query_input_key = "astrbot_text_query"
|
||||
if not self.workflow_output_key:
|
||||
self.workflow_output_key = "astrbot_wf_output"
|
||||
self.variables: dict = provider_config.get("variables", {})
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.conversation_ids = {}
|
||||
"""记录当前 session id 的对话 ID"""
|
||||
|
||||
self.api_client = DifyAPIClient(self.api_key, api_base)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
result = ""
|
||||
session_id = session_id or kwargs.get("user") or "unknown" # 1734
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
image_path = (
|
||||
await download_image_by_url(image_url)
|
||||
if image_url.startswith("http")
|
||||
else image_url
|
||||
)
|
||||
file_response = await self.api_client.file_upload(
|
||||
image_path,
|
||||
user=session_id,
|
||||
)
|
||||
logger.debug(f"Dify 上传图片响应:{file_response}")
|
||||
if "id" not in file_response:
|
||||
logger.warning(
|
||||
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。",
|
||||
)
|
||||
continue
|
||||
files_payload.append(
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response["id"],
|
||||
},
|
||||
)
|
||||
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
case "chat" | "agent" | "chatflow":
|
||||
if not prompt:
|
||||
prompt = "请描述这张图片。"
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**payload_vars,
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if (
|
||||
chunk["event"] == "message"
|
||||
or chunk["event"] == "agent_message"
|
||||
):
|
||||
result += chunk["answer"]
|
||||
if not conversation_id:
|
||||
self.conversation_ids[session_id] = chunk[
|
||||
"conversation_id"
|
||||
]
|
||||
conversation_id = chunk["conversation_id"]
|
||||
elif chunk["event"] == "message_end":
|
||||
logger.debug("Dify message end")
|
||||
break
|
||||
elif chunk["event"] == "error":
|
||||
logger.error(f"Dify 出现错误:{chunk}")
|
||||
raise Exception(
|
||||
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}",
|
||||
)
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**payload_vars,
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
match chunk["event"]:
|
||||
case "workflow_started":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。",
|
||||
)
|
||||
case "node_finished":
|
||||
logger.debug(
|
||||
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。",
|
||||
)
|
||||
case "workflow_finished":
|
||||
logger.info(
|
||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束",
|
||||
)
|
||||
logger.debug(f"Dify 工作流结果:{chunk}")
|
||||
if chunk["data"]["error"]:
|
||||
logger.error(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}",
|
||||
)
|
||||
raise Exception(
|
||||
f"Dify 工作流出现错误:{chunk['data']['error']}",
|
||||
)
|
||||
if (
|
||||
self.workflow_output_key
|
||||
not in chunk["data"]["outputs"]
|
||||
):
|
||||
raise Exception(
|
||||
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}",
|
||||
)
|
||||
result = chunk
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Dify 请求失败:{e!s}")
|
||||
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}")
|
||||
|
||||
if not result:
|
||||
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
||||
|
||||
chain = await self.parse_dify_result(result)
|
||||
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
if isinstance(chunk, str):
|
||||
# Chat
|
||||
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||
|
||||
async def parse_file(item: dict):
|
||||
match item["type"]:
|
||||
case "image":
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
return Comp.Video(file=item["url"])
|
||||
case _:
|
||||
return Comp.File(name=item["filename"], file=item["url"])
|
||||
|
||||
output = chunk["data"]["outputs"][self.workflow_output_key]
|
||||
chains = []
|
||||
if isinstance(output, str):
|
||||
# 纯文本输出
|
||||
chains.append(Comp.Plain(output))
|
||||
elif isinstance(output, list):
|
||||
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
|
||||
for item in output:
|
||||
# handle Array[File]
|
||||
if (
|
||||
not isinstance(item, dict)
|
||||
or item.get("dify_model_identity", "") != "__dify__file__"
|
||||
):
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
break
|
||||
else:
|
||||
chains.append(Comp.Plain(str(output)))
|
||||
|
||||
# scan file
|
||||
files = chunk["data"].get("files", [])
|
||||
for item in files:
|
||||
comp = await parse_file(item)
|
||||
chains.append(comp)
|
||||
|
||||
return MessageChain(chain=chains)
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.conversation_ids[session_id] = ""
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("Dify 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 Dify 的历史消息记录。")
|
||||
|
||||
async def terminate(self):
|
||||
await self.api_client.close()
|
||||
@@ -111,9 +111,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
)
|
||||
# logger.error(
|
||||
# f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
# )
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
|
||||
@@ -433,7 +433,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
payloads.pop("tools", None)
|
||||
return False, chosen_key, available_api_keys, payloads, context_query, None
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
# logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
@@ -6,7 +6,10 @@ from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.tencent_record_helper import (
|
||||
convert_to_pcm_wav,
|
||||
tencent_silk_to_wav,
|
||||
)
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import STTProvider
|
||||
@@ -35,18 +38,28 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
|
||||
self.set_model(provider_config.get("model"))
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
async def _get_audio_format(self, file_path):
|
||||
# 定义要检测的头部字节
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
amr_header = b"#!AMR"
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
return False
|
||||
return "silk"
|
||||
|
||||
if amr_header in file_header:
|
||||
return "amr"
|
||||
return None
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
"""Only supports mp3, mp4, mpeg, m4a, wav, webm"""
|
||||
is_tencent = False
|
||||
output_path = None
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
if "multimedia.nt.qq.com.cn" in audio_url:
|
||||
@@ -62,16 +75,35 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
file_format = await self._get_audio_format(audio_url)
|
||||
|
||||
# 判断是否需要转换
|
||||
if file_format in ["silk", "amr"]:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
|
||||
if file_format == "silk":
|
||||
logger.info(
|
||||
"Converting silk file to wav using tencent_silk_to_wav..."
|
||||
)
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
elif file_format == "amr":
|
||||
logger.info(
|
||||
"Converting amr file to wav using convert_to_pcm_wav..."
|
||||
)
|
||||
await convert_to_pcm_wav(audio_url, output_path)
|
||||
|
||||
audio_url = output_path
|
||||
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
file=open(audio_url, "rb"),
|
||||
file=("audio.wav", open(audio_url, "rb")),
|
||||
)
|
||||
|
||||
# remove temp file
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(audio_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temp file {audio_url}: {e}")
|
||||
return result.text
|
||||
|
||||
@@ -14,7 +14,6 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.memory.memory_manager import MemoryManager
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.platform import Platform
|
||||
@@ -66,7 +65,6 @@ class Context:
|
||||
persona_manager: PersonaManager,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
knowledge_base_manager: KnowledgeBaseManager,
|
||||
memory_manager: MemoryManager,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
@@ -81,7 +79,6 @@ class Context:
|
||||
self.persona_manager = persona_manager
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
self.kb_manager = knowledge_base_manager
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
async def llm_generate(
|
||||
self,
|
||||
|
||||
@@ -171,110 +171,3 @@ class SessionServiceManager:
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_session_status(session_id: str, enabled: bool) -> None:
|
||||
"""设置会话的整体启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["session_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_session_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理会话请求(会话整体启停检查)
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_session_enabled(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话命名相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_custom_name(session_id: str) -> str | None:
|
||||
"""获取会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 自定义名称,如果没有设置则返回None
|
||||
|
||||
"""
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
return session_services.get("custom_name")
|
||||
|
||||
@staticmethod
|
||||
def set_session_custom_name(session_id: str, custom_name: str) -> None:
|
||||
"""设置会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
if custom_name and custom_name.strip():
|
||||
session_config["custom_name"] = custom_name.strip()
|
||||
else:
|
||||
# 如果传入空名称,则删除自定义名称
|
||||
session_config.pop("custom_name", None)
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_display_name(session_id: str) -> str:
|
||||
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 显示名称
|
||||
|
||||
"""
|
||||
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
||||
if custom_name:
|
||||
return custom_name
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
|
||||
@@ -42,87 +42,6 @@ class SessionPluginManager:
|
||||
# 如果都没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_plugin_status_for_session(
|
||||
session_id: str,
|
||||
plugin_name: str,
|
||||
enabled: bool,
|
||||
) -> None:
|
||||
"""设置插件在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
enabled: True表示启用,False表示禁用
|
||||
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config",
|
||||
{},
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
"disabled_plugins": [],
|
||||
}
|
||||
|
||||
session_config = session_plugin_config[session_id]
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
if enabled:
|
||||
# 启用插件
|
||||
if plugin_name in disabled_plugins:
|
||||
disabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in enabled_plugins:
|
||||
enabled_plugins.append(plugin_name)
|
||||
else:
|
||||
# 禁用插件
|
||||
if plugin_name in enabled_plugins:
|
||||
enabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in disabled_plugins:
|
||||
disabled_plugins.append(plugin_name)
|
||||
|
||||
# 保存配置
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put(
|
||||
"session_plugin_config",
|
||||
session_plugin_config,
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_plugin_config(session_id: str) -> dict[str, list[str]]:
|
||||
"""获取指定会话的插件配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
|
||||
"""
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config",
|
||||
{},
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
)
|
||||
return session_plugin_config.get(
|
||||
session_id,
|
||||
{"enabled_plugins": [], "disabled_plugins": []},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
@@ -85,3 +85,22 @@ class UmopConfigRouter:
|
||||
|
||||
self.umop_to_conf_id[umo] = conf_id
|
||||
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
||||
|
||||
async def delete_route(self, umo: str):
|
||||
"""删除一条路由
|
||||
|
||||
Args:
|
||||
umo (str): 需要删除的 UMO 字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 当 umo 格式不正确时抛出
|
||||
"""
|
||||
|
||||
if not isinstance(umo, str) or len(umo.split(":")) != 3:
|
||||
raise ValueError(
|
||||
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all",
|
||||
)
|
||||
|
||||
if umo in self.umop_to_conf_id:
|
||||
del self.umop_to_conf_id[umo]
|
||||
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def extract_file_moonshotai(file_path: str, api_key: str) -> str:
|
||||
"""Extract text from a file using Moonshot AI API"""
|
||||
"""
|
||||
Args:
|
||||
file_path: The path to the file to extract text from
|
||||
api_key: The API key to use to extract text from the file
|
||||
Returns:
|
||||
The text extracted from the file
|
||||
"""
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.moonshot.cn/v1",
|
||||
)
|
||||
file_object = await client.files.create(
|
||||
file=Path(file_path),
|
||||
purpose="file-extract", # type: ignore
|
||||
)
|
||||
return (await client.files.content(file_id=file_object.id)).text
|
||||
@@ -0,0 +1,73 @@
|
||||
import traceback
|
||||
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
|
||||
|
||||
def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None:
|
||||
"""
|
||||
Migra agent runner configs from provider configs.
|
||||
"""
|
||||
try:
|
||||
default_prov_id = conf["provider_settings"]["default_provider_id"]
|
||||
if default_prov_id in ids_map:
|
||||
conf["provider_settings"]["default_provider_id"] = ""
|
||||
p = ids_map[default_prov_id]
|
||||
if p["type"] == "dify":
|
||||
conf["provider_settings"]["dify_agent_runner_provider_id"] = p["id"]
|
||||
conf["provider_settings"]["agent_runner_type"] = "dify"
|
||||
elif p["type"] == "coze":
|
||||
conf["provider_settings"]["coze_agent_runner_provider_id"] = p["id"]
|
||||
conf["provider_settings"]["agent_runner_type"] = "coze"
|
||||
elif p["type"] == "dashscope":
|
||||
conf["provider_settings"]["dashscope_agent_runner_provider_id"] = p[
|
||||
"id"
|
||||
]
|
||||
conf["provider_settings"]["agent_runner_type"] = "dashscope"
|
||||
conf.save_config()
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for third party agent runner configs failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def migra(
|
||||
db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager
|
||||
) -> None:
|
||||
"""
|
||||
Stores the migration logic here.
|
||||
btw, i really don't like migration :(
|
||||
"""
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
try:
|
||||
await migrate_45_to_46(astrbot_config_mgr, umop_config_router)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for webchat session
|
||||
try:
|
||||
await migrate_webchat_session(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migra third party agent runner configs
|
||||
_c = False
|
||||
providers = astrbot_config["provider"]
|
||||
ids_map = {}
|
||||
for prov in providers:
|
||||
type_ = prov.get("type")
|
||||
if type_ in ["dify", "coze", "dashscope"]:
|
||||
prov["provider_type"] = "agent_runner"
|
||||
ids_map[prov["id"]] = {
|
||||
"type": type_,
|
||||
"id": prov["id"],
|
||||
}
|
||||
_c = True
|
||||
if _c:
|
||||
astrbot_config.save_config()
|
||||
|
||||
for conf in acm.confs.values():
|
||||
_migra_agent_runner_configs(conf, ids_map)
|
||||
@@ -40,9 +40,6 @@ class SharedPreferences:
|
||||
else:
|
||||
ret = default
|
||||
return ret
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference.",
|
||||
)
|
||||
|
||||
async def range_get_async(
|
||||
self,
|
||||
@@ -56,30 +53,6 @@ class SharedPreferences:
|
||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||
return ret
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str | None,
|
||||
@@ -88,7 +61,7 @@ class SharedPreferences:
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取会话范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
Note: 当 umo 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if umo is None or key is None:
|
||||
return await self.range_get_async("umo", umo, key)
|
||||
|
||||
@@ -36,7 +36,7 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
import pilk
|
||||
except (ImportError, ModuleNotFoundError) as _:
|
||||
raise Exception(
|
||||
"pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库",
|
||||
"pilk 模块未安装,请前往管理面板->平台日志->安装pip库 安装 pilk 这个库",
|
||||
)
|
||||
# with wave.open(wav_path, 'rb') as wav:
|
||||
# wav_data = wav.readframes(wav.getnframes())
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from astrbot.core import astrbot_config, logger
|
||||
|
||||
|
||||
def _get_callback_api_base() -> str:
|
||||
try:
|
||||
return astrbot_config.get("callback_api_base", "").rstrip("/")
|
||||
except Exception as e:
|
||||
logger.error(f"获取 callback_api_base 失败: {e!s}")
|
||||
return ""
|
||||
|
||||
|
||||
def _get_dashboard_port() -> int:
|
||||
try:
|
||||
return astrbot_config.get("dashboard", {}).get("port", 6185)
|
||||
except Exception as e:
|
||||
logger.error(f"获取 dashboard 端口失败: {e!s}")
|
||||
return 6185
|
||||
|
||||
|
||||
def log_webhook_info(platform_name: str, webhook_uuid: str):
|
||||
"""打印美观的 webhook 信息日志
|
||||
|
||||
Args:
|
||||
platform_name: 平台名称
|
||||
webhook_uuid: webhook 的 UUID
|
||||
"""
|
||||
|
||||
callback_base = _get_callback_api_base()
|
||||
|
||||
if not callback_base:
|
||||
callback_base = "http(s)://<your-astrbot-domain>"
|
||||
|
||||
if not callback_base.startswith("http"):
|
||||
callback_base = f"http(s)://{callback_base}"
|
||||
|
||||
callback_base = callback_base.rstrip("/")
|
||||
webhook_url = f"{callback_base}/api/platform/webhook/{webhook_uuid}"
|
||||
|
||||
display_log = (
|
||||
"\n====================\n"
|
||||
f"🔗 机器人平台 {platform_name} 已启用统一 Webhook 模式\n"
|
||||
f"📍 Webhook 回调地址: \n"
|
||||
f" ➜ http://<your-ip>:{_get_dashboard_port()}/api/platform/webhook/{webhook_uuid}\n"
|
||||
f" ➜ {webhook_url}\n"
|
||||
"====================\n"
|
||||
)
|
||||
logger.info(display_log)
|
||||
@@ -5,8 +5,8 @@ from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
from .knowledge_base import KnowledgeBaseRoute
|
||||
from .log import LogRoute
|
||||
from .memory import MemoryRoute
|
||||
from .persona import PersonaRoute
|
||||
from .platform import PlatformRoute
|
||||
from .plugin import PluginRoute
|
||||
from .session_management import SessionManagementRoute
|
||||
from .stat import StatRoute
|
||||
@@ -22,8 +22,8 @@ __all__ = [
|
||||
"FileRoute",
|
||||
"KnowledgeBaseRoute",
|
||||
"LogRoute",
|
||||
"MemoryRoute",
|
||||
"PersonaRoute",
|
||||
"PlatformRoute",
|
||||
"PluginRoute",
|
||||
"SessionManagementRoute",
|
||||
"StatRoute",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from quart import Response as QuartResponse
|
||||
from quart import g, make_response, request
|
||||
from quart import g, make_response, request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -44,7 +44,7 @@ class ChatRoute(Route):
|
||||
self.update_session_display_name,
|
||||
),
|
||||
"/chat/get_file": ("GET", self.get_file),
|
||||
"/chat/post_image": ("POST", self.post_image),
|
||||
"/chat/get_attachment": ("GET", self.get_attachment),
|
||||
"/chat/post_file": ("POST", self.post_file),
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
@@ -56,6 +56,7 @@ class ChatRoute(Route):
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
self.db = db
|
||||
self.umop_config_router = core_lifecycle.umop_config_router
|
||||
|
||||
self.running_convs: dict[str, bool] = {}
|
||||
|
||||
@@ -72,52 +73,184 @@ class ChatRoute(Route):
|
||||
if not real_file_path.startswith(real_imgs_dir):
|
||||
return Response().error("Invalid file path").__dict__
|
||||
|
||||
with open(real_file_path, "rb") as f:
|
||||
filename_ext = os.path.splitext(filename)[1].lower()
|
||||
|
||||
if filename_ext == ".wav":
|
||||
return QuartResponse(f.read(), mimetype="audio/wav")
|
||||
if filename_ext[1:] in self.supported_imgs:
|
||||
return QuartResponse(f.read(), mimetype="image/jpeg")
|
||||
return QuartResponse(f.read())
|
||||
filename_ext = os.path.splitext(filename)[1].lower()
|
||||
if filename_ext == ".wav":
|
||||
return await send_file(real_file_path, mimetype="audio/wav")
|
||||
if filename_ext[1:] in self.supported_imgs:
|
||||
return await send_file(real_file_path, mimetype="image/jpeg")
|
||||
return await send_file(real_file_path)
|
||||
|
||||
except (FileNotFoundError, OSError):
|
||||
return Response().error("File access error").__dict__
|
||||
|
||||
async def post_image(self):
|
||||
post_data = await request.files
|
||||
if "file" not in post_data:
|
||||
return Response().error("Missing key: file").__dict__
|
||||
async def get_attachment(self):
|
||||
"""Get attachment file by attachment_id."""
|
||||
attachment_id = request.args.get("attachment_id")
|
||||
if not attachment_id:
|
||||
return Response().error("Missing key: attachment_id").__dict__
|
||||
|
||||
file = post_data["file"]
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
await file.save(path)
|
||||
try:
|
||||
attachment = await self.db.get_attachment_by_id(attachment_id)
|
||||
if not attachment:
|
||||
return Response().error("Attachment not found").__dict__
|
||||
|
||||
return Response().ok(data={"filename": filename}).__dict__
|
||||
file_path = attachment.path
|
||||
real_file_path = os.path.realpath(file_path)
|
||||
|
||||
return await send_file(real_file_path, mimetype=attachment.mime_type)
|
||||
|
||||
except (FileNotFoundError, OSError):
|
||||
return Response().error("File access error").__dict__
|
||||
|
||||
async def post_file(self):
|
||||
"""Upload a file and create an attachment record, return attachment_id."""
|
||||
post_data = await request.files
|
||||
if "file" not in post_data:
|
||||
return Response().error("Missing key: file").__dict__
|
||||
|
||||
file = post_data["file"]
|
||||
filename = f"{uuid.uuid4()!s}"
|
||||
# 通过文件格式判断文件类型
|
||||
if file.content_type.startswith("audio"):
|
||||
filename += ".wav"
|
||||
filename = file.filename or f"{uuid.uuid4()!s}"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
|
||||
# 根据 content_type 判断文件类型并添加扩展名
|
||||
if content_type.startswith("image"):
|
||||
attach_type = "image"
|
||||
elif content_type.startswith("audio"):
|
||||
attach_type = "record"
|
||||
elif content_type.startswith("video"):
|
||||
attach_type = "video"
|
||||
else:
|
||||
attach_type = "file"
|
||||
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
await file.save(path)
|
||||
|
||||
return Response().ok(data={"filename": filename}).__dict__
|
||||
# 创建 attachment 记录
|
||||
attachment = await self.db.insert_attachment(
|
||||
path=path,
|
||||
type=attach_type,
|
||||
mime_type=content_type,
|
||||
)
|
||||
|
||||
if not attachment:
|
||||
return Response().error("Failed to create attachment").__dict__
|
||||
|
||||
filename = os.path.basename(attachment.path)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": filename,
|
||||
"type": attach_type,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def _build_user_message_parts(self, message: str | list) -> list[dict]:
|
||||
"""构建用户消息的部分列表
|
||||
|
||||
Args:
|
||||
message: 文本消息 (str) 或消息段列表 (list)
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if isinstance(message, list):
|
||||
for part in message:
|
||||
part_type = part.get("type")
|
||||
if part_type == "plain":
|
||||
parts.append({"type": "plain", "text": part.get("text", "")})
|
||||
elif part_type == "reply":
|
||||
parts.append(
|
||||
{"type": "reply", "message_id": part.get("message_id")}
|
||||
)
|
||||
elif attachment_id := part.get("attachment_id"):
|
||||
attachment = await self.db.get_attachment_by_id(attachment_id)
|
||||
if attachment:
|
||||
parts.append(
|
||||
{
|
||||
"type": attachment.type,
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": os.path.basename(attachment.path),
|
||||
"path": attachment.path, # will be deleted
|
||||
}
|
||||
)
|
||||
return parts
|
||||
|
||||
if message:
|
||||
parts.append({"type": "plain", "text": message})
|
||||
|
||||
return parts
|
||||
|
||||
async def _create_attachment_from_file(
|
||||
self, filename: str, attach_type: str
|
||||
) -> dict | None:
|
||||
"""从本地文件创建 attachment 并返回消息部分
|
||||
|
||||
用于处理 bot 回复中的媒体文件
|
||||
|
||||
Args:
|
||||
filename: 存储的文件名
|
||||
attach_type: 附件类型 (image, record, file, video)
|
||||
"""
|
||||
file_path = os.path.join(self.imgs_dir, os.path.basename(filename))
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
|
||||
# guess mime type
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# insert attachment
|
||||
attachment = await self.db.insert_attachment(
|
||||
path=file_path,
|
||||
type=attach_type,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
if not attachment:
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": attach_type,
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": os.path.basename(file_path),
|
||||
}
|
||||
|
||||
async def _save_bot_message(
|
||||
self,
|
||||
webchat_conv_id: str,
|
||||
text: str,
|
||||
media_parts: list,
|
||||
reasoning: str,
|
||||
):
|
||||
"""保存 bot 消息到历史记录,返回保存的记录"""
|
||||
bot_message_parts = []
|
||||
if text:
|
||||
bot_message_parts.append({"type": "plain", "text": text})
|
||||
bot_message_parts.extend(media_parts)
|
||||
|
||||
new_his = {"type": "bot", "message": bot_message_parts}
|
||||
if reasoning:
|
||||
new_his["reasoning"] = reasoning
|
||||
|
||||
record = await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
return record
|
||||
|
||||
async def chat(self):
|
||||
username = g.get("username", "guest")
|
||||
|
||||
post_data = await request.json
|
||||
if "message" not in post_data and "image_url" not in post_data:
|
||||
return Response().error("Missing key: message or image_url").__dict__
|
||||
if "message" not in post_data and "files" not in post_data:
|
||||
return Response().error("Missing key: message or files").__dict__
|
||||
|
||||
if "session_id" not in post_data and "conversation_id" not in post_data:
|
||||
return (
|
||||
@@ -125,44 +258,40 @@ class ChatRoute(Route):
|
||||
)
|
||||
|
||||
message = post_data["message"]
|
||||
# conversation_id = post_data["conversation_id"]
|
||||
session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
||||
image_url = post_data.get("image_url")
|
||||
audio_url = post_data.get("audio_url")
|
||||
selected_provider = post_data.get("selected_provider")
|
||||
selected_model = post_data.get("selected_model")
|
||||
enable_streaming = post_data.get("enable_streaming", True) # 默认为 True
|
||||
enable_streaming = post_data.get("enable_streaming", True)
|
||||
|
||||
if not message and not image_url and not audio_url:
|
||||
return (
|
||||
Response()
|
||||
.error("Message and image_url and audio_url are empty")
|
||||
.__dict__
|
||||
# 检查消息是否为空
|
||||
if isinstance(message, list):
|
||||
has_content = any(
|
||||
part.get("type") in ("plain", "image", "record", "file", "video")
|
||||
for part in message
|
||||
)
|
||||
if not has_content:
|
||||
return (
|
||||
Response()
|
||||
.error("Message content is empty (reply only is not allowed)")
|
||||
.__dict__
|
||||
)
|
||||
elif not message:
|
||||
return Response().error("Message are both empty").__dict__
|
||||
|
||||
if not session_id:
|
||||
return Response().error("session_id is empty").__dict__
|
||||
|
||||
# 追加用户消息
|
||||
webchat_conv_id = session_id
|
||||
|
||||
# 获取会话特定的队列
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
||||
|
||||
new_his = {"type": "user", "message": message}
|
||||
if image_url:
|
||||
new_his["image_url"] = image_url
|
||||
if audio_url:
|
||||
new_his["audio_url"] = audio_url
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id=username,
|
||||
sender_name=username,
|
||||
)
|
||||
# 构建用户消息段(包含 path 用于传递给 adapter)
|
||||
message_parts = await self._build_user_message_parts(message)
|
||||
|
||||
async def stream():
|
||||
client_disconnected = False
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
|
||||
try:
|
||||
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||
@@ -181,16 +310,17 @@ class ChatRoute(Route):
|
||||
continue
|
||||
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
msg_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
|
||||
# 发送 SSE 数据
|
||||
try:
|
||||
if not client_disconnected:
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
if not client_disconnected:
|
||||
logger.debug(
|
||||
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}",
|
||||
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
|
||||
)
|
||||
client_disconnected = True
|
||||
|
||||
@@ -201,24 +331,68 @@ class ChatRoute(Route):
|
||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||
client_disconnected = True
|
||||
|
||||
if type == "end":
|
||||
# 累积消息部分
|
||||
if msg_type == "plain":
|
||||
chain_type = result.get("chain_type", "normal")
|
||||
if chain_type == "reasoning":
|
||||
accumulated_reasoning += result_text
|
||||
else:
|
||||
accumulated_text += result_text
|
||||
elif msg_type == "image":
|
||||
filename = result_text.replace("[IMAGE]", "")
|
||||
part = await self._create_attachment_from_file(
|
||||
filename, "image"
|
||||
)
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "record":
|
||||
filename = result_text.replace("[RECORD]", "")
|
||||
part = await self._create_attachment_from_file(
|
||||
filename, "record"
|
||||
)
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "file":
|
||||
# 格式: [FILE]filename
|
||||
filename = result_text.replace("[FILE]", "")
|
||||
part = await self._create_attachment_from_file(
|
||||
filename, "file"
|
||||
)
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
|
||||
# 消息结束处理
|
||||
if msg_type == "end":
|
||||
break
|
||||
elif (
|
||||
(streaming and type == "complete")
|
||||
(streaming and msg_type == "complete")
|
||||
or not streaming
|
||||
or type == "break"
|
||||
or msg_type == "break"
|
||||
):
|
||||
# 追加机器人消息
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
if "reasoning" in result:
|
||||
new_his["reasoning"] = result["reasoning"]
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
saved_record = await self._save_bot_message(
|
||||
webchat_conv_id,
|
||||
accumulated_text,
|
||||
accumulated_parts,
|
||||
accumulated_reasoning,
|
||||
)
|
||||
# 发送保存的消息信息给前端
|
||||
if saved_record and not client_disconnected:
|
||||
saved_info = {
|
||||
"type": "message_saved",
|
||||
"data": {
|
||||
"id": saved_record.id,
|
||||
"created_at": saved_record.created_at.astimezone().isoformat(),
|
||||
},
|
||||
}
|
||||
try:
|
||||
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
# 重置累积变量 (对于 break 后的下一段消息)
|
||||
if msg_type == "break":
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
|
||||
@@ -229,9 +403,7 @@ class ChatRoute(Route):
|
||||
username,
|
||||
webchat_conv_id,
|
||||
{
|
||||
"message": message,
|
||||
"image_url": image_url, # list
|
||||
"audio_url": audio_url,
|
||||
"message": message_parts,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"enable_streaming": enable_streaming,
|
||||
@@ -239,6 +411,19 @@ class ChatRoute(Route):
|
||||
),
|
||||
)
|
||||
|
||||
message_parts_for_storage = []
|
||||
for part in message_parts:
|
||||
part_copy = {k: v for k, v in part.items() if k != "path"}
|
||||
message_parts_for_storage.append(part_copy)
|
||||
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content={"type": "user", "message": message_parts_for_storage},
|
||||
sender_id=username,
|
||||
sender_name=username,
|
||||
)
|
||||
|
||||
response = await make_response(
|
||||
stream(),
|
||||
{
|
||||
@@ -248,7 +433,7 @@ class ChatRoute(Route):
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
response.timeout = None # fix SSE auto disconnect issue
|
||||
response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue]
|
||||
return response
|
||||
|
||||
async def delete_webchat_session(self):
|
||||
@@ -266,9 +451,21 @@ class ChatRoute(Route):
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
# 删除该会话下的所有对话
|
||||
unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}"
|
||||
message_type = "GroupMessage" if session.is_group else "FriendMessage"
|
||||
unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}"
|
||||
await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
|
||||
|
||||
# 获取消息历史中的所有附件 ID 并删除附件
|
||||
history_list = await self.platform_history_mgr.get(
|
||||
platform_id=session.platform_id,
|
||||
user_id=session_id,
|
||||
page=1,
|
||||
page_size=100000, # 获取足够多的记录
|
||||
)
|
||||
attachment_ids = self._extract_attachment_ids(history_list)
|
||||
if attachment_ids:
|
||||
await self._delete_attachments(attachment_ids)
|
||||
|
||||
# 删除消息历史
|
||||
await self.platform_history_mgr.delete(
|
||||
platform_id=session.platform_id,
|
||||
@@ -276,6 +473,16 @@ class ChatRoute(Route):
|
||||
offset_sec=99999999,
|
||||
)
|
||||
|
||||
# 删除与会话关联的配置路由
|
||||
try:
|
||||
await self.umop_config_router.delete_route(unified_msg_origin)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Failed to delete UMO route %s during session cleanup: %s",
|
||||
unified_msg_origin,
|
||||
exc,
|
||||
)
|
||||
|
||||
# 清理队列(仅对 webchat)
|
||||
if session.platform_id == "webchat":
|
||||
webchat_queue_mgr.remove_queues(session_id)
|
||||
@@ -285,6 +492,41 @@ class ChatRoute(Route):
|
||||
|
||||
return Response().ok().__dict__
|
||||
|
||||
def _extract_attachment_ids(self, history_list) -> list[str]:
|
||||
"""从消息历史中提取所有 attachment_id"""
|
||||
attachment_ids = []
|
||||
for history in history_list:
|
||||
content = history.content
|
||||
if not content or "message" not in content:
|
||||
continue
|
||||
message_parts = content.get("message", [])
|
||||
for part in message_parts:
|
||||
if isinstance(part, dict) and "attachment_id" in part:
|
||||
attachment_ids.append(part["attachment_id"])
|
||||
return attachment_ids
|
||||
|
||||
async def _delete_attachments(self, attachment_ids: list[str]):
|
||||
"""删除附件(包括数据库记录和磁盘文件)"""
|
||||
try:
|
||||
attachments = await self.db.get_attachments(attachment_ids)
|
||||
for attachment in attachments:
|
||||
if not os.path.exists(attachment.path):
|
||||
continue
|
||||
try:
|
||||
os.remove(attachment.path)
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"Failed to delete attachment file {attachment.path}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get attachments: {e}")
|
||||
|
||||
# 批量删除数据库记录
|
||||
try:
|
||||
await self.db.delete_attachments(attachment_ids)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete attachments: {e}")
|
||||
|
||||
async def new_session(self):
|
||||
"""Create a new Platform session (default: webchat)."""
|
||||
username = g.get("username", "guest")
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from quart import request
|
||||
|
||||
@@ -13,15 +14,14 @@ from astrbot.core.config.default import (
|
||||
CONFIG_METADATA_3_SYSTEM,
|
||||
DEFAULT_CONFIG,
|
||||
DEFAULT_VALUE_MAP,
|
||||
WEBHOOK_SUPPORTED_PLATFORMS,
|
||||
)
|
||||
from astrbot.core.config.i18n_utils import ConfigMetadataI18n
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_cls_map, platform_registry
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -133,7 +133,9 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
|
||||
is_core,
|
||||
)
|
||||
else:
|
||||
errors, post_config = validate_config(post_config, config.schema, is_core)
|
||||
errors, post_config = validate_config(
|
||||
post_config, getattr(config, "schema", {}), is_core
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(f"验证配置时出现异常: {e}")
|
||||
@@ -247,11 +249,8 @@ class ConfigRoute(Route):
|
||||
|
||||
async def get_default_config(self):
|
||||
"""获取默认配置文件"""
|
||||
return (
|
||||
Response()
|
||||
.ok({"config": DEFAULT_CONFIG, "metadata": CONFIG_METADATA_3})
|
||||
.__dict__
|
||||
)
|
||||
metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3)
|
||||
return Response().ok({"config": DEFAULT_CONFIG, "metadata": metadata}).__dict__
|
||||
|
||||
async def get_abconf_list(self):
|
||||
"""获取所有 AstrBot 配置文件的列表"""
|
||||
@@ -282,17 +281,15 @@ class ConfigRoute(Route):
|
||||
try:
|
||||
if system_config:
|
||||
abconf = self.acm.confs["default"]
|
||||
return (
|
||||
Response()
|
||||
.ok({"config": abconf, "metadata": CONFIG_METADATA_3_SYSTEM})
|
||||
.__dict__
|
||||
metadata = ConfigMetadataI18n.convert_to_i18n_keys(
|
||||
CONFIG_METADATA_3_SYSTEM
|
||||
)
|
||||
return Response().ok({"config": abconf, "metadata": metadata}).__dict__
|
||||
if abconf_id is None:
|
||||
raise ValueError("abconf_id cannot be None")
|
||||
abconf = self.acm.confs[abconf_id]
|
||||
return (
|
||||
Response()
|
||||
.ok({"config": abconf, "metadata": CONFIG_METADATA_3})
|
||||
.__dict__
|
||||
)
|
||||
metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3)
|
||||
return Response().ok({"config": abconf, "metadata": metadata}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -358,169 +355,20 @@ class ConfigRoute(Route):
|
||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
|
||||
)
|
||||
|
||||
if provider_capability_type == ProviderType.CHAT_COMPLETION:
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=45.0,
|
||||
)
|
||||
logger.debug(
|
||||
f"Received response from {status_info['name']}: {response}",
|
||||
)
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if (
|
||||
hasattr(response, "completion_text")
|
||||
and response.completion_text
|
||||
):
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
else response.completion_text
|
||||
)
|
||||
elif hasattr(response, "result_chain") and response.result_chain:
|
||||
try:
|
||||
response_text_snippet = (
|
||||
response.result_chain.get_plain_text()[:70] + "..."
|
||||
if len(response.result_chain.get_plain_text()) > 70
|
||||
else response.result_chain.get_plain_text()
|
||||
)
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
|
||||
)
|
||||
else:
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = (
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
|
||||
)
|
||||
|
||||
elif provider_capability_type == ProviderType.EMBEDDING:
|
||||
try:
|
||||
# For embedding, we can call the get_embedding method with a short prompt.
|
||||
embedding_result = await provider.get_embedding("health_check")
|
||||
if isinstance(embedding_result, list) and (
|
||||
not embedding_result or isinstance(embedding_result[0], float)
|
||||
):
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"Embedding test failed: unexpected result type {type(embedding_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing embedding provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Embedding test failed: {e!s}"
|
||||
|
||||
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
|
||||
try:
|
||||
# For TTS, we can call the get_audio method with a short prompt.
|
||||
audio_result = await provider.get_audio("你好")
|
||||
if isinstance(audio_result, str) and audio_result:
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"TTS test failed: unexpected result type {type(audio_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing TTS provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"TTS test failed: {e!s}"
|
||||
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Sending health check audio to provider: {status_info['name']}",
|
||||
)
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
if not os.path.exists(sample_audio_path):
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
"STT test failed: sample audio file not found."
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
|
||||
)
|
||||
else:
|
||||
text_result = await provider.get_text(sample_audio_path)
|
||||
if isinstance(text_result, str) and text_result:
|
||||
status_info["status"] = "available"
|
||||
snippet = (
|
||||
text_result[:70] + "..."
|
||||
if len(text_result) > 70
|
||||
else text_result
|
||||
)
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
|
||||
)
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"STT test failed: unexpected result type {type(text_result)}"
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing STT provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"STT test failed: {e!s}"
|
||||
elif provider_capability_type == ProviderType.RERANK:
|
||||
try:
|
||||
assert isinstance(provider, RerankProvider)
|
||||
await provider.rerank("Apple", documents=["apple", "banana"])
|
||||
status_info["status"] = "available"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing rerank provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Rerank test failed: {e!s}"
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}",
|
||||
)
|
||||
try:
|
||||
await provider.test()
|
||||
status_info["status"] = "available"
|
||||
status_info["error"] = (
|
||||
"This provider type is not tested and is assumed to be available."
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
|
||||
)
|
||||
|
||||
return status_info
|
||||
@@ -598,9 +446,15 @@ class ConfigRoute(Route):
|
||||
return Response().error("缺少参数 provider_id").__dict__
|
||||
|
||||
prov_mgr = self.core_lifecycle.provider_manager
|
||||
provider: Provider | None = prov_mgr.inst_map.get(provider_id, None)
|
||||
provider = prov_mgr.inst_map.get(provider_id, None)
|
||||
if not provider:
|
||||
return Response().error(f"未找到 ID 为 {provider_id} 的提供商").__dict__
|
||||
if not isinstance(provider, Provider):
|
||||
return (
|
||||
Response()
|
||||
.error(f"提供商 {provider_id} 类型不支持获取模型列表")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
models = await provider.get_models()
|
||||
@@ -703,6 +557,15 @@ class ConfigRoute(Route):
|
||||
|
||||
async def post_new_platform(self):
|
||||
new_platform_config = await request.json
|
||||
|
||||
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,自动生成 webhook_uuid
|
||||
platform_type = new_platform_config.get("type", "")
|
||||
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
|
||||
if new_platform_config.get("unified_webhook_mode", False):
|
||||
# 如果没有 webhook_uuid 或为空,自动生成
|
||||
if not new_platform_config.get("webhook_uuid"):
|
||||
new_platform_config["webhook_uuid"] = uuid.uuid4().hex[:16]
|
||||
|
||||
self.config["platform"].append(new_platform_config)
|
||||
try:
|
||||
save_config(self.config, self.config, is_core=True)
|
||||
@@ -732,6 +595,14 @@ class ConfigRoute(Route):
|
||||
if not platform_id or not new_config:
|
||||
return Response().error("参数错误").__dict__
|
||||
|
||||
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid
|
||||
platform_type = new_config.get("type", "")
|
||||
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
|
||||
if new_config.get("unified_webhook_mode", False):
|
||||
# 如果没有 webhook_uuid 或为空,自动生成
|
||||
if not new_config.get("webhook_uuid"):
|
||||
new_config["webhook_uuid"] = uuid.uuid4().hex
|
||||
|
||||
for i, platform in enumerate(self.config["platform"]):
|
||||
if platform["id"] == platform_id:
|
||||
self.config["platform"][i] = new_config
|
||||
|
||||
@@ -60,10 +60,6 @@ class KnowledgeBaseRoute(Route):
|
||||
# "/kb/media/delete": ("POST", self.delete_media),
|
||||
# 检索
|
||||
"/kb/retrieve": ("POST", self.retrieve),
|
||||
# 会话知识库配置
|
||||
"/kb/session/config/get": ("GET", self.get_session_kb_config),
|
||||
"/kb/session/config/set": ("POST", self.set_session_kb_config),
|
||||
"/kb/session/config/delete": ("POST", self.delete_session_kb_config),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
@@ -278,7 +274,7 @@ class KnowledgeBaseRoute(Route):
|
||||
except Exception as e:
|
||||
return (
|
||||
Response()
|
||||
.error(f"测试重排序模型失败: {e!s},请检查控制台日志输出。")
|
||||
.error(f"测试重排序模型失败: {e!s},请检查平台日志输出。")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
@@ -920,158 +916,6 @@ class KnowledgeBaseRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"检索失败: {e!s}").__dict__
|
||||
|
||||
# ===== 会话知识库配置 API =====
|
||||
|
||||
async def get_session_kb_config(self):
|
||||
"""获取会话的知识库配置
|
||||
|
||||
Query 参数:
|
||||
- session_id: 会话 ID (必填)
|
||||
|
||||
返回:
|
||||
- kb_ids: 知识库 ID 列表
|
||||
- top_k: 返回结果数量
|
||||
- enable_rerank: 是否启用重排序
|
||||
"""
|
||||
try:
|
||||
from astrbot.core import sp
|
||||
|
||||
session_id = request.args.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少参数 session_id").__dict__
|
||||
|
||||
# 从 SharedPreferences 获取配置
|
||||
config = await sp.session_get(session_id, "kb_config", default={})
|
||||
|
||||
logger.debug(f"[KB配置] 读取到配置: session_id={session_id}")
|
||||
|
||||
# 如果没有配置,返回默认值
|
||||
if not config:
|
||||
config = {"kb_ids": [], "top_k": 5, "enable_rerank": True}
|
||||
|
||||
return Response().ok(config).__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True)
|
||||
return Response().error(f"获取会话知识库配置失败: {e!s}").__dict__
|
||||
|
||||
async def set_session_kb_config(self):
|
||||
"""设置会话的知识库配置
|
||||
|
||||
Body:
|
||||
- scope: 配置范围 (目前只支持 "session")
|
||||
- scope_id: 会话 ID (必填)
|
||||
- kb_ids: 知识库 ID 列表 (必填)
|
||||
- top_k: 返回结果数量 (可选, 默认 5)
|
||||
- enable_rerank: 是否启用重排序 (可选, 默认 true)
|
||||
"""
|
||||
try:
|
||||
from astrbot.core import sp
|
||||
|
||||
data = await request.json
|
||||
|
||||
scope = data.get("scope")
|
||||
scope_id = data.get("scope_id")
|
||||
kb_ids = data.get("kb_ids", [])
|
||||
top_k = data.get("top_k", 5)
|
||||
enable_rerank = data.get("enable_rerank", True)
|
||||
|
||||
# 验证参数
|
||||
if scope != "session":
|
||||
return Response().error("目前仅支持 session 范围的配置").__dict__
|
||||
|
||||
if not scope_id:
|
||||
return Response().error("缺少参数 scope_id").__dict__
|
||||
|
||||
if not isinstance(kb_ids, list):
|
||||
return Response().error("kb_ids 必须是列表").__dict__
|
||||
|
||||
# 验证知识库是否存在
|
||||
kb_mgr = self._get_kb_manager()
|
||||
invalid_ids = []
|
||||
valid_ids = []
|
||||
for kb_id in kb_ids:
|
||||
kb_helper = await kb_mgr.get_kb(kb_id)
|
||||
if kb_helper:
|
||||
valid_ids.append(kb_id)
|
||||
else:
|
||||
invalid_ids.append(kb_id)
|
||||
logger.warning(f"[KB配置] 知识库不存在: {kb_id}")
|
||||
|
||||
if invalid_ids:
|
||||
logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}")
|
||||
|
||||
# 允许保存空列表,表示明确不使用任何知识库
|
||||
if kb_ids and not valid_ids:
|
||||
# 只有当用户提供了 kb_ids 但全部无效时才报错
|
||||
return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__
|
||||
|
||||
# 如果 kb_ids 为空列表,表示用户想清空配置
|
||||
if not kb_ids:
|
||||
valid_ids = []
|
||||
|
||||
# 构建配置对象(只保存有效的ID)
|
||||
config = {
|
||||
"kb_ids": valid_ids,
|
||||
"top_k": top_k,
|
||||
"enable_rerank": enable_rerank,
|
||||
}
|
||||
|
||||
# 保存到 SharedPreferences
|
||||
await sp.session_put(scope_id, "kb_config", config)
|
||||
|
||||
# 立即验证是否保存成功
|
||||
verify_config = await sp.session_get(scope_id, "kb_config", default={})
|
||||
|
||||
if verify_config == config:
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{"valid_ids": valid_ids, "invalid_ids": invalid_ids},
|
||||
"保存知识库配置成功",
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
logger.error("[KB配置] 配置保存失败,验证不匹配")
|
||||
return Response().error("配置保存失败").__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True)
|
||||
return Response().error(f"设置会话知识库配置失败: {e!s}").__dict__
|
||||
|
||||
async def delete_session_kb_config(self):
|
||||
"""删除会话的知识库配置
|
||||
|
||||
Body:
|
||||
- scope: 配置范围 (目前只支持 "session")
|
||||
- scope_id: 会话 ID (必填)
|
||||
"""
|
||||
try:
|
||||
from astrbot.core import sp
|
||||
|
||||
data = await request.json
|
||||
|
||||
scope = data.get("scope")
|
||||
scope_id = data.get("scope_id")
|
||||
|
||||
# 验证参数
|
||||
if scope != "session":
|
||||
return Response().error("目前仅支持 session 范围的配置").__dict__
|
||||
|
||||
if not scope_id:
|
||||
return Response().error("缺少参数 scope_id").__dict__
|
||||
|
||||
# 从 SharedPreferences 删除配置
|
||||
await sp.session_remove(scope_id, "kb_config")
|
||||
|
||||
return Response().ok(message="删除知识库配置成功").__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除会话知识库配置失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__
|
||||
|
||||
async def upload_document_from_url(self):
|
||||
"""从 URL 上传文档
|
||||
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
"""Memory management API routes"""
|
||||
|
||||
from quart import jsonify, request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
class MemoryRoute(Route):
|
||||
"""Memory management routes"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
):
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.memory_manager = core_lifecycle.memory_manager
|
||||
self.provider_manager = core_lifecycle.provider_manager
|
||||
self.routes = [
|
||||
("/memory/status", ("GET", self.get_status)),
|
||||
("/memory/initialize", ("POST", self.initialize)),
|
||||
("/memory/update_merge_llm", ("POST", self.update_merge_llm)),
|
||||
]
|
||||
self.register_routes()
|
||||
|
||||
async def get_status(self):
|
||||
"""Get memory system status"""
|
||||
try:
|
||||
is_initialized = self.memory_manager._initialized
|
||||
|
||||
status_data = {
|
||||
"initialized": is_initialized,
|
||||
"embedding_provider_id": None,
|
||||
"merge_llm_provider_id": None,
|
||||
}
|
||||
|
||||
if is_initialized:
|
||||
# Get embedding provider info
|
||||
if self.memory_manager.embedding_provider:
|
||||
status_data["embedding_provider_id"] = (
|
||||
self.memory_manager.embedding_provider.provider_config["id"]
|
||||
)
|
||||
# Get merge LLM provider info
|
||||
if self.memory_manager.merge_llm_provider:
|
||||
status_data["merge_llm_provider_id"] = (
|
||||
self.memory_manager.merge_llm_provider.provider_config["id"]
|
||||
)
|
||||
|
||||
return jsonify(Response().ok(status_data).__dict__)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory status: {e}")
|
||||
return jsonify(Response().error(str(e)).__dict__)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize memory system with embedding and merge LLM providers"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
embedding_provider_id = data.get("embedding_provider_id")
|
||||
merge_llm_provider_id = data.get("merge_llm_provider_id")
|
||||
|
||||
if not embedding_provider_id or not merge_llm_provider_id:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(
|
||||
"embedding_provider_id and merge_llm_provider_id are required"
|
||||
)
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Check if already initialized
|
||||
if self.memory_manager._initialized:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(
|
||||
"Memory system already initialized. Embedding provider cannot be changed.",
|
||||
)
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Get providers
|
||||
embedding_provider = await self.provider_manager.get_provider_by_id(
|
||||
embedding_provider_id,
|
||||
)
|
||||
merge_llm_provider = await self.provider_manager.get_provider_by_id(
|
||||
merge_llm_provider_id,
|
||||
)
|
||||
|
||||
if not embedding_provider:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(f"Embedding provider {embedding_provider_id} not found")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
if not merge_llm_provider:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(f"Merge LLM provider {merge_llm_provider_id} not found")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Initialize memory manager
|
||||
await self.memory_manager.initialize(
|
||||
embedding_provider=embedding_provider,
|
||||
merge_llm_provider=merge_llm_provider,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Memory system initialized with embedding: {embedding_provider_id}, "
|
||||
f"merge LLM: {merge_llm_provider_id}",
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
Response()
|
||||
.ok({"message": "Memory system initialized successfully"})
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize memory system: {e}")
|
||||
return jsonify(Response().error(str(e)).__dict__)
|
||||
|
||||
async def update_merge_llm(self):
|
||||
"""Update merge LLM provider (only allowed after initialization)"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
merge_llm_provider_id = data.get("merge_llm_provider_id")
|
||||
|
||||
if not merge_llm_provider_id:
|
||||
return jsonify(
|
||||
Response().error("merge_llm_provider_id is required").__dict__,
|
||||
)
|
||||
|
||||
# Check if initialized
|
||||
if not self.memory_manager._initialized:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error("Memory system not initialized. Please initialize first.")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Get new merge LLM provider
|
||||
merge_llm_provider = await self.provider_manager.get_provider_by_id(
|
||||
merge_llm_provider_id,
|
||||
)
|
||||
|
||||
if not merge_llm_provider:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(f"Merge LLM provider {merge_llm_provider_id} not found")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Update merge LLM provider
|
||||
self.memory_manager.merge_llm_provider = merge_llm_provider
|
||||
|
||||
logger.info(f"Updated merge LLM provider to: {merge_llm_provider_id}")
|
||||
|
||||
return jsonify(
|
||||
Response()
|
||||
.ok({"message": "Merge LLM provider updated successfully"})
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update merge LLM provider: {e}")
|
||||
return jsonify(Response().error(str(e)).__dict__)
|
||||
@@ -0,0 +1,100 @@
|
||||
"""统一 Webhook 路由
|
||||
|
||||
提供统一的 webhook 回调入口,支持多个平台使用同一端口接收回调。
|
||||
"""
|
||||
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform import Platform
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
class PlatformRoute(Route):
|
||||
"""统一 Webhook 路由"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.platform_manager = core_lifecycle.platform_manager
|
||||
|
||||
self._register_webhook_routes()
|
||||
|
||||
def _register_webhook_routes(self):
|
||||
"""注册 webhook 路由"""
|
||||
# 统一 webhook 入口,支持 GET 和 POST
|
||||
self.app.add_url_rule(
|
||||
"/api/platform/webhook/<webhook_uuid>",
|
||||
view_func=self.unified_webhook_callback,
|
||||
methods=["GET", "POST"],
|
||||
)
|
||||
|
||||
# 平台统计信息接口
|
||||
self.app.add_url_rule(
|
||||
"/api/platform/stats",
|
||||
view_func=self.get_platform_stats,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
async def unified_webhook_callback(self, webhook_uuid: str):
|
||||
"""统一 webhook 回调入口
|
||||
|
||||
Args:
|
||||
webhook_uuid: 平台配置中的 webhook_uuid
|
||||
|
||||
Returns:
|
||||
根据平台适配器返回相应的响应
|
||||
"""
|
||||
# 根据 webhook_uuid 查找对应的平台
|
||||
platform_adapter = self._find_platform_by_uuid(webhook_uuid)
|
||||
|
||||
if not platform_adapter:
|
||||
logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台")
|
||||
return Response().error("未找到对应平台").__dict__, 404
|
||||
|
||||
# 调用平台适配器的 webhook_callback 方法
|
||||
try:
|
||||
result = await platform_adapter.webhook_callback(request)
|
||||
return result
|
||||
except NotImplementedError:
|
||||
logger.error(
|
||||
f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法"
|
||||
)
|
||||
return Response().error("平台未支持统一 Webhook 模式").__dict__, 500
|
||||
except Exception as e:
|
||||
logger.error(f"处理 webhook 回调时发生错误: {e}", exc_info=True)
|
||||
return Response().error("处理回调失败").__dict__, 500
|
||||
|
||||
def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None:
|
||||
"""根据 webhook_uuid 查找对应的平台适配器
|
||||
|
||||
Args:
|
||||
webhook_uuid: webhook UUID
|
||||
|
||||
Returns:
|
||||
平台适配器实例,未找到则返回 None
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.config.get("webhook_uuid") == webhook_uuid:
|
||||
if platform.config.get("unified_webhook_mode", False):
|
||||
return platform
|
||||
return None
|
||||
|
||||
async def get_platform_stats(self):
|
||||
"""获取所有平台的统计信息
|
||||
|
||||
Returns:
|
||||
包含平台统计信息的响应
|
||||
"""
|
||||
try:
|
||||
stats = self.platform_manager.get_all_stats()
|
||||
return Response().ok(stats).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台统计信息失败: {e}", exc_info=True)
|
||||
return Response().error(f"获取统计信息失败: {e}").__dict__, 500
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
@@ -19,6 +20,10 @@ from astrbot.core.star.star_manager import PluginManager
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
PLUGIN_UPDATE_CONCURRENCY = (
|
||||
3 # limit concurrent updates to avoid overwhelming plugin sources
|
||||
)
|
||||
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(
|
||||
@@ -33,6 +38,7 @@ class PluginRoute(Route):
|
||||
"/plugin/install": ("POST", self.install_plugin),
|
||||
"/plugin/install-upload": ("POST", self.install_plugin_upload),
|
||||
"/plugin/update": ("POST", self.update_plugin),
|
||||
"/plugin/update-all": ("POST", self.update_all_plugins),
|
||||
"/plugin/uninstall": ("POST", self.uninstall_plugin),
|
||||
"/plugin/market_list": ("GET", self.get_online_plugins),
|
||||
"/plugin/off": ("POST", self.off_plugin),
|
||||
@@ -63,7 +69,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
data = await request.json
|
||||
data = await request.get_json()
|
||||
plugin_name = data.get("name", None)
|
||||
try:
|
||||
success, message = await self.plugin_manager.reload(plugin_name)
|
||||
@@ -346,7 +352,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
repo_url = post_data["url"]
|
||||
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
@@ -393,7 +399,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
delete_config = post_data.get("delete_config", False)
|
||||
delete_data = post_data.get("delete_data", False)
|
||||
@@ -418,7 +424,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
try:
|
||||
@@ -432,6 +438,59 @@ class PluginRoute(Route):
|
||||
logger.error(f"/api/plugin/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def update_all_plugins(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.get_json()
|
||||
plugin_names: list[str] = post_data.get("names") or []
|
||||
proxy: str = post_data.get("proxy", "")
|
||||
|
||||
if not isinstance(plugin_names, list) or not plugin_names:
|
||||
return Response().error("插件列表不能为空").__dict__
|
||||
|
||||
results = []
|
||||
sem = asyncio.Semaphore(PLUGIN_UPDATE_CONCURRENCY)
|
||||
|
||||
async def _update_one(name: str):
|
||||
async with sem:
|
||||
try:
|
||||
logger.info(f"批量更新插件 {name}")
|
||||
await self.plugin_manager.update_plugin(name, proxy)
|
||||
return {"name": name, "status": "ok", "message": "更新成功"}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/plugin/update-all: 更新插件 {name} 失败: {traceback.format_exc()}",
|
||||
)
|
||||
return {"name": name, "status": "error", "message": str(e)}
|
||||
|
||||
raw_results = await asyncio.gather(
|
||||
*(_update_one(name) for name in plugin_names),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for name, result in zip(plugin_names, raw_results):
|
||||
if isinstance(result, asyncio.CancelledError):
|
||||
raise result
|
||||
if isinstance(result, BaseException):
|
||||
results.append(
|
||||
{"name": name, "status": "error", "message": str(result)}
|
||||
)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
failed = [r for r in results if r["status"] == "error"]
|
||||
message = (
|
||||
"批量更新完成,全部成功。"
|
||||
if not failed
|
||||
else f"批量更新完成,其中 {len(failed)}/{len(results)} 个插件失败。"
|
||||
)
|
||||
|
||||
return Response().ok({"results": results}, message).__dict__
|
||||
|
||||
async def off_plugin(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
@@ -440,7 +499,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_off_plugin(plugin_name)
|
||||
@@ -458,7 +517,7 @@ class PluginRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
post_data = await request.get_json()
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_on_plugin(plugin_name)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,6 +16,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
|
||||
from .routes import *
|
||||
from .routes.platform import PlatformRoute
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
from .routes.t2i import T2iRoute
|
||||
@@ -79,7 +80,7 @@ class AstrBotDashboard:
|
||||
self.persona_route = PersonaRoute(self.context, db, core_lifecycle)
|
||||
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
||||
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
||||
self.memory_route = MemoryRoute(self.context, db, core_lifecycle)
|
||||
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
@@ -103,7 +104,7 @@ class AstrBotDashboard:
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return None
|
||||
allowed_endpoints = ["/api/auth/login", "/api/file"]
|
||||
allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
|
||||
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
||||
return None
|
||||
# 声明 JWT
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user