Compare commits
154 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7cedf0d587 | |||
| aeb21f719e | |||
| 7c1dbecea5 | |||
| 05012af627 | |||
| 17b52ab5dd | |||
| 9449ff668b | |||
| c5a2827def | |||
| 701399c00c | |||
| eaee98d4b8 | |||
| 76c66000a7 | |||
| 4b365143c0 | |||
| 6e4e5011e2 | |||
| d853bfde84 | |||
| a0e856f80f | |||
| 8c94a0010c | |||
| a44fdaaec0 | |||
| 60105c76f5 | |||
| bcf87d3ce4 | |||
| 4d7c8c8453 | |||
| a064a9115f | |||
| 6ef99e1553 | |||
| c0dbe5cf65 | |||
| 3598c51eff | |||
| b5cdb8f650 | |||
| fc5b520f9b | |||
| 904f56b32f | |||
| 2f15fd019c | |||
| 82330b8d10 | |||
| 3ee6af7027 | |||
| 6e20ebe901 | |||
| 4d6150fd6d | |||
| 544e52191b | |||
| f2c2a6da4a | |||
| dd3df425ee | |||
| 40b4a27a3d | |||
| 9d991c7468 | |||
| ad6a8b5c94 | |||
| 1b4bfcbd72 | |||
| 9d3cc593a1 | |||
| f0dee35ba9 | |||
| 4135bd84d5 | |||
| f6da614e5d | |||
| 5f531c9be5 | |||
| 94591d965b | |||
| 8a0f865af1 | |||
| 4aced976a8 | |||
| 0299aa6e4c | |||
| e8b54a019e | |||
| 98ce796275 | |||
| b87dcf2275 | |||
| 591a228431 | |||
| f52f375154 | |||
| 975c685a17 | |||
| 6db80d36a8 | |||
| 4651bd2807 | |||
| 94ada3793e | |||
| fd05b0bf09 | |||
| 4d046f8490 | |||
| 58e32b7b70 | |||
| 903dd0f9f7 | |||
| 1acac0cac2 | |||
| 80b89fd2ea | |||
| 26f863ba81 | |||
| f78a90218e | |||
| a3ecebd2aa | |||
| 67c33b842d | |||
| 5431c9f46e | |||
| 764b91a5f7 | |||
| c20c1b84bf | |||
| fd66a0ac00 | |||
| aaee283367 | |||
| 4a5b7d1976 | |||
| 08244548ab | |||
| b486de6a98 | |||
| e2f928a7e5 | |||
| b8e4068c75 | |||
| 0916177a57 | |||
| 02cd5e396b | |||
| 56673ad78f | |||
| 9a4d05e2b6 | |||
| b2e9dab233 | |||
| 45110200ea | |||
| c3f45449e8 | |||
| 65da469deb | |||
| 16df64c405 | |||
| 6b73b19e54 | |||
| a70088b799 | |||
| e7e97730af | |||
| 467ca1eb5c | |||
| bb45d9cb54 | |||
| 46528391c2 | |||
| 8a0b7717cc | |||
| 3b81fb4985 | |||
| c09d57a820 | |||
| ec408a2aff | |||
| 417179a6b9 | |||
| fcd29445c7 | |||
| 5f535001db | |||
| 750d245b16 | |||
| f624971613 | |||
| aa6d07afcc | |||
| 2c36649874 | |||
| c95735dcc0 | |||
| 03bb278f50 | |||
| a5e0974da3 | |||
| f0fb447fbc | |||
| 37566182b0 | |||
| e460b411da | |||
| e14ed804da | |||
| 8e4e49df20 | |||
| 5d856900ef | |||
| 380a68b96c | |||
| 8879bd7e9d | |||
| 2cce09400f | |||
| 54d26dcd38 | |||
| 205024f27a | |||
| efde994907 | |||
| 8ca4f9cb74 | |||
| 54e49b997b | |||
| 5714944eef | |||
| defc46b6c9 | |||
| 4d819546b0 | |||
| 8006981976 | |||
| f7a716af43 | |||
| a708901e7f | |||
| e9be8cf69f | |||
| 31d53edb9d | |||
| 2ba0460f19 | |||
| 0e034f0fbd | |||
| 2a7d03f9e1 | |||
| 72fac4b9f1 | |||
| 38281ba2cf | |||
| 21aa3174f4 | |||
| dcda871fc0 | |||
| c13c51f499 | |||
| a130db5cf4 | |||
| 7faeb5cea8 | |||
| 8d3ff61e0d | |||
| 4c03e82570 | |||
| e7e8664ab4 | |||
| 1dd1623e7d | |||
| 80d8161d58 | |||
| fc80d7d681 | |||
| c2f036b27c | |||
| 4087bbb512 | |||
| e1c728582d | |||
| 93c69a639a | |||
| a7fdc98b29 | |||
| 85b7f104df | |||
| d76d1bd7fe | |||
| df4412aa80 | |||
| ab2c94e19a | |||
| 37cc4e2121 | |||
| 60dfdd0a66 |
@@ -36,7 +36,7 @@ jobs:
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: dist-without-markdown
|
||||
path: |
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
name: Smoke Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README*.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
smoke-test:
|
||||
name: Run smoke tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV package manager
|
||||
run: |
|
||||
pip install uv
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync
|
||||
timeout-minutes: 15
|
||||
|
||||
- name: Run smoke tests
|
||||
run: |
|
||||
uv run main.py &
|
||||
APP_PID=$!
|
||||
|
||||
echo "Waiting for application to start..."
|
||||
for i in {1..60}; do
|
||||
if curl -f http://localhost:6185 > /dev/null 2>&1; then
|
||||
echo "Application started successfully!"
|
||||
kill $APP_PID
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Application failed to start within 30 seconds"
|
||||
kill $APP_PID 2>/dev/null || true
|
||||
exit 1
|
||||
timeout-minutes: 2
|
||||
@@ -34,6 +34,7 @@ dashboard/node_modules/
|
||||
dashboard/dist/
|
||||
package-lock.json
|
||||
package.json
|
||||
yarn.lock
|
||||
|
||||
# Operating System
|
||||
**/.DS_Store
|
||||
@@ -47,3 +48,5 @@ astrbot.lock
|
||||
chroma
|
||||
venv/*
|
||||
pytest.ini
|
||||
AGENTS.md
|
||||
IFLOW.md
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
# CONTRIBUTING
|
||||
|
||||
## 贡献指南
|
||||
|
||||
首先,感谢您花时间做出贡献!❤️
|
||||
|
||||
所有类型的贡献都受到鼓励和重视。有关不同的帮助方式和处理方式的详细信息,请参阅[目录](#目录)。在做出贡献之前,请确保阅读相关部分。这将使我们维护人员的工作变得更加容易,并为所有参与者带来顺畅的体验。社区期待您的贡献。🎉
|
||||
|
||||
### 目录
|
||||
|
||||
- [报告问题](#报告问题)
|
||||
- [提交代码更改](#提交代码更改)
|
||||
|
||||
### 报告问题
|
||||
|
||||
如果您在使用 AstrBot 时遇到任何问题,请按照以下步骤报告:
|
||||
|
||||
1. **检查现有问题**:在提交新问题之前,请先检查 [Issues](https://github.com/AstrBotDevs/AstrBot/issues) 中是否已经存在类似的问题。
|
||||
2. **创建新问题**:如果没有类似的问题,请创建一个新问题。请确保提供以下信息:
|
||||
- 问题的简要描述
|
||||
- 重现问题的步骤
|
||||
- 预期结果和实际结果
|
||||
- 相关日志或错误消息
|
||||
|
||||
### 提交代码更改
|
||||
|
||||
#### 分支命名
|
||||
|
||||
我们使用 `fix/` 前缀来修复错误,使用 `feat/` 前缀来添加新功能。对于 `fix/` 分支,请使用简短的描述,或者直接使用 Issue 编号。例如:`fix/1234` 或者 `fix/1234-login-typo`。对于 `feat/` 分支,请使用简短的描述,例如:`feat/add-user-profile`。
|
||||
|
||||
#### PR 描述
|
||||
|
||||
- 请使用英文描述您的 PR。
|
||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||
|
||||
#### 代码规范
|
||||
|
||||
##### Core
|
||||
|
||||
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
||||
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
First off, thanks for taking the time to contribute! ❤️
|
||||
|
||||
All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- [Reporting Issues](#reporting-issues)
|
||||
- [Pull Requests](#pull-requests)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
If you encounter any issues while using AstrBot, please follow these steps to report them:
|
||||
1. **Check Existing Issues**: Before submitting a new issue, please check if a similar issue already exists in the [Issues](https://github.com/AstrBotDevs/AstrBot/issues) section of the repository.
|
||||
2. **Create a New Issue**: If no similar issue exists, please create a new issue. Make sure to provide the following information:
|
||||
- A brief description of the issue
|
||||
- Steps to reproduce the issue
|
||||
- Expected and actual results
|
||||
- Relevant logs or error messages
|
||||
|
||||
### Pull Requests
|
||||
|
||||
#### Branch Naming
|
||||
|
||||
We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features. For `fix/` branches, please use a short description or directly use the Issue number, e.g., `fix/1234` or `fix/1234-login-typo`. For `feat/` branches, please use a short description, e.g., `feat/add-user-profile`.
|
||||
|
||||
#### PR Description
|
||||
- Please use English to describe your PR.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
|
||||
#### Code Style
|
||||
|
||||
##### Core
|
||||
|
||||
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
@@ -1,10 +1,13 @@
|
||||

|
||||
|
||||
</p>
|
||||

|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
@@ -14,35 +17,38 @@
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">文档</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## 主要功能
|
||||
|
||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
1. 💯 免费 & 开源。
|
||||
1. ✨ AI 大模型对话,多模态,Agent,MCP,知识库,人格设定。
|
||||
2. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。
|
||||
2. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。
|
||||
3. 📦 插件扩展,已有近 800 个插件可一键安装。
|
||||
5. 💻 WebUI 支持。
|
||||
6. 🌐 国际化(i18n)支持。
|
||||
|
||||
## 部署方式
|
||||
## 快速开始
|
||||
|
||||
#### Docker 部署(推荐 🥳)
|
||||
|
||||
@@ -50,6 +56,12 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接
|
||||
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
#### uv 部署
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
AstrBot 与宝塔面板合作,已上架至宝塔面板。
|
||||
@@ -101,24 +113,6 @@ uv run main.py
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
## 🌍 社区
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群组
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## 支持的消息平台
|
||||
|
||||
**官方维护**
|
||||
@@ -205,6 +199,25 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 社区
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 7 群:743746109
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群组
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
@@ -230,4 +243,10 @@ pre-commit install
|
||||
|
||||
</details>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
</div
|
||||
|
||||
|
||||
+40
-26
@@ -19,30 +19,38 @@
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
|
||||
</div>
|
||||
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform and development framework.
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform that integrates with mainstream instant messaging apps. It provides reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether you're building a personal AI companion, intelligent customer service, automation assistant, or enterprise knowledge base, AstrBot enables you to quickly build production-ready AI applications within your IM platform workflows.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Key Features
|
||||
|
||||
1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more.
|
||||
2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation.
|
||||
3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more.
|
||||
4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem.
|
||||
5. **Web UI**. Visual configuration and management of your bot with comprehensive features.
|
||||
1. 💯 Free & Open Source.
|
||||
2. ✨ AI LLM Conversations, Multimodal, Agent, MCP, Knowledge Base, Persona Settings.
|
||||
3. 🤖 Supports integration with Dify, Alibaba Cloud Bailian, Coze and other agent platforms.
|
||||
4. 🌐 Multi-Platform: QQ, WeChat Work, Feishu, DingTalk, WeChat Official Accounts, Telegram, Slack, and [more](#supported-messaging-platforms).
|
||||
5. 📦 Plugin Extensions with nearly 800 plugins available for one-click installation.
|
||||
6. 💻 WebUI Support.
|
||||
7. 🌐 Internationalization (i18n) Support.
|
||||
|
||||
## Deployment Methods
|
||||
## Quick Start
|
||||
|
||||
#### Docker Deployment (Recommended 🥳)
|
||||
|
||||
@@ -50,6 +58,12 @@ We recommend deploying AstrBot using Docker or Docker Compose.
|
||||
|
||||
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### uv Deployment
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### BT-Panel Deployment
|
||||
|
||||
AstrBot has partnered with BT-Panel and is now available in their marketplace.
|
||||
@@ -101,24 +115,6 @@ uv run main.py
|
||||
|
||||
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## 🌍 Community
|
||||
|
||||
### QQ Groups
|
||||
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
### Telegram Group
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord Server
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## Supported Messaging Platforms
|
||||
|
||||
**Officially Maintained**
|
||||
@@ -205,6 +201,24 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Community
|
||||
|
||||
### QQ Groups
|
||||
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
### Telegram Group
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord Server
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
|
||||
|
||||
+248
@@ -0,0 +1,248 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Feuille de route</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Signaler un problème</a>
|
||||
</div>
|
||||
|
||||
AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègre aux principales applications de messagerie instantanée. Elle fournit une infrastructure d'IA conversationnelle fiable et évolutive pour les particuliers, les développeurs et les équipes. Que vous construisiez un compagnon IA personnel, un service client intelligent, un assistant d'automatisation ou une base de connaissances d'entreprise, AstrBot vous permet de créer rapidement des applications d'IA prêtes pour la production dans les flux de travail de votre plateforme de messagerie.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Fonctionnalités principales
|
||||
|
||||
1. 💯 Gratuit & Open Source.
|
||||
2. ✨ Conversations avec LLM IA, Multimodal, Agent, MCP, Base de connaissances, Paramètres de personnalité.
|
||||
3. 🤖 Prise en charge de l'intégration avec Dify, Alibaba Cloud Bailian, Coze et autres plateformes d'agents.
|
||||
4. 🌐 Multi-plateforme : QQ, WeChat Work, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack, et [plus encore](#plateformes-de-messagerie-prises-en-charge).
|
||||
5. 📦 Extensions de plugins avec près de 800 plugins disponibles pour une installation en un clic.
|
||||
6. 💻 Support WebUI.
|
||||
7. 🌐 Support de l'internationalisation (i18n).
|
||||
|
||||
## Démarrage rapide
|
||||
|
||||
#### Déploiement Docker (Recommandé 🥳)
|
||||
|
||||
Nous recommandons de déployer AstrBot en utilisant Docker ou Docker Compose.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### Déploiement uv
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### Déploiement BT-Panel
|
||||
|
||||
AstrBot s'est associé à BT-Panel et est maintenant disponible sur leur marketplace.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### Déploiement 1Panel
|
||||
|
||||
AstrBot a été officiellement listé sur le marketplace 1Panel.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Déployer sur RainYun
|
||||
|
||||
AstrBot a été officiellement listé sur la plateforme d'applications cloud de RainYun avec un déploiement en un clic.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Déployer sur Replit
|
||||
|
||||
Méthode de déploiement contribuée par la communauté.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Installateur Windows en un clic
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déployer AstrBot avec l'installateur Windows en un clic](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### Déploiement CasaOS
|
||||
|
||||
Méthode de déploiement contribuée par la communauté.
|
||||
|
||||
Veuillez consulter la documentation officielle : [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Déploiement manuel
|
||||
|
||||
Tout d'abord, installez uv :
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Installez AstrBot via Git Clone :
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## Plateformes de messagerie prises en charge
|
||||
|
||||
**Maintenues officiellement**
|
||||
|
||||
- QQ (Plateforme officielle & OneBot)
|
||||
- Telegram
|
||||
- Application WeChat Work & Bot intelligent WeChat Work
|
||||
- Service client WeChat & Comptes officiels WeChat
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Bientôt disponible)
|
||||
- LINE (Bientôt disponible)
|
||||
|
||||
**Maintenues par la communauté**
|
||||
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## Services de modèles pris en charge
|
||||
|
||||
**Services LLM**
|
||||
|
||||
- OpenAI et services compatibles
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Auto-hébergé)
|
||||
- LM Studio (Auto-hébergé)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**Plateformes LLMOps**
|
||||
|
||||
- Dify
|
||||
- Applications Alibaba Cloud Bailian
|
||||
- Coze
|
||||
|
||||
**Services de reconnaissance vocale**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**Services de synthèse vocale**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
## ❤️ Contribuer
|
||||
|
||||
Les Issues et Pull Requests sont toujours les bienvenues ! N'hésitez pas à soumettre vos modifications à ce projet :)
|
||||
|
||||
### Comment contribuer
|
||||
|
||||
Vous pouvez contribuer en examinant les issues ou en aidant à la revue des pull requests. Toutes les issues ou PRs sont les bienvenues pour encourager la participation de la communauté. Bien sûr, ce ne sont que des suggestions - vous pouvez contribuer de la manière que vous souhaitez. Pour l'ajout de nouvelles fonctionnalités, veuillez d'abord en discuter via une Issue.
|
||||
|
||||
### Environnement de développement
|
||||
|
||||
AstrBot utilise `ruff` pour le formatage et le linting du code.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Communauté
|
||||
|
||||
### Groupes QQ
|
||||
|
||||
- Groupe 1 : 322154837
|
||||
- Groupe 3 : 630166526
|
||||
- Groupe 5 : 822130018
|
||||
- Groupe 6 : 753075035
|
||||
- Groupe développeurs : 975206796
|
||||
|
||||
### Groupe Telegram
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Serveur Discord
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Remerciements spéciaux
|
||||
|
||||
Un grand merci à tous les contributeurs et développeurs de plugins pour leurs contributions à AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
De plus, la naissance de ce projet n'aurait pas été possible sans l'aide des projets open source suivants :
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - L'incroyable framework chat
|
||||
|
||||
## ⭐ Historique des étoiles
|
||||
|
||||
> [!TIP]
|
||||
> Si ce projet vous a aidé dans votre vie ou votre travail, ou si vous êtes intéressé par son développement futur, veuillez donner une étoile au projet. C'est la force motrice derrière la maintenance de ce projet open source <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
+40
-26
@@ -19,30 +19,38 @@
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E3%83%97%E3%83%A9%E3%82%B0%E3%82%A4%E3%83%B3&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">ドキュメント</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue</a>
|
||||
</div>
|
||||
|
||||
AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークです。
|
||||
AstrBot は、主要なインスタントメッセージングアプリと統合できるオープンソースのオールインワン Agent チャットボットプラットフォームです。個人、開発者、チームに信頼性が高くスケーラブルな会話型 AI インフラストラクチャを提供します。パーソナル AI コンパニオン、インテリジェントカスタマーサービス、オートメーションアシスタント、エンタープライズナレッジベースなど、AstrBot を使用すると、IM プラットフォームのワークフロー内で本番環境対応の AI アプリケーションを迅速に構築できます。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## 主な機能
|
||||
|
||||
1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載。
|
||||
2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート。
|
||||
3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。
|
||||
4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム。
|
||||
5. **WebUI**。ビジュアル設定とボット管理、充実した機能。
|
||||
1. 💯 無料 & オープンソース。
|
||||
2. ✨ AI 大規模言語モデル対話、マルチモーダル、Agent、MCP、ナレッジベース、ペルソナ設定。
|
||||
3. 🤖 Dify、Alibaba Cloud 百炼、Coze などの Agent プラットフォームとの統合をサポート。
|
||||
4. 🌐 マルチプラットフォーム:QQ、WeChat Work、Feishu、DingTalk、WeChat 公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)。
|
||||
5. 📦 約800個のプラグインをワンクリックでインストール可能なプラグイン拡張機能。
|
||||
6. 💻 WebUI サポート。
|
||||
7. 🌐 国際化(i18n)サポート。
|
||||
|
||||
## デプロイ方法
|
||||
## クイックスタート
|
||||
|
||||
#### Docker デプロイ(推奨 🥳)
|
||||
|
||||
@@ -50,6 +58,12 @@ Docker / Docker Compose を使用した AstrBot のデプロイを推奨しま
|
||||
|
||||
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
|
||||
|
||||
#### uv デプロイ
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 宝塔パネルデプロイ
|
||||
|
||||
AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
|
||||
@@ -101,24 +115,6 @@ uv run main.py
|
||||
|
||||
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
|
||||
|
||||
## 🌍 コミュニティ
|
||||
|
||||
### QQ グループ
|
||||
|
||||
- 1群:322154837
|
||||
- 3群:630166526
|
||||
- 5群:822130018
|
||||
- 6群:753075035
|
||||
- 開発者群:975206796
|
||||
|
||||
### Telegram グループ
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord サーバー
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## サポートされているメッセージプラットフォーム
|
||||
|
||||
**公式メンテナンス**
|
||||
@@ -205,6 +201,24 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 コミュニティ
|
||||
|
||||
### QQ グループ
|
||||
|
||||
- 1群: 322154837
|
||||
- 3群: 630166526
|
||||
- 5群: 822130018
|
||||
- 6群: 753075035
|
||||
- 開発者群: 975206796
|
||||
|
||||
### Telegram グループ
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord サーバー
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
|
||||
|
||||
+248
@@ -0,0 +1,248 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20%D0%BF%D0%BB%D0%B0%D0%B3%D0%B8%D0%BD%D0%BE%D0%B2&style=for-the-badge&label=%D0%9C%D0%B0%D0%B3%D0%B0%D0%B7%D0%B8%D0%BD&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
|
||||
|
||||
<a href="https://astrbot.app/">Документация</a> |
|
||||
<a href="https://blog.astrbot.app/">Блог</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Дорожная карта</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Сообщить о проблеме</a>
|
||||
</div>
|
||||
|
||||
AstrBot — это универсальная платформа Agent-чатботов с открытым исходным кодом, которая интегрируется с основными приложениями для обмена мгновенными сообщениями. Она предоставляет надёжную и масштабируемую инфраструктуру разговорного ИИ для частных лиц, разработчиков и команд. Будь то персональный ИИ-компаньон, интеллектуальная служба поддержки, автоматизированный помощник или корпоративная база знаний — AstrBot позволяет быстро создавать готовые к использованию ИИ-приложения в рабочих процессах вашей платформы обмена сообщениями.
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## Основные возможности
|
||||
|
||||
1. 💯 Бесплатно и с открытым исходным кодом.
|
||||
2. ✨ ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности.
|
||||
3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов.
|
||||
4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями).
|
||||
5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик.
|
||||
6. 💻 Поддержка WebUI.
|
||||
7. 🌐 Поддержка интернационализации (i18n).
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
#### Развёртывание Docker (Рекомендуется 🥳)
|
||||
|
||||
Мы рекомендуем развёртывать AstrBot с помощью Docker или Docker Compose.
|
||||
|
||||
См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### Развёртывание uv
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### Развёртывание BT-Panel
|
||||
|
||||
AstrBot в партнёрстве с BT-Panel теперь доступен на их маркетплейсе.
|
||||
|
||||
См. официальную документацию: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### Развёртывание 1Panel
|
||||
|
||||
AstrBot официально размещён на маркетплейсе 1Panel.
|
||||
|
||||
См. официальную документацию: [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Развёртывание на RainYun
|
||||
|
||||
AstrBot официально размещён на облачной платформе приложений RainYun с развёртыванием в один клик.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Развёртывание на Replit
|
||||
|
||||
Метод развёртывания от сообщества.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Установщик Windows в один клик
|
||||
|
||||
См. официальную документацию: [Развёртывание AstrBot с установщиком Windows в один клик](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### Развёртывание CasaOS
|
||||
|
||||
Метод развёртывания от сообщества.
|
||||
|
||||
См. официальную документацию: [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Ручное развёртывание
|
||||
|
||||
Сначала установите uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Установите AstrBot через Git Clone:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
## Поддерживаемые платформы обмена сообщениями
|
||||
|
||||
**Официально поддерживаемые**
|
||||
|
||||
- QQ (Официальная платформа и OneBot)
|
||||
- Telegram
|
||||
- Приложение WeChat Work и интеллектуальный бот WeChat Work
|
||||
- Служба поддержки WeChat и официальные аккаунты WeChat
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Скоро)
|
||||
- LINE (Скоро)
|
||||
|
||||
**Поддерживаемые сообществом**
|
||||
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## Поддерживаемые сервисы моделей
|
||||
|
||||
**Сервисы LLM**
|
||||
|
||||
- OpenAI и совместимые сервисы
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Самостоятельное размещение)
|
||||
- LM Studio (Самостоятельное размещение)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**Платформы LLMOps**
|
||||
|
||||
- Dify
|
||||
- Приложения Alibaba Cloud Bailian
|
||||
- Coze
|
||||
|
||||
**Сервисы распознавания речи**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**Сервисы синтеза речи**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
## ❤️ Вклад в проект
|
||||
|
||||
Issues и Pull Request всегда приветствуются! Не стесняйтесь отправлять свои изменения в этот проект :)
|
||||
|
||||
### Как внести вклад
|
||||
|
||||
Вы можете внести вклад, просматривая issues или помогая с ревью pull request. Любые issues или PR приветствуются для поощрения участия сообщества. Конечно, это лишь предложения — вы можете вносить вклад любым удобным для вас способом. Для добавления новых функций сначала обсудите это через Issue.
|
||||
|
||||
### Среда разработки
|
||||
|
||||
AstrBot использует `ruff` для форматирования и линтинга кода.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 Сообщество
|
||||
|
||||
### Группы QQ
|
||||
|
||||
- Группа 1: 322154837
|
||||
- Группа 3: 630166526
|
||||
- Группа 5: 822130018
|
||||
- Группа 6: 753075035
|
||||
- Группа разработчиков: 975206796
|
||||
|
||||
### Группа Telegram
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Сервер Discord
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Особая благодарность
|
||||
|
||||
Особая благодарность всем контрибьюторам и разработчикам плагинов за их вклад в AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
Кроме того, рождение этого проекта было бы невозможно без помощи следующих проектов с открытым исходным кодом:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - Замечательный кошачий фреймворк
|
||||
|
||||
## ⭐ История звёзд
|
||||
|
||||
> [!TIP]
|
||||
> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
+248
@@ -0,0 +1,248 @@
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%A0%B4&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">简体中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
|
||||
|
||||
<a href="https://astrbot.app/">文件</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路線圖</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題回報</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主流即時通訊軟體,為個人、開發者和團隊打造可靠、可擴展的對話式智慧基礎設施。無論是個人 AI 夥伴、智慧客服、自動化助手,還是企業知識庫,AstrBot 都能在您的即時通訊軟體平台的工作流程中快速構建生產可用的 AI 應用程式。
|
||||
|
||||
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
|
||||
|
||||
## 主要功能
|
||||
|
||||
1. 💯 免費 & 開源。
|
||||
2. ✨ AI 大型模型對話,多模態,Agent,MCP,知識庫,人格設定。
|
||||
3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。
|
||||
4. 🌐 多平台:QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
|
||||
5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。
|
||||
6. 💻 WebUI 支援。
|
||||
7. 🌐 國際化(i18n)支援。
|
||||
|
||||
## 快速開始
|
||||
|
||||
#### Docker 部署(推薦 🥳)
|
||||
|
||||
推薦使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
請參閱官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
|
||||
|
||||
#### uv 部署
|
||||
|
||||
```bash
|
||||
uvx astrbot
|
||||
```
|
||||
|
||||
#### 寶塔面板部署
|
||||
|
||||
AstrBot 與寶塔面板合作,已上架至寶塔面板。
|
||||
|
||||
請參閱官方文件 [寶塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html)。
|
||||
|
||||
#### 1Panel 部署
|
||||
|
||||
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
|
||||
|
||||
請參閱官方文件 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html)。
|
||||
|
||||
#### 在雨雲上部署
|
||||
|
||||
AstrBot 已由雨雲官方上架至雲端應用程式平台,可一鍵部署。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
社群貢獻的部署方式。
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows 一鍵安裝器部署
|
||||
|
||||
請參閱官方文件 [使用 Windows 一鍵安裝器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html)。
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社群貢獻的部署方式。
|
||||
|
||||
請參閱官方文件 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html)。
|
||||
|
||||
#### 手動部署
|
||||
|
||||
首先安裝 uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
透過 Git Clone 安裝 AstrBot:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。
|
||||
|
||||
## 支援的訊息平台
|
||||
|
||||
**官方維護**
|
||||
|
||||
- QQ(官方平台 & OneBot)
|
||||
- Telegram
|
||||
- 企微應用 & 企微智慧機器人
|
||||
- 微信客服 & 微信公眾號
|
||||
- 飛書
|
||||
- 釘釘
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- Whatsapp(即將支援)
|
||||
- LINE(即將支援)
|
||||
|
||||
**社群維護**
|
||||
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## 支援的模型服務
|
||||
|
||||
**大型模型服務**
|
||||
|
||||
- OpenAI 及相容服務
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智譜 AI
|
||||
- DeepSeek
|
||||
- Ollama(本機部署)
|
||||
- LM Studio(本機部署)
|
||||
- [優雲智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小馬算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [矽基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派歐雲](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**LLMOps 平台**
|
||||
|
||||
- Dify
|
||||
- 阿里雲百煉應用
|
||||
- Coze
|
||||
|
||||
**語音轉文字服務**
|
||||
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**文字轉語音服務**
|
||||
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- 阿里雲百煉 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- 火山引擎 TTS
|
||||
|
||||
## ❤️ 貢獻
|
||||
|
||||
歡迎任何 Issues/Pull Requests!只需要將您的變更提交到此專案 :)
|
||||
|
||||
### 如何貢獻
|
||||
|
||||
您可以透過檢視問題或協助審核 PR(拉取請求)來貢獻。任何問題或 PR 都歡迎參與,以促進社群貢獻。當然,這些只是建議,您可以以任何方式進行貢獻。對於新功能的新增,請先透過 Issue 討論。
|
||||
|
||||
### 開發環境
|
||||
|
||||
AstrBot 使用 `ruff` 進行程式碼格式化和檢查。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌍 社群
|
||||
|
||||
### QQ 群組
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 開發者群:975206796
|
||||
|
||||
### Telegram 群組
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群組
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特別感謝所有 Contributors 和外掛開發者對 AstrBot 的貢獻 ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
此外,本專案的誕生離不開以下開源專案的幫助:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 偉大的貓貓框架
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本專案對您的生活 / 工作產生了幫助,或者您關注本專案的未來發展,請給專案 Star,這是我們維護這個開源專案的動力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "3.5.23"
|
||||
__version__ = "4.10.2"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -122,10 +122,12 @@ class ToolCall(BaseModel):
|
||||
extra_content: dict[str, Any] | None = None
|
||||
"""Extra metadata for the tool call."""
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.extra_content is None:
|
||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||
return super().model_dump(**kwargs)
|
||||
data.pop("extra_content", None)
|
||||
return data
|
||||
|
||||
|
||||
class ToolCallPart(BaseModel):
|
||||
@@ -145,22 +147,39 @@ class Message(BaseModel):
|
||||
"tool",
|
||||
]
|
||||
|
||||
content: str | list[ContentPart]
|
||||
content: str | list[ContentPart] | None = None
|
||||
"""The content of the message."""
|
||||
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
"""The tool calls of the message."""
|
||||
|
||||
tool_call_id: str | None = None
|
||||
"""The ID of the tool call."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content_required(self):
|
||||
# assistant + tool_calls is not None: allow content to be None
|
||||
if self.role == "assistant" and self.tool_calls is not None:
|
||||
return self
|
||||
|
||||
# other all cases: content is required
|
||||
if self.content is None:
|
||||
raise ValueError(
|
||||
"content is required unless role='assistant' and tool_calls is not None"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
|
||||
|
||||
class ToolCallMessageSegment(Message):
|
||||
"""A message segment representing a tool call."""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class UserMessageSegment(Message):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import TokenUsage
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStats:
|
||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||
start_time: float = 0.0
|
||||
end_time: float = 0.0
|
||||
time_to_first_token: float = 0.0
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"token_usage": self.token_usage.__dict__,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"time_to_first_token": self.time_to_first_token,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from .message import Message
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
@@ -12,6 +13,7 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
@@ -24,7 +26,7 @@ from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||
@@ -69,14 +71,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
self.stats = AgentStats()
|
||||
self.stats.start_time = time.time()
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
@@ -97,8 +110,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_resp_result = None
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
if self.stats.time_to_first_token == 0:
|
||||
self.stats.time_to_first_token = time.time() - self.stats.start_time
|
||||
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
@@ -122,6 +138,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
|
||||
if not llm_response.is_chunk and llm_response.usage:
|
||||
# only count the token usage of the final response for computation purpose
|
||||
self.stats.token_usage += llm_response.usage
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
@@ -133,6 +153,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self.stats.end_time = time.time()
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
@@ -147,11 +168,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
content=llm_resp.completion_text or "*No response*",
|
||||
),
|
||||
)
|
||||
try:
|
||||
@@ -176,22 +198,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
if result.type is None:
|
||||
# should not happen
|
||||
continue
|
||||
if result.type == "tool_direct_result":
|
||||
ar_type = "tool_call_result"
|
||||
else:
|
||||
ar_type = result.type
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
type=ar_type,
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
@@ -219,6 +238,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
|
||||
if not self.done():
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
# 拔掉所有工具
|
||||
if self.req:
|
||||
self.req.func_tool = None
|
||||
# 注入提示词
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
# 再执行最后一步
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -234,6 +272,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
yield MessageChain(
|
||||
type="tool_call",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"name": func_tool_name,
|
||||
"args": func_tool_args,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
@@ -307,7 +358,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=res.content[0].text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -329,7 +379,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=resource.text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
@@ -353,20 +402,34 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content="返回的数据类型不受支持",
|
||||
),
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||
logger.warning(
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具没有返回值或者将结果直接发送给了用户*",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||
f"Tool 返回了不支持的类型: {type(resp)}。",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -388,6 +451,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any, Generic
|
||||
|
||||
import jsonschema
|
||||
@@ -7,6 +7,8 @@ from deprecated import deprecated
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
@@ -38,7 +40,10 @@ class ToolSchema:
|
||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
handler: (
|
||||
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
|
||||
| None
|
||||
) = None
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
|
||||
@@ -6,8 +6,10 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
__pydantic_config__ = {"arbitrary_types_allowed": True}
|
||||
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
|
||||
@@ -2,13 +2,16 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
@@ -22,8 +25,25 @@ async def run_agent(
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
while step_idx < max_step + 1:
|
||||
step_idx += 1
|
||||
|
||||
if step_idx == max_step + 1:
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
if not agent_runner.done():
|
||||
# 拔掉所有工具
|
||||
if agent_runner.req:
|
||||
agent_runner.req.func_tool = None
|
||||
# 注入提示词
|
||||
agent_runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
@@ -32,16 +52,27 @@ async def run_agent(
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(resp.data["chain"])
|
||||
await astr_event.send(msg_chain)
|
||||
continue
|
||||
if astr_event.get_platform_id() == "webchat":
|
||||
await astr_event.send(msg_chain)
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use:
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
await astr_event.send(chain)
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
@@ -68,11 +99,33 @@ async def run_agent(
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
# send agent stats to webchat
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(
|
||||
MessageChain(
|
||||
type="agent_stats",
|
||||
chain=[Json(data=agent_runner.stats.to_dict())],
|
||||
)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
|
||||
error_llm_response = LLMResponse(
|
||||
role="err",
|
||||
completion_text=err_msg,
|
||||
)
|
||||
try:
|
||||
await agent_runner.agent_hooks.on_agent_done(
|
||||
agent_runner.run_context, error_llm_response
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in on_agent_done hook")
|
||||
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
|
||||
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[
|
||||
...,
|
||||
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
|
||||
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
|
||||
],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
@@ -205,12 +209,42 @@ async def call_local_llm_tool(
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
raise Exception(f"Tool execution ValueError: {e}") from e
|
||||
except TypeError as e:
|
||||
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
|
||||
try:
|
||||
sig = inspect.signature(handler)
|
||||
params = list(sig.parameters.values())
|
||||
# 跳过第一个参数(event 或 context)
|
||||
if params:
|
||||
params = params[1:]
|
||||
|
||||
param_strs = []
|
||||
for param in params:
|
||||
param_str = param.name
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
# 获取类型注解的字符串表示
|
||||
if isinstance(param.annotation, type):
|
||||
type_str = param.annotation.__name__
|
||||
else:
|
||||
type_str = str(param.annotation)
|
||||
param_str += f": {type_str}"
|
||||
if param.default != inspect.Parameter.empty:
|
||||
param_str += f" = {param.default!r}"
|
||||
param_strs.append(param_str)
|
||||
|
||||
handler_param_str = (
|
||||
", ".join(param_strs) if param_strs else "(no additional parameters)"
|
||||
)
|
||||
except Exception:
|
||||
handler_param_str = "(unable to inspect signature)"
|
||||
|
||||
raise Exception(
|
||||
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
"""
|
||||
|
||||
config_path: str
|
||||
default_config: dict
|
||||
schema: dict | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
|
||||
+352
-213
@@ -1,12 +1,22 @@
|
||||
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
|
||||
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.7.1"
|
||||
VERSION = "4.10.2"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
"qq_official_webhook",
|
||||
"weixin_official_account",
|
||||
"wecom",
|
||||
"wecom_ai_bot",
|
||||
"slack",
|
||||
"lark",
|
||||
]
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"config_version": 2,
|
||||
@@ -34,7 +44,15 @@ DEFAULT_CONFIG = {
|
||||
"interval": "1.5,3.5",
|
||||
"log_base": 2.6,
|
||||
"words_count_threshold": 150,
|
||||
"split_mode": "regex", # regex 或 words
|
||||
"regex": ".*?[。?!~…]+|.+$",
|
||||
"split_words": [
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
"…",
|
||||
], # 当 split_mode 为 words 时使用
|
||||
"content_cleanup_rule": "",
|
||||
},
|
||||
"no_permission_reply": True,
|
||||
@@ -44,7 +62,8 @@ DEFAULT_CONFIG = {
|
||||
"ignore_bot_self_message": False,
|
||||
"ignore_at_all": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_sources": [], # provider sources
|
||||
"provider": [], # models from provider_sources
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
@@ -73,8 +92,14 @@ DEFAULT_CONFIG = {
|
||||
"coze_agent_runner_provider_id": "",
|
||||
"dashscope_agent_runner_provider_id": "",
|
||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||
"reachability_check": False,
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
"file_extract": {
|
||||
"enable": False,
|
||||
"provider": "moonshotai",
|
||||
"moonshotai_api_key": "",
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -85,11 +110,13 @@ DEFAULT_CONFIG = {
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
"use_file_service": False,
|
||||
"trigger_probability": 1.0,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_provider_id": "",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
@@ -142,9 +169,26 @@ DEFAULT_CONFIG = {
|
||||
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||
"kb_agentic_mode": False,
|
||||
"disable_builtin_commands": False,
|
||||
}
|
||||
|
||||
|
||||
class ChatProviderTemplate(TypedDict):
|
||||
id: str
|
||||
provider_source_id: str
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
"id": "",
|
||||
"provide_source_id": "",
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
}
|
||||
|
||||
"""
|
||||
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
|
||||
|
||||
@@ -178,10 +222,12 @@ CONFIG_METADATA_2 = {
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"is_sandbox": False,
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"QQ 个人号(OneBot v11)": {
|
||||
"OneBot v11": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -208,6 +254,8 @@ CONFIG_METADATA_2 = {
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6194,
|
||||
"active_send_mode": False,
|
||||
@@ -222,6 +270,8 @@ CONFIG_METADATA_2 = {
|
||||
"encoding_aes_key": "",
|
||||
"kf_name": "",
|
||||
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6195,
|
||||
},
|
||||
@@ -234,6 +284,8 @@ CONFIG_METADATA_2 = {
|
||||
"wecom_ai_bot_name": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6198,
|
||||
},
|
||||
@@ -245,6 +297,10 @@ CONFIG_METADATA_2 = {
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"domain": "https://open.feishu.cn",
|
||||
"lark_connection_mode": "socket", # webhook, socket
|
||||
"webhook_uuid": "",
|
||||
"lark_encrypt_key": "",
|
||||
"lark_verification_token": "",
|
||||
},
|
||||
"钉钉(DingTalk)": {
|
||||
"id": "dingtalk",
|
||||
@@ -301,6 +357,8 @@ CONFIG_METADATA_2 = {
|
||||
"app_token": "",
|
||||
"signing_secret": "",
|
||||
"slack_connection_mode": "socket", # webhook, socket
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
"slack_webhook_host": "0.0.0.0",
|
||||
"slack_webhook_port": 6197,
|
||||
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
||||
@@ -336,6 +394,28 @@ CONFIG_METADATA_2 = {
|
||||
# "type": "string",
|
||||
# "options": ["fullscreen", "embedded"],
|
||||
# },
|
||||
"lark_connection_mode": {
|
||||
"description": "订阅方式",
|
||||
"type": "string",
|
||||
"options": ["socket", "webhook"],
|
||||
"labels": ["长连接模式", "推送至服务器模式"],
|
||||
},
|
||||
"lark_encrypt_key": {
|
||||
"description": "Encrypt Key",
|
||||
"type": "string",
|
||||
"hint": "用于解密飞书回调数据的加密密钥",
|
||||
"condition": {
|
||||
"lark_connection_mode": "webhook",
|
||||
},
|
||||
},
|
||||
"lark_verification_token": {
|
||||
"description": "Verification Token",
|
||||
"type": "string",
|
||||
"hint": "用于验证飞书回调请求的令牌",
|
||||
"condition": {
|
||||
"lark_connection_mode": "webhook",
|
||||
},
|
||||
},
|
||||
"is_sandbox": {
|
||||
"description": "沙箱模式",
|
||||
"type": "bool",
|
||||
@@ -380,16 +460,28 @@ CONFIG_METADATA_2 = {
|
||||
"description": "Slack Webhook Host",
|
||||
"type": "string",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
"condition": {
|
||||
"slack_connection_mode": "webhook",
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"slack_webhook_port": {
|
||||
"description": "Slack Webhook Port",
|
||||
"type": "int",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
"condition": {
|
||||
"slack_connection_mode": "webhook",
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"slack_webhook_path": {
|
||||
"description": "Slack Webhook Path",
|
||||
"type": "string",
|
||||
"hint": "Only valid when Slack connection mode is `webhook`.",
|
||||
"condition": {
|
||||
"slack_connection_mode": "webhook",
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"active_send_mode": {
|
||||
"description": "是否换用主动发送接口",
|
||||
@@ -580,6 +672,33 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
|
||||
},
|
||||
"port": {
|
||||
"description": "回调服务器端口",
|
||||
"type": "int",
|
||||
"hint": "回调服务器端口。留空则不启用回调服务器。",
|
||||
"condition": {
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"callback_server_host": {
|
||||
"description": "回调服务器主机",
|
||||
"type": "string",
|
||||
"hint": "回调服务器主机。留空则不启用回调服务器。",
|
||||
"condition": {
|
||||
"unified_webhook_mode": False,
|
||||
},
|
||||
},
|
||||
"unified_webhook_mode": {
|
||||
"description": "统一 Webhook 模式",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。",
|
||||
},
|
||||
"webhook_uuid": {
|
||||
"invisible": True,
|
||||
"description": "Webhook UUID",
|
||||
"type": "string",
|
||||
"hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -743,6 +862,7 @@ CONFIG_METADATA_2 = {
|
||||
"metadata": {
|
||||
"provider": {
|
||||
"type": "list",
|
||||
# provider sources templates
|
||||
"config_template": {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
@@ -753,107 +873,10 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"xai_native_search": False,
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Anthropic": {
|
||||
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
|
||||
"id": "claude",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Ollama": {
|
||||
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
|
||||
"id": "ollama_default",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://localhost:1234/v1",
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"Google Gemini": {
|
||||
"id": "google_gemini",
|
||||
"provider": "google",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -861,10 +884,6 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
@@ -875,13 +894,43 @@ CONFIG_METADATA_2 = {
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"budget": 0,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
||||
},
|
||||
"Anthropic": {
|
||||
"id": "anthropic",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
},
|
||||
"Moonshot": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
"xai_native_search": False,
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"id": "deepseek",
|
||||
"provider": "deepseek",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -889,13 +938,75 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Zhipu": {
|
||||
"id": "zhipu",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure_openai",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Ollama": {
|
||||
"id": "ollama",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://127.0.0.1:11434/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://127.0.0.1:1234/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"ModelStack": {
|
||||
"id": "modelstack",
|
||||
"provider": "modelstack",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://modelstack.app/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Gemini_OpenAI_API": {
|
||||
"id": "google_gemini_openai",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"id": "groq",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -903,13 +1014,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
@@ -920,12 +1025,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"硅基流动": {
|
||||
"SiliconFlow": {
|
||||
"id": "siliconflow",
|
||||
"provider": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -934,15 +1036,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"model_config": {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"PPIO": {
|
||||
"id": "ppio",
|
||||
"provider": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -951,14 +1047,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
"TokenPony": {
|
||||
"id": "tokenpony",
|
||||
"provider": "tokenpony",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -967,14 +1058,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"Compshare": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -983,42 +1069,18 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.modelverse.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
@@ -1033,7 +1095,6 @@ CONFIG_METADATA_2 = {
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
@@ -1064,20 +1125,6 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"provider": "fastgpt",
|
||||
@@ -1101,7 +1148,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"Whisper(Local)": {
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1110,7 +1156,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "tiny",
|
||||
},
|
||||
"SenseVoice(Local)": {
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider": "sensevoice",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1132,7 +1177,6 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": "20",
|
||||
},
|
||||
"Edge TTS": {
|
||||
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"provider": "microsoft",
|
||||
"type": "edge_tts",
|
||||
@@ -1348,6 +1392,10 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"provider_source_id": {
|
||||
"invisible": True,
|
||||
"type": "string",
|
||||
},
|
||||
"xai_native_search": {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
@@ -1718,13 +1766,24 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"description": "Gemini思考设置",
|
||||
"description": "Thinking Config",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "思考预算",
|
||||
"description": "Thinking Budget",
|
||||
"type": "int",
|
||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
|
||||
},
|
||||
"level": {
|
||||
"description": "Thinking Level",
|
||||
"type": "string",
|
||||
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
|
||||
"options": [
|
||||
"MINIMAL",
|
||||
"LOW",
|
||||
"MEDIUM",
|
||||
"HIGH",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1905,7 +1964,6 @@ CONFIG_METADATA_2 = {
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"hint": "模型提供商名字。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商种类",
|
||||
@@ -1925,29 +1983,15 @@ CONFIG_METADATA_2 = {
|
||||
"description": "API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "提供商 API Key。",
|
||||
},
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
},
|
||||
"model_config": {
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {
|
||||
"description": "模型名称",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "模型最大输出长度(tokens)",
|
||||
"type": "int",
|
||||
},
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
},
|
||||
"model": {
|
||||
"description": "模型 ID",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
@@ -2067,6 +2111,20 @@ CONFIG_METADATA_2 = {
|
||||
"tool_call_timeout": {
|
||||
"type": "int",
|
||||
},
|
||||
"file_extract": {
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"type": "bool",
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
},
|
||||
"moonshotai_api_key": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
@@ -2095,6 +2153,9 @@ CONFIG_METADATA_2 = {
|
||||
"use_file_service": {
|
||||
"type": "bool",
|
||||
},
|
||||
"trigger_probability": {
|
||||
"type": "float",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -2109,6 +2170,9 @@ CONFIG_METADATA_2 = {
|
||||
"image_caption": {
|
||||
"type": "bool",
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"type": "string",
|
||||
},
|
||||
@@ -2302,6 +2366,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.trigger_probability": {
|
||||
"description": "TTS 触发概率",
|
||||
"type": "float",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.image_caption_prompt": {
|
||||
"description": "图片转述提示词",
|
||||
"type": "text",
|
||||
@@ -2398,6 +2470,36 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
# "file_extract": {
|
||||
# "description": "文档解析能力 [beta]",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "provider_settings.file_extract.enable": {
|
||||
# "description": "启用文档解析能力",
|
||||
# "type": "bool",
|
||||
# },
|
||||
# "provider_settings.file_extract.provider": {
|
||||
# "description": "文档解析提供商",
|
||||
# "type": "string",
|
||||
# "options": ["moonshotai"],
|
||||
# "condition": {
|
||||
# "provider_settings.file_extract.enable": True,
|
||||
# },
|
||||
# },
|
||||
# "provider_settings.file_extract.moonshotai_api_key": {
|
||||
# "description": "Moonshot AI API Key",
|
||||
# "type": "string",
|
||||
# "condition": {
|
||||
# "provider_settings.file_extract.provider": "moonshotai",
|
||||
# "provider_settings.file_extract.enable": True,
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
# "condition": {
|
||||
# "provider_settings.agent_runner_type": "local",
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2492,6 +2594,11 @@ CONFIG_METADATA_3 = {
|
||||
"description": "开启 TTS 时同时输出语音和文字内容",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.reachability_check": {
|
||||
"description": "提供商可达性检测",
|
||||
"type": "bool",
|
||||
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
@@ -2545,6 +2652,11 @@ CONFIG_METADATA_3 = {
|
||||
"description": "只 @ 机器人是否触发等待",
|
||||
"type": "bool",
|
||||
},
|
||||
"disable_builtin_commands": {
|
||||
"description": "禁用自带指令",
|
||||
"type": "bool",
|
||||
"hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"whitelist": {
|
||||
@@ -2759,9 +2871,26 @@ CONFIG_METADATA_3 = {
|
||||
"description": "分段回复字数阈值",
|
||||
"type": "int",
|
||||
},
|
||||
"platform_settings.segmented_reply.split_mode": {
|
||||
"description": "分段模式",
|
||||
"type": "string",
|
||||
"options": ["regex", "words"],
|
||||
"labels": ["正则表达式", "分段词列表"],
|
||||
},
|
||||
"platform_settings.segmented_reply.regex": {
|
||||
"description": "分段正则表达式",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"platform_settings.segmented_reply.split_mode": "regex",
|
||||
},
|
||||
},
|
||||
"platform_settings.segmented_reply.split_words": {
|
||||
"description": "分段词列表",
|
||||
"type": "list",
|
||||
"hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
|
||||
"condition": {
|
||||
"platform_settings.segmented_reply.split_mode": "words",
|
||||
},
|
||||
},
|
||||
"platform_settings.segmented_reply.content_cleanup_rule": {
|
||||
"description": "内容过滤正则表达式",
|
||||
@@ -2785,7 +2914,16 @@ CONFIG_METADATA_3 = {
|
||||
"provider_ltm_settings.image_caption": {
|
||||
"description": "自动理解图片",
|
||||
"type": "bool",
|
||||
"hint": "需要设置默认图片转述模型。",
|
||||
"hint": "需要设置群聊图片转述模型。",
|
||||
},
|
||||
"provider_ltm_settings.image_caption_provider_id": {
|
||||
"description": "群聊图片转述模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
|
||||
"condition": {
|
||||
"provider_ltm_settings.image_caption": True,
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings.active_reply.enable": {
|
||||
"description": "主动回复",
|
||||
@@ -2803,6 +2941,7 @@ CONFIG_METADATA_3 = {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"hint": "0.0-1.0 之间的数值",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_ltm_settings.active_reply.enable": True,
|
||||
},
|
||||
|
||||
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
|
||||
"_special",
|
||||
"invisible",
|
||||
"options",
|
||||
"slider",
|
||||
]:
|
||||
if attr in field_data:
|
||||
field_result[attr] = field_data[attr]
|
||||
|
||||
@@ -33,6 +33,7 @@ from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.llm_metadata import update_llm_metadata
|
||||
from astrbot.core.utils.migra_helper import migra
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
@@ -185,6 +186,8 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
asyncio.create_task(update_llm_metadata())
|
||||
|
||||
def _load(self) -> None:
|
||||
"""加载事件总线和任务并初始化."""
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
@@ -197,7 +200,7 @@ class AstrBotCoreLifecycle:
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
extra_tasks = []
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
|
||||
|
||||
tasks_ = [event_bus_task, *extra_tasks]
|
||||
for task in tasks_:
|
||||
|
||||
+104
-4
@@ -5,11 +5,12 @@ from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -32,7 +33,7 @@ class BaseDatabase(abc.ABC):
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
@@ -173,7 +174,7 @@ class BaseDatabase(abc.ABC):
|
||||
content: dict,
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
) -> PlatformMessageHistory:
|
||||
"""Insert a new platform message history record."""
|
||||
...
|
||||
|
||||
@@ -198,6 +199,14 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get platform message history for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_message_history_by_id(
|
||||
self,
|
||||
message_id: int,
|
||||
) -> PlatformMessageHistory | None:
|
||||
"""Get a platform message history record by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_attachment(
|
||||
self,
|
||||
@@ -213,6 +222,27 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get an attachment by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]:
|
||||
"""Get multiple attachments by their IDs."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||
"""Delete an attachment by its ID.
|
||||
|
||||
Returns True if the attachment was deleted, False if it was not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
"""Delete multiple attachments by their IDs.
|
||||
|
||||
Returns the number of attachments deleted.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_persona(
|
||||
self,
|
||||
@@ -286,6 +316,76 @@ class BaseDatabase(abc.ABC):
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
"""Get all stored command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
|
||||
"""Fetch a single command configuration by handler."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
"""Create or update a command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
"""Delete a single command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
"""Bulk delete command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
"""List recorded command conflict entries."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
"""Create or update a conflict record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
"""Delete conflict records."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
|
||||
@@ -70,6 +70,7 @@ async def migration_conversation_table(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
@@ -207,6 +208,7 @@ async def migration_webchat_data(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
|
||||
@@ -127,7 +127,7 @@ class SQLiteDatabase:
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: tuple = None):
|
||||
def _exec_sql(self, sql: str, params: tuple | None = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -224,9 +224,11 @@ class SQLiteDatabase:
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
return Stats(platform)
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
def get_conversation_by_user_id(
|
||||
self, user_id: str, cid: str
|
||||
) -> Conversation | None:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -258,7 +260,7 @@ class SQLiteDatabase:
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> tuple:
|
||||
def get_conversations(self, user_id: str) -> list[Conversation]:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
|
||||
+75
-15
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats" # type: ignore
|
||||
__tablename__: str = "platform_stats"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations" # type: ignore
|
||||
__tablename__: str = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
inner_conversation_id: int | None = Field(
|
||||
default=None,
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas" # type: ignore
|
||||
__tablename__: str = "personas"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences" # type: ignore
|
||||
__tablename__: str = "preferences"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history" # type: ignore
|
||||
__tablename__: str = "platform_message_history"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
|
||||
Each session can have multiple conversations (对话) associated with it.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_sessions" # type: ignore
|
||||
__tablename__: str = "platform_sessions"
|
||||
|
||||
inner_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments" # type: ignore
|
||||
__tablename__: str = "attachments"
|
||||
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -233,6 +234,65 @@ class Attachment(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class CommandConfig(SQLModel, table=True):
|
||||
"""Per-command configuration overrides for dashboard management."""
|
||||
|
||||
__tablename__ = "command_configs" # type: ignore
|
||||
|
||||
handler_full_name: str = Field(
|
||||
primary_key=True,
|
||||
max_length=512,
|
||||
)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
module_path: str = Field(nullable=False, max_length=255)
|
||||
original_command: str = Field(nullable=False, max_length=255)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
enabled: bool = Field(default=True, nullable=False)
|
||||
keep_original_alias: bool = Field(default=False, nullable=False)
|
||||
conflict_key: str | None = Field(default=None, max_length=255)
|
||||
resolution_strategy: str | None = Field(default=None, max_length=64)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_managed: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class CommandConflict(SQLModel, table=True):
|
||||
"""Conflict tracking for duplicated command names."""
|
||||
|
||||
__tablename__ = "command_conflicts" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conflict_key: str = Field(nullable=False, max_length=255)
|
||||
handler_full_name: str = Field(nullable=False, max_length=512)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
status: str = Field(default="pending", max_length=32)
|
||||
resolution: str | None = Field(default=None, max_length=64)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_generated: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conflict_key",
|
||||
"handler_full_name",
|
||||
name="uix_conflict_handler",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
@@ -261,17 +321,17 @@ class Personality(TypedDict):
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
prompt: str
|
||||
name: str
|
||||
begin_dialogs: list[str]
|
||||
mood_imitation_dialogs: list[str]
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
tools: list[str] | None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
_begin_dialogs_processed: list[dict]
|
||||
_mood_imitation_dialogs_processed: str
|
||||
|
||||
|
||||
# ====
|
||||
|
||||
+296
-1
@@ -1,14 +1,18 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import typing as T
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -25,6 +29,7 @@ from astrbot.core.db.po import (
|
||||
)
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
TxResult = T.TypeVar("TxResult")
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
@@ -105,8 +110,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
text("""
|
||||
SELECT * FROM platform_stats
|
||||
WHERE timestamp >= :start_time
|
||||
ORDER BY timestamp DESC
|
||||
GROUP BY platform_id
|
||||
ORDER BY timestamp DESC
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
@@ -449,6 +454,18 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query.offset(offset).limit(page_size))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_platform_message_history_by_id(
|
||||
self, message_id: int
|
||||
) -> PlatformMessageHistory | None:
|
||||
"""Get a platform message history record by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(PlatformMessageHistory).where(
|
||||
PlatformMessageHistory.id == message_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def insert_attachment(self, path, type, mime_type):
|
||||
"""Insert a new attachment record."""
|
||||
async with self.get_db() as session:
|
||||
@@ -470,6 +487,48 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_attachments(self, attachment_ids: list[str]) -> list:
|
||||
"""Get multiple attachments by their IDs."""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||
"""Delete an attachment by its ID.
|
||||
|
||||
Returns True if the attachment was deleted, False if it was not found.
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id) == attachment_id
|
||||
)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
"""Delete multiple attachments by their IDs.
|
||||
|
||||
Returns the number of attachments deleted.
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return 0
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount
|
||||
|
||||
async def insert_persona(
|
||||
self,
|
||||
persona_id,
|
||||
@@ -615,6 +674,242 @@ class SQLiteDatabase(BaseDatabase):
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Command Configuration & Conflict Tracking
|
||||
# ====
|
||||
|
||||
async def _run_in_tx(
|
||||
self,
|
||||
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
||||
) -> TxResult:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
return await fn(session)
|
||||
|
||||
@staticmethod
|
||||
def _apply_updates(model, **updates) -> None:
|
||||
for field, value in updates.items():
|
||||
if value is not None:
|
||||
setattr(model, field, value)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_config(
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
return CommandConfig(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=True if enabled is None else enabled,
|
||||
keep_original_alias=False
|
||||
if keep_original_alias is None
|
||||
else keep_original_alias,
|
||||
conflict_key=conflict_key or original_command,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=bool(auto_managed),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_conflict(
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
return CommandConflict(
|
||||
conflict_key=conflict_key,
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
status=status or "pending",
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=bool(auto_generated),
|
||||
)
|
||||
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(select(CommandConfig))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
) -> CommandConfig | None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
return await session.get(CommandConfig, handler_full_name)
|
||||
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
async def _op(session: AsyncSession) -> CommandConfig:
|
||||
config = await session.get(CommandConfig, handler_full_name)
|
||||
if not config:
|
||||
config = self._new_command_config(
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
module_path,
|
||||
original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
session.add(config)
|
||||
else:
|
||||
self._apply_updates(
|
||||
config,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
await self.delete_command_configs([handler_full_name])
|
||||
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
if not handler_full_names:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConfig).where(
|
||||
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
||||
),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(CommandConflict)
|
||||
if status:
|
||||
query = query.where(CommandConflict.status == status)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
async def _op(session: AsyncSession) -> CommandConflict:
|
||||
result = await session.execute(
|
||||
select(CommandConflict).where(
|
||||
CommandConflict.conflict_key == conflict_key,
|
||||
CommandConflict.handler_full_name == handler_full_name,
|
||||
),
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
record = self._new_command_conflict(
|
||||
conflict_key,
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
self._apply_updates(
|
||||
record,
|
||||
plugin_name=plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
@@ -90,4 +90,6 @@ class EmbeddingStorage:
|
||||
path (str): 保存索引的路径
|
||||
|
||||
"""
|
||||
if self.index is None:
|
||||
return
|
||||
faiss.write_index(self.index, self.path)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EventBus:
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
@@ -40,6 +40,11 @@ class EventBus:
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
if not scheduler:
|
||||
logger.error(
|
||||
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
|
||||
)
|
||||
continue
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
|
||||
@@ -166,7 +166,11 @@ class RetrievalManager:
|
||||
# 5. Rerank
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
vec_db = kb_options[kb_id]["vec_db"]
|
||||
if not isinstance(vec_db, FaissVecDB):
|
||||
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
|
||||
continue
|
||||
|
||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||
if (
|
||||
vec_db
|
||||
|
||||
+2
-1
@@ -24,6 +24,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from collections import deque
|
||||
|
||||
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
|
||||
self.log_broker.publish(
|
||||
{
|
||||
"level": record.levelname,
|
||||
"time": record.asctime,
|
||||
"time": time.time(),
|
||||
"data": log_entry,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def toDict(self):
|
||||
data = {}
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
|
||||
id: int | None = 0 # 忽略
|
||||
name: str | None = "" # qq昵称
|
||||
uin: str | None = "0" # qq号
|
||||
content: list[BaseMessageComponent] | None = []
|
||||
content: list[BaseMessageComponent] = []
|
||||
seq: str | list | None = "" # 忽略
|
||||
time: int | None = 0 # 忽略
|
||||
|
||||
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
async def to_dict(self) -> dict:
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
@@ -626,12 +629,11 @@ class Nodes(BaseMessageComponent):
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type = ComponentType.Json
|
||||
data: str | dict
|
||||
resid: int | None = 0
|
||||
data: dict
|
||||
|
||||
def __init__(self, data, **_):
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
def __init__(self, data: str | dict, **_):
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
super().__init__(data=data, **_)
|
||||
|
||||
|
||||
@@ -714,15 +716,23 @@ class File(BaseMessageComponent):
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return os.path.abspath(self.file_)
|
||||
if self.file_:
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if not self.url:
|
||||
raise ValueError("Download failed: No URL provided in File component.")
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
if self.name:
|
||||
name, ext = os.path.splitext(self.name)
|
||||
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
else:
|
||||
filename = f"{uuid.uuid4().hex}"
|
||||
file_path = os.path.join(download_dir, filename)
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
|
||||
@@ -98,8 +98,8 @@ class PersonaManager:
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
|
||||
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
check_text: str | None = None,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""检查内容安全"""
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
ok, info = self.strategy_selector.check(text)
|
||||
|
||||
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
@@ -91,6 +91,7 @@ async def call_event_hook(
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
assert inspect.iscoroutinefunction(handler.handler)
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from astrbot.core import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
@@ -22,6 +22,7 @@ from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -56,6 +57,13 @@ class InternalAgentSubStage(Stage):
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
file_extract_conf: dict = settings.get("file_extract", {})
|
||||
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
|
||||
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
|
||||
self.file_extract_msh_api_key: str = file_extract_conf.get(
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -114,6 +122,50 @@ class InternalAgentSubStage(Stage):
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
async def _apply_file_extract(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply file extract to the provider request"""
|
||||
file_paths = []
|
||||
file_names = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
file_paths.append(await comp.get_file())
|
||||
file_names.append(comp.name)
|
||||
elif isinstance(comp, Reply) and comp.chain:
|
||||
for reply_comp in comp.chain:
|
||||
if isinstance(reply_comp, File):
|
||||
file_paths.append(await reply_comp.get_file())
|
||||
file_names.append(reply_comp.name)
|
||||
if not file_paths:
|
||||
return
|
||||
if not req.prompt:
|
||||
req.prompt = "总结一下文件里面讲了什么?"
|
||||
if self.file_extract_prov == "moonshotai":
|
||||
if not self.file_extract_msh_api_key:
|
||||
logger.error("Moonshot AI API key for file extract is not set")
|
||||
return
|
||||
file_contents = await asyncio.gather(
|
||||
*[
|
||||
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
|
||||
for file_path in file_paths
|
||||
]
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
|
||||
return
|
||||
|
||||
# add file extract results to contexts
|
||||
for file_content, file_name in zip(file_contents, file_names):
|
||||
req.contexts.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
@@ -269,7 +321,12 @@ class InternalAgentSubStage(Stage):
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text or "*No response*",
|
||||
}
|
||||
)
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
@@ -346,6 +403,17 @@ class InternalAgentSubStage(Stage):
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
@@ -356,10 +424,6 @@ class InternalAgentSubStage(Stage):
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
|
||||
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
||||
DashscopeAgentRunner,
|
||||
@@ -57,7 +57,7 @@ async def run_third_party_agent(
|
||||
logger.error(f"Third party agent runner error: {e}")
|
||||
err_msg = (
|
||||
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
|
||||
f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
f"错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
)
|
||||
yield MessageChain().message(err_msg)
|
||||
|
||||
@@ -88,12 +88,15 @@ class ThirdPartyAgentSubStage(Stage):
|
||||
return
|
||||
|
||||
self.prov_cfg: dict = next(
|
||||
(p for p in self.conf["provider"] if p["id"] == self.prov_id),
|
||||
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id),
|
||||
{},
|
||||
)
|
||||
if not self.prov_id or not self.prov_cfg:
|
||||
if not self.prov_id:
|
||||
logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。")
|
||||
return
|
||||
if not self.prov_cfg:
|
||||
logger.error(
|
||||
"Third Party Agent Runner provider ID is not configured properly."
|
||||
f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from ..stage import Stage
|
||||
|
||||
class StarRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
||||
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
|
||||
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
|
||||
self.ctx = ctx
|
||||
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||
"activated_handlers",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
@@ -61,14 +60,7 @@ class ProcessStage(Stage):
|
||||
):
|
||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||
if (
|
||||
event.get_result() and not event.get_result().is_stopped()
|
||||
event.get_result() and not event.is_stopped()
|
||||
) or not event.get_result():
|
||||
# 事件没有终止传播
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -117,7 +117,9 @@ class RespondStage(Stage):
|
||||
if not self.enable_seg:
|
||||
return False
|
||||
|
||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||
if (result := event.get_result()) is None:
|
||||
return False
|
||||
if self.only_llm_result and not result.is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
@@ -156,7 +158,11 @@ class RespondStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
if event.get_extra("_streaming_finished", False):
|
||||
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
event.set_extra("_streaming_finished", True)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
@@ -185,7 +191,7 @@ class RespondStage(Stage):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
result.chain[idx] = component
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -6,6 +7,7 @@ from collections.abc import AsyncGenerator
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
@@ -41,6 +43,18 @@ class ResultDecorateStage(Stage):
|
||||
"forward_threshold"
|
||||
]
|
||||
|
||||
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
|
||||
"trigger_probability",
|
||||
1,
|
||||
)
|
||||
try:
|
||||
self.tts_trigger_probability = max(
|
||||
0.0,
|
||||
min(float(trigger_probability), 1.0),
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
self.tts_trigger_probability = 1.0
|
||||
|
||||
# 分段回复
|
||||
self.words_count_threshold = int(
|
||||
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||
@@ -53,7 +67,22 @@ class ResultDecorateStage(Stage):
|
||||
self.only_llm_result = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["only_llm_result"]
|
||||
self.split_mode = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_mode", "regex")
|
||||
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
|
||||
self.split_words = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_words", ["。", "?", "!", "~", "…"])
|
||||
if self.split_words:
|
||||
escaped_words = sorted(
|
||||
[re.escape(word) for word in self.split_words], key=len, reverse=True
|
||||
)
|
||||
self.split_words_pattern = re.compile(
|
||||
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
|
||||
)
|
||||
else:
|
||||
self.split_words_pattern = None
|
||||
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["content_cleanup_rule"]
|
||||
@@ -69,6 +98,28 @@ class ResultDecorateStage(Stage):
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
def _split_text_by_words(self, text: str) -> list[str]:
|
||||
"""使用分段词列表分段文本"""
|
||||
if not self.split_words_pattern:
|
||||
return [text]
|
||||
|
||||
segments = self.split_words_pattern.findall(text)
|
||||
result = []
|
||||
for seg in segments:
|
||||
if isinstance(seg, tuple):
|
||||
content = seg[0]
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
for word in self.split_words:
|
||||
if content.endswith(word):
|
||||
content = content[: -len(word)]
|
||||
break
|
||||
if content.strip():
|
||||
result.append(content)
|
||||
elif seg and seg.strip():
|
||||
result.append(seg)
|
||||
return result if result else [text]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -93,11 +144,13 @@ class ResultDecorateStage(Stage):
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text += comp.text
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
|
||||
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
@@ -114,7 +167,8 @@ class ResultDecorateStage(Stage):
|
||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
|
||||
)
|
||||
await handler.handler(event)
|
||||
if event.get_result() is None or not event.get_result().chain:
|
||||
|
||||
if (result := event.get_result()) is None or not result.chain:
|
||||
logger.debug(
|
||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
|
||||
)
|
||||
@@ -161,21 +215,27 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
# 根据 split_mode 选择分段方式
|
||||
if self.split_mode == "words":
|
||||
split_response = self._split_text_by_words(comp.text)
|
||||
else: # regex 模式
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
@@ -199,7 +259,14 @@ class ResultDecorateStage(Stage):
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
if not tts_provider:
|
||||
should_tts = self.tts_trigger_probability >= 1.0 or (
|
||||
self.tts_trigger_probability > 0.0
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
)
|
||||
|
||||
if not should_tts:
|
||||
logger.debug("跳过 TTS:触发概率未命中。")
|
||||
elif not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
||||
WecomAIBotMessageEvent,
|
||||
)
|
||||
|
||||
from . import STAGES_ORDER
|
||||
from .context import PipelineContext
|
||||
@@ -78,7 +82,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -50,6 +50,9 @@ class WakingCheckStage(Stage):
|
||||
"ignore_at_all",
|
||||
False,
|
||||
)
|
||||
self.disable_builtin_commands = self.ctx.astrbot_config.get(
|
||||
"disable_builtin_commands", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
@@ -131,6 +134,13 @@ class WakingCheckStage(Stage):
|
||||
EventType.AdapterMessageEvent,
|
||||
plugins_name=event.plugins_name,
|
||||
):
|
||||
if (
|
||||
self.disable_builtin_commands
|
||||
and handler.handler_module_path == "packages.builtin_commands.main"
|
||||
):
|
||||
logger.debug("skipping builtin command")
|
||||
continue
|
||||
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
|
||||
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
def get_sender_name(self) -> str:
|
||||
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
||||
return self.message_obj.sender.nickname
|
||||
if isinstance(self.message_obj.sender.nickname, str):
|
||||
return self.message_obj.sender.nickname
|
||||
return ""
|
||||
|
||||
def set_extra(self, key, value):
|
||||
"""设置额外的信息。"""
|
||||
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self.call_llm = call_llm
|
||||
|
||||
def get_result(self) -> MessageEventResult:
|
||||
def get_result(self) -> MessageEventResult | None:
|
||||
"""获取消息事件的结果。"""
|
||||
return self._result
|
||||
|
||||
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
session_id: str = None,
|
||||
session_id: str = "",
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
system_prompt: str = "",
|
||||
|
||||
@@ -54,7 +54,7 @@ class AstrBotMessage:
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group: Group # 群组
|
||||
group: Group | None # 群组
|
||||
sender: MessageMember # 发送者
|
||||
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
@@ -78,7 +78,7 @@ class AstrBotMessage:
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str):
|
||||
def group_id(self, value: str | None):
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
|
||||
@@ -5,8 +5,9 @@ from asyncio import Queue
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
|
||||
|
||||
from .platform import Platform
|
||||
from .platform import Platform, PlatformStatus
|
||||
from .register import platform_cls_map
|
||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
@@ -16,8 +17,9 @@ class PlatformManager:
|
||||
self.platform_insts: list[Platform] = []
|
||||
"""加载的 Platform 的实例"""
|
||||
|
||||
self._inst_map = {}
|
||||
self._inst_map: dict[str, dict] = {}
|
||||
|
||||
self.astrbot_config = config
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
|
||||
@@ -29,6 +31,8 @@ class PlatformManager:
|
||||
"""初始化所有平台适配器"""
|
||||
for platform in self.platforms_config:
|
||||
try:
|
||||
if ensure_platform_webhook_config(platform):
|
||||
self.astrbot_config.save_config()
|
||||
await self.load_platform(platform)
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
|
||||
@@ -37,7 +41,10 @@ class PlatformManager:
|
||||
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
|
||||
self.platform_insts.append(webchat_inst)
|
||||
asyncio.create_task(
|
||||
self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")),
|
||||
self._task_wrapper(
|
||||
asyncio.create_task(webchat_inst.run(), name="webchat"),
|
||||
platform=webchat_inst,
|
||||
),
|
||||
)
|
||||
|
||||
async def load_platform(self, platform_config: dict):
|
||||
@@ -107,7 +114,7 @@ class PlatformManager:
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。",
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")
|
||||
@@ -131,6 +138,7 @@ class PlatformManager:
|
||||
inst.run(),
|
||||
name=f"platform_{platform_config['type']}_{platform_config['id']}",
|
||||
),
|
||||
platform=inst,
|
||||
),
|
||||
)
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
@@ -145,17 +153,28 @@ class PlatformManager:
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
|
||||
# 设置平台状态为运行中
|
||||
if platform:
|
||||
platform.status = PlatformStatus.RUNNING
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if platform:
|
||||
platform.status = PlatformStatus.STOPPED
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
tb_str = traceback.format_exc()
|
||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||
for line in traceback.format_exc().split("\n"):
|
||||
for line in tb_str.split("\n"):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
# 记录错误到平台实例
|
||||
if platform:
|
||||
platform.record_error(error_msg, tb_str)
|
||||
|
||||
async def reload(self, platform_config: dict):
|
||||
await self.terminate_platform(platform_config["id"])
|
||||
if platform_config["enable"]:
|
||||
@@ -172,9 +191,9 @@ class PlatformManager:
|
||||
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
|
||||
|
||||
# client_id = self._inst_map.pop(platform_id, None)
|
||||
info = self._inst_map.pop(platform_id, None)
|
||||
info = self._inst_map.pop(platform_id)
|
||||
client_id = info["client_id"]
|
||||
inst = info["inst"]
|
||||
inst: Platform = info["inst"]
|
||||
try:
|
||||
self.platform_insts.remove(
|
||||
next(
|
||||
@@ -196,3 +215,46 @@ class PlatformManager:
|
||||
|
||||
def get_insts(self):
|
||||
return self.platform_insts
|
||||
|
||||
def get_all_stats(self) -> dict:
|
||||
"""获取所有平台的统计信息
|
||||
|
||||
Returns:
|
||||
包含所有平台统计信息的字典
|
||||
"""
|
||||
stats_list = []
|
||||
total_errors = 0
|
||||
running_count = 0
|
||||
error_count = 0
|
||||
|
||||
for inst in self.platform_insts:
|
||||
try:
|
||||
stat = inst.get_stats()
|
||||
stats_list.append(stat)
|
||||
total_errors += stat.get("error_count", 0)
|
||||
if stat.get("status") == PlatformStatus.RUNNING.value:
|
||||
running_count += 1
|
||||
elif stat.get("status") == PlatformStatus.ERROR.value:
|
||||
error_count += 1
|
||||
except Exception as e:
|
||||
# 如果获取统计信息失败,记录基本信息
|
||||
logger.warning(f"获取平台统计信息失败: {e}")
|
||||
stats_list.append(
|
||||
{
|
||||
"id": getattr(inst, "config", {}).get("id", "unknown"),
|
||||
"type": "unknown",
|
||||
"status": "unknown",
|
||||
"error_count": 0,
|
||||
"last_error": None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"platforms": stats_list,
|
||||
"summary": {
|
||||
"total": len(stats_list),
|
||||
"running": running_count,
|
||||
"error": error_count,
|
||||
"total_errors": total_errors,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -12,15 +15,100 @@ from .message_session import MessageSesion
|
||||
from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
class PlatformStatus(Enum):
|
||||
"""平台运行状态"""
|
||||
|
||||
PENDING = "pending" # 待启动
|
||||
RUNNING = "running" # 运行中
|
||||
ERROR = "error" # 发生错误
|
||||
STOPPED = "stopped" # 已停止
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformError:
|
||||
"""平台错误信息"""
|
||||
|
||||
message: str
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
traceback: str | None = None
|
||||
|
||||
|
||||
class Platform(abc.ABC):
|
||||
def __init__(self, event_queue: Queue):
|
||||
def __init__(self, config: dict, event_queue: Queue):
|
||||
super().__init__()
|
||||
# 平台配置
|
||||
self.config = config
|
||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||
self._event_queue = event_queue
|
||||
self.client_self_id = uuid.uuid4().hex
|
||||
|
||||
# 平台运行状态
|
||||
self._status: PlatformStatus = PlatformStatus.PENDING
|
||||
self._errors: list[PlatformError] = []
|
||||
self._started_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def status(self) -> PlatformStatus:
|
||||
"""获取平台运行状态"""
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, value: PlatformStatus):
|
||||
"""设置平台运行状态"""
|
||||
self._status = value
|
||||
if value == PlatformStatus.RUNNING and self._started_at is None:
|
||||
self._started_at = datetime.now()
|
||||
|
||||
@property
|
||||
def errors(self) -> list[PlatformError]:
|
||||
"""获取错误列表"""
|
||||
return self._errors
|
||||
|
||||
@property
|
||||
def last_error(self) -> PlatformError | None:
|
||||
"""获取最近的错误"""
|
||||
return self._errors[-1] if self._errors else None
|
||||
|
||||
def record_error(self, message: str, traceback_str: str | None = None):
|
||||
"""记录一个错误"""
|
||||
self._errors.append(PlatformError(message=message, traceback=traceback_str))
|
||||
self._status = PlatformStatus.ERROR
|
||||
|
||||
def clear_errors(self):
|
||||
"""清除错误记录"""
|
||||
self._errors.clear()
|
||||
if self._status == PlatformStatus.ERROR:
|
||||
self._status = PlatformStatus.RUNNING
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
"""是否正在使用统一 Webhook 模式"""
|
||||
return bool(
|
||||
self.config.get("unified_webhook_mode", False)
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取平台统计信息"""
|
||||
meta = self.meta()
|
||||
return {
|
||||
"id": meta.id or self.config.get("id"),
|
||||
"type": meta.name,
|
||||
"display_name": meta.adapter_display_name or meta.name,
|
||||
"status": self._status.value,
|
||||
"started_at": self._started_at.isoformat() if self._started_at else None,
|
||||
"error_count": len(self._errors),
|
||||
"last_error": {
|
||||
"message": self.last_error.message,
|
||||
"timestamp": self.last_error.timestamp.isoformat(),
|
||||
"traceback": self.last_error.traceback,
|
||||
}
|
||||
if self.last_error
|
||||
else None,
|
||||
"unified_webhook": self.unified_webhook(),
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self) -> Awaitable[Any]:
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -36,7 +124,7 @@ class Platform(abc.ABC):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
) -> Awaitable[Any]:
|
||||
) -> None:
|
||||
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||
|
||||
异步方法。
|
||||
@@ -49,3 +137,20 @@ class Platform(abc.ABC):
|
||||
|
||||
def get_client(self):
|
||||
"""获取平台的客户端对象。"""
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口。
|
||||
|
||||
支持统一 Webhook 模式的平台需要实现此方法。
|
||||
当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容,格式取决于具体平台的要求
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 平台未实现统一 Webhook 模式
|
||||
"""
|
||||
raise NotImplementedError(f"平台 {self.meta().name} 未实现统一 Webhook 模式")
|
||||
|
||||
@@ -7,7 +7,7 @@ class PlatformMetadata:
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str | None = None
|
||||
id: str
|
||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||
|
||||
default_config_tmpl: dict | None = None
|
||||
|
||||
@@ -40,6 +40,7 @@ def register_platform_adapter(
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
id=adapter_name,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
|
||||
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
bot: CQHttp,
|
||||
event: Event | None,
|
||||
is_group: bool,
|
||||
session_id: str,
|
||||
session_id: str | None,
|
||||
messages: list[dict],
|
||||
):
|
||||
# session_id 必须是纯数字字符串
|
||||
session_id = int(session_id) if session_id.isdigit() else None
|
||||
session_id_int = (
|
||||
int(session_id) if session_id and session_id.isdigit() else None
|
||||
)
|
||||
|
||||
if is_group and isinstance(session_id, int):
|
||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||
elif not is_group and isinstance(session_id, int):
|
||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||
if is_group and isinstance(session_id_int, int):
|
||||
await bot.send_group_msg(group_id=session_id_int, message=messages)
|
||||
elif not is_group and isinstance(session_id_int, int):
|
||||
await bot.send_private_msg(user_id=session_id_int, message=messages)
|
||||
elif isinstance(event, Event): # 最后兜底
|
||||
await bot.send(event=event, message=messages)
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
@@ -38,9 +38,8 @@ class AiocqhttpAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.host = platform_config["ws_reverse_host"]
|
||||
@@ -49,7 +48,7 @@ class AiocqhttpAdapter(Platform):
|
||||
self.metadata = PlatformMetadata(
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -128,7 +127,9 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 请求类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||
)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if event.get("group_id"):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
@@ -154,7 +155,9 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 通知类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||
)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if event.get("group_id"):
|
||||
abm.group_id = str(event.group_id)
|
||||
@@ -193,6 +196,7 @@ class AiocqhttpAdapter(Platform):
|
||||
@param event: 事件对象
|
||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||
"""
|
||||
assert event.sender is not None
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
@@ -202,6 +206,7 @@ class AiocqhttpAdapter(Platform):
|
||||
if event["message_type"] == "group":
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.group = Group(str(event.group_id))
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -227,7 +232,7 @@ class AiocqhttpAdapter(Platform):
|
||||
await self.bot.send(event, err)
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
return None
|
||||
raise ValueError(err)
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
@@ -246,7 +251,13 @@ class AiocqhttpAdapter(Platform):
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
# 检查多个可能的文件名字段
|
||||
file_name = (
|
||||
m["data"].get("file_name", "")
|
||||
or m["data"].get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or "file"
|
||||
)
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
@@ -265,7 +276,14 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
a = File(name="", url=file_url)
|
||||
# 优先从 API 返回值获取文件名,其次从原始消息数据获取
|
||||
file_name = (
|
||||
ret.get("file_name", "")
|
||||
or ret.get("name", "")
|
||||
or m["data"].get("file", "")
|
||||
or m["data"].get("file_name", "")
|
||||
)
|
||||
a = File(name=file_name, url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
@@ -367,10 +385,25 @@ class AiocqhttpAdapter(Platform):
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
|
||||
message_str += "".join(at_parts)
|
||||
elif t == "markdown":
|
||||
text = m["data"].get("markdown") or m["data"].get("content", "")
|
||||
abm.message.append(Plain(text=text))
|
||||
message_str += text
|
||||
else:
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
try:
|
||||
if t not in ComponentTypes:
|
||||
logger.warning(
|
||||
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
|
||||
)
|
||||
continue
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"消息段解析失败: type={t}, data={m['data']}. {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
@@ -403,7 +436,7 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
await self.shutdown_event.wait()
|
||||
logger.info("aiocqhttp 适配器已被优雅地关闭")
|
||||
logger.info("aiocqhttp 适配器已被关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -47,21 +48,21 @@ class DingtalkPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.client_id = platform_config["client_id"]
|
||||
self.client_secret = platform_config["client_secret"]
|
||||
|
||||
outer_self = self
|
||||
|
||||
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
||||
async def process(self, message: dingtalk_stream.CallbackMessage):
|
||||
logger.debug(f"dingtalk: {message.data}")
|
||||
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||
abm = await self.convert_msg(im)
|
||||
await self.handle_msg(abm)
|
||||
abm = await outer_self.convert_msg(im)
|
||||
await outer_self.handle_msg(abm)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
@@ -75,14 +76,15 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self.client,
|
||||
)
|
||||
self.client_ = client # 用于 websockets 的 client
|
||||
self._shutdown_event: threading.Event | None = None
|
||||
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
||||
if not dingtalk_id:
|
||||
return dingtalk_id
|
||||
return dingtalk_id or "unknown"
|
||||
prefix = "$:LWCP_v1:$"
|
||||
if dingtalk_id.startswith(prefix):
|
||||
return dingtalk_id[len(prefix) :]
|
||||
return dingtalk_id
|
||||
return dingtalk_id or "unknown"
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
@@ -95,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -106,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
abm = AstrBotMessage()
|
||||
abm.message = []
|
||||
abm.message_str = ""
|
||||
abm.timestamp = int(message.create_at / 1000)
|
||||
abm.timestamp = int(cast(int, message.create_at) / 1000)
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if message.conversation_type == "2"
|
||||
@@ -117,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
nickname=message.sender_nick,
|
||||
)
|
||||
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
||||
abm.message_id = message.message_id
|
||||
abm.message_id = cast(str, message.message_id)
|
||||
abm.raw_message = message
|
||||
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
@@ -134,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
message_type: str = message.message_type
|
||||
message_type: str = cast(str, message.message_type)
|
||||
match message_type:
|
||||
case "text":
|
||||
abm.message_str = message.text.content.strip()
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
case "richText":
|
||||
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
||||
contents: list[dict] = rtc.rich_text_list
|
||||
rtc: dingtalk_stream.RichTextContent = cast(
|
||||
dingtalk_stream.RichTextContent, message.rich_text_content
|
||||
)
|
||||
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
|
||||
for content in contents:
|
||||
plains = ""
|
||||
if "text" in content:
|
||||
@@ -150,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
elif "type" in content and content["type"] == "picture":
|
||||
f_path = await self.download_ding_file(
|
||||
content["downloadCode"],
|
||||
message.robot_code,
|
||||
cast(str, message.robot_code),
|
||||
"jpg",
|
||||
)
|
||||
abm.message.append(Image.fromFileSystem(f_path))
|
||||
@@ -195,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return None
|
||||
return ""
|
||||
resp_data = await resp.json()
|
||||
download_url = resp_data["data"]["downloadUrl"]
|
||||
await download_file(download_url, f_path)
|
||||
@@ -215,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return None
|
||||
return ""
|
||||
return (await resp.json())["data"]["accessToken"]
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
@@ -241,7 +245,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
task.result()
|
||||
except Exception as e:
|
||||
if "Graceful shutdown" in str(e):
|
||||
logger.info("钉钉适配器已被优雅地关闭")
|
||||
logger.info("钉钉适配器已被关闭")
|
||||
return
|
||||
logger.error(f"钉钉机器人启动失败: {e}")
|
||||
|
||||
@@ -250,11 +254,13 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
async def terminate(self):
|
||||
def monkey_patch_close():
|
||||
raise Exception("Graceful shutdown")
|
||||
raise KeyboardInterrupt("Graceful shutdown")
|
||||
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
self._shutdown_event.set()
|
||||
if self.client_.websocket is not None:
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
import dingtalk_stream
|
||||
|
||||
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client.reply_markdown,
|
||||
segment.text,
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
self.message_obj.raw_message,
|
||||
cast(
|
||||
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
|
||||
),
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import discord
|
||||
|
||||
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
|
||||
# 回调函数
|
||||
self.on_message_received = None
|
||||
self.on_ready_once_callback = None
|
||||
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
|
||||
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
|
||||
self._ready_once_fired = False
|
||||
|
||||
@override
|
||||
async def on_ready(self):
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
if self.user is None:
|
||||
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
|
||||
return
|
||||
|
||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||
logger.info("[Discord] 客户端已准备就绪。")
|
||||
|
||||
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
is_mentioned = self.user in message.mentions
|
||||
return {
|
||||
"message": message,
|
||||
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||
"""从 discord.Interaction 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
if interaction.user is None:
|
||||
raise ValueError("Interaction received without a valid user")
|
||||
|
||||
return {
|
||||
"interaction": interaction,
|
||||
"bot_id": str(self.user.id),
|
||||
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
|
||||
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
components: list[BaseMessageComponent] = None,
|
||||
timeout: float = None,
|
||||
components: list[BaseMessageComponent] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import discord
|
||||
from discord.abc import Messageable
|
||||
from discord.abc import GuildChannel, Messageable, PrivateChannel
|
||||
from discord.channel import DMChannel
|
||||
|
||||
from astrbot import logger
|
||||
@@ -44,10 +44,9 @@ class DiscordPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = None
|
||||
self.client_self_id: str | None = None
|
||||
self.registered_handlers = []
|
||||
# 指令注册相关
|
||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||
@@ -63,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 创建一个 message_obj 以便在 event 中使用
|
||||
message_obj = AstrBotMessage()
|
||||
if "_" in session.session_id:
|
||||
@@ -90,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
user_id=str(self.client_self_id),
|
||||
nickname=self.client.user.display_name,
|
||||
)
|
||||
message_obj.self_id = self.client_self_id
|
||||
message_obj.self_id = cast(str, self.client_self_id)
|
||||
message_obj.session_id = session.session_id
|
||||
message_obj.message = message_chain.chain
|
||||
|
||||
@@ -111,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"discord",
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
default_config_tmpl=self.config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
@@ -161,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
|
||||
def _get_message_type(
|
||||
self,
|
||||
channel: Messageable,
|
||||
channel: Messageable | GuildChannel | PrivateChannel,
|
||||
guild_id: int | None = None,
|
||||
) -> MessageType:
|
||||
"""根据 channel 对象和 guild_id 判断消息类型"""
|
||||
@@ -171,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
|
||||
return MessageType.FRIEND_MESSAGE
|
||||
return MessageType.GROUP_MESSAGE
|
||||
|
||||
def _get_channel_id(self, channel: Messageable) -> str:
|
||||
def _get_channel_id(
|
||||
self, channel: Messageable | GuildChannel | PrivateChannel
|
||||
) -> str:
|
||||
"""根据 channel 对象获取ID"""
|
||||
return str(getattr(channel, "id", None))
|
||||
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message: discord.Message = data["message"]
|
||||
message = data["message"]
|
||||
|
||||
content = message.content
|
||||
|
||||
@@ -234,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = self.client_self_id
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.session_id = str(message.channel.id)
|
||||
abm.message_id = str(message.id)
|
||||
return abm
|
||||
@@ -255,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
|
||||
interaction_followup_webhook=followup_webhook,
|
||||
)
|
||||
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否为斜杠指令
|
||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||
|
||||
# 1. 优先处理斜杠指令
|
||||
if is_slash_command:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
self.commit_event(message_event)
|
||||
return
|
||||
|
||||
# 2. 处理普通消息(提及检测)
|
||||
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
|
||||
raw_message = message.raw_message
|
||||
if not isinstance(raw_message, discord.Message):
|
||||
logger.warning(
|
||||
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||
is_mention = False
|
||||
|
||||
# User Mention
|
||||
if (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
):
|
||||
if self.client.user in message.raw_message.mentions:
|
||||
is_mention = True
|
||||
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
|
||||
if self.client.user in raw_message.mentions:
|
||||
is_mention = True
|
||||
|
||||
# Role Mention(Bot 拥有的角色被提及)
|
||||
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
||||
if not is_mention and raw_message.role_mentions:
|
||||
bot_member = None
|
||||
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
||||
if raw_message.guild:
|
||||
try:
|
||||
bot_member = message.raw_message.guild.get_member(
|
||||
bot_member = raw_message.guild.get_member(
|
||||
self.client.user.id,
|
||||
)
|
||||
except Exception:
|
||||
bot_member = None
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
bot_roles = set(bot_member.roles)
|
||||
mentioned_roles = set(message.raw_message.role_mentions)
|
||||
mentioned_roles = set(raw_message.role_mentions)
|
||||
if (
|
||||
bot_roles
|
||||
and mentioned_roles
|
||||
@@ -288,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
|
||||
):
|
||||
is_mention = True
|
||||
|
||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||
if is_slash_command or is_mention:
|
||||
# 如果是被@的消息,设置为唤醒状态
|
||||
if is_mention:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
|
||||
@@ -425,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = [Plain(text=message_str_for_filter)]
|
||||
abm.raw_message = ctx.interaction
|
||||
abm.self_id = self.client_self_id
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.session_id = str(ctx.channel_id)
|
||||
abm.message_id = str(ctx.interaction.id)
|
||||
|
||||
@@ -438,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
def _extract_command_info(
|
||||
event_filter: Any,
|
||||
handler_metadata: StarHandlerMetadata,
|
||||
) -> tuple[str, str, CommandFilter] | None:
|
||||
) -> tuple[str, str, CommandFilter | None] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
# is_group = False
|
||||
|
||||
@@ -4,8 +4,10 @@ import binascii
|
||||
from collections.abc import AsyncGenerator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import discord
|
||||
from discord.types.interactions import ComponentInteractionData
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
channel = await self._get_channel()
|
||||
if not channel:
|
||||
return
|
||||
if not isinstance(channel, discord.abc.Messageable):
|
||||
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
|
||||
return
|
||||
await channel.send(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _get_channel(self) -> discord.abc.Messageable | None:
|
||||
async def _get_channel(
|
||||
self,
|
||||
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
channel_id = int(self.session_id)
|
||||
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
async def _parse_to_discord(
|
||||
self,
|
||||
message: MessageChain,
|
||||
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
|
||||
) -> tuple[
|
||||
str,
|
||||
list[discord.File],
|
||||
discord.ui.View | None,
|
||||
list[discord.Embed],
|
||||
str | int | None,
|
||||
]:
|
||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||
content_parts = []
|
||||
files = []
|
||||
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"add_reaction",
|
||||
):
|
||||
await self.message_obj.raw_message.add_reaction(emoji)
|
||||
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
|
||||
emoji
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||
|
||||
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
== discord.InteractionType.application_command
|
||||
)
|
||||
|
||||
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
== discord.InteractionType.component
|
||||
)
|
||||
|
||||
def get_interaction_custom_id(self) -> str:
|
||||
"""获取交互组件的custom_id"""
|
||||
if self.is_button_interaction():
|
||||
try:
|
||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
||||
return cast(
|
||||
ComponentInteractionData,
|
||||
cast(discord.Interaction, self.message_obj.raw_message).data,
|
||||
).get("custom_id", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
):
|
||||
return any(
|
||||
mention.id == int(self.message_obj.self_id)
|
||||
for mention in self.message_obj.raw_message.mentions
|
||||
for mention in cast(
|
||||
discord.Message, self.message_obj.raw_message
|
||||
).mentions
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"clean_content",
|
||||
):
|
||||
return self.message_obj.raw_message.clean_content
|
||||
return cast(discord.Message, self.message_obj.raw_message).clean_content
|
||||
return self.message_str
|
||||
|
||||
@@ -2,10 +2,17 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
GetMessageResourceRequest,
|
||||
)
|
||||
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot import logger
|
||||
@@ -18,9 +25,11 @@ from astrbot.api.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .lark_event import LarkMessageEvent
|
||||
from .server import LarkWebhookServer
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -33,9 +42,7 @@ class LarkPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
@@ -44,9 +51,13 @@ class LarkPlatformAdapter(Platform):
|
||||
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
|
||||
self.bot_name = platform_config.get("lark_bot_name", "astrbot")
|
||||
|
||||
# socket or webhook
|
||||
self.connection_mode = platform_config.get("lark_connection_mode", "socket")
|
||||
|
||||
if not self.bot_name:
|
||||
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
|
||||
|
||||
# 初始化 WebSocket 长连接相关配置
|
||||
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
await self.convert_msg(event)
|
||||
|
||||
@@ -59,6 +70,8 @@ class LarkPlatformAdapter(Platform):
|
||||
.build()
|
||||
)
|
||||
|
||||
self.do_v2_msg_event = do_v2_msg_event
|
||||
|
||||
self.client = lark.ws.Client(
|
||||
app_id=self.appid,
|
||||
app_secret=self.appsecret,
|
||||
@@ -68,14 +81,56 @@ class LarkPlatformAdapter(Platform):
|
||||
)
|
||||
|
||||
self.lark_api = (
|
||||
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
|
||||
lark.Client.builder()
|
||||
.app_id(self.appid)
|
||||
.app_secret(self.appsecret)
|
||||
.log_level(lark.LogLevel.ERROR)
|
||||
.domain(self.domain)
|
||||
.build()
|
||||
)
|
||||
|
||||
self.webhook_server = None
|
||||
if self.connection_mode == "webhook":
|
||||
self.webhook_server = LarkWebhookServer(platform_config, event_queue)
|
||||
self.webhook_server.set_callback(self.handle_webhook_event)
|
||||
|
||||
self.event_id_timestamps: dict[str, float] = {}
|
||||
|
||||
def _clean_expired_events(self):
|
||||
"""清理超过 30 分钟的事件记录"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
event_id
|
||||
for event_id, timestamp in self.event_id_timestamps.items()
|
||||
if current_time - timestamp > 1800
|
||||
]
|
||||
for event_id in expired_keys:
|
||||
del self.event_id_timestamps[event_id]
|
||||
|
||||
def _is_duplicate_event(self, event_id: str) -> bool:
|
||||
"""检查事件是否重复
|
||||
|
||||
Args:
|
||||
event_id: 事件ID
|
||||
|
||||
Returns:
|
||||
True 表示重复事件,False 表示新事件
|
||||
"""
|
||||
self._clean_expired_events()
|
||||
if event_id in self.event_id_timestamps:
|
||||
return True
|
||||
self.event_id_timestamps[event_id] = time.time()
|
||||
return False
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||
return
|
||||
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
@@ -116,14 +171,25 @@ class LarkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
if event.event is None:
|
||||
logger.debug("[Lark] 收到空事件(event.event is None)")
|
||||
return
|
||||
message = event.event.message
|
||||
if message is None:
|
||||
logger.debug("[Lark] 事件中没有消息体(message is None)")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.timestamp = int(message.create_time) / 1000
|
||||
|
||||
if message.create_time:
|
||||
abm.timestamp = int(message.create_time) // 1000
|
||||
else:
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message = []
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
@@ -138,14 +204,28 @@ class LarkPlatformAdapter(Platform):
|
||||
at_list = {}
|
||||
if message.mentions:
|
||||
for m in message.mentions:
|
||||
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
||||
if m.name == self.bot_name:
|
||||
abm.self_id = m.id.open_id
|
||||
if m.id is None:
|
||||
continue
|
||||
# 飞书 open_id 可能是 None,这里做个防护
|
||||
open_id = m.id.open_id if m.id.open_id else ""
|
||||
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
|
||||
|
||||
content_json_b = json.loads(message.content)
|
||||
if m.name == self.bot_name:
|
||||
if m.id.open_id is not None:
|
||||
abm.self_id = m.id.open_id
|
||||
|
||||
if message.content is None:
|
||||
logger.warning("[Lark] 消息内容为空")
|
||||
return
|
||||
|
||||
try:
|
||||
content_json_b = json.loads(message.content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
|
||||
return
|
||||
|
||||
if message.message_type == "text":
|
||||
message_str_raw = content_json_b["text"] # 带有 @ 的消息
|
||||
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
|
||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||
# at_users = re.findall(at_pattern, message_str_raw)
|
||||
# 拆分文本,去掉AT符号部分
|
||||
@@ -170,27 +250,47 @@ class LarkPlatformAdapter(Platform):
|
||||
content_json_b = _ls
|
||||
elif message.message_type == "image":
|
||||
content_json_b = [
|
||||
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
|
||||
{
|
||||
"tag": "img",
|
||||
"image_key": content_json_b.get("image_key"),
|
||||
"style": [],
|
||||
},
|
||||
]
|
||||
|
||||
if message.message_type in ("post", "image"):
|
||||
for comp in content_json_b:
|
||||
if comp["tag"] == "at":
|
||||
abm.message.append(at_list[comp["user_id"]])
|
||||
elif comp["tag"] == "text" and comp["text"].strip():
|
||||
if comp.get("tag") == "at":
|
||||
user_id = comp.get("user_id")
|
||||
if user_id in at_list:
|
||||
abm.message.append(at_list[user_id])
|
||||
elif comp.get("tag") == "text" and comp.get("text", "").strip():
|
||||
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||
elif comp["tag"] == "img":
|
||||
image_key = comp["image_key"]
|
||||
elif comp.get("tag") == "img":
|
||||
image_key = comp.get("image_key")
|
||||
if not image_key:
|
||||
continue
|
||||
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message.message_id)
|
||||
.message_id(cast(str, message.message_id))
|
||||
.file_key(image_key)
|
||||
.type("image")
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
continue
|
||||
|
||||
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法下载飞书图片: {image_key}")
|
||||
continue
|
||||
|
||||
if response.file is None:
|
||||
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
|
||||
continue
|
||||
|
||||
image_bytes = response.file.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode()
|
||||
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||
@@ -198,6 +298,19 @@ class LarkPlatformAdapter(Platform):
|
||||
for comp in abm.message:
|
||||
if isinstance(comp, Comp.Plain):
|
||||
abm.message_str += comp.text
|
||||
|
||||
if message.message_id is None:
|
||||
logger.error("[Lark] 消息缺少 message_id")
|
||||
return
|
||||
|
||||
if (
|
||||
event.event.sender is None
|
||||
or event.event.sender.sender_id is None
|
||||
or event.event.sender.sender_id.open_id is None
|
||||
):
|
||||
logger.error("[Lark] 消息发送者信息不完整")
|
||||
return
|
||||
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
abm.sender = MessageMember(
|
||||
@@ -229,13 +342,61 @@ class LarkPlatformAdapter(Platform):
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
async def handle_webhook_event(self, event_data: dict):
|
||||
"""处理 Webhook 事件
|
||||
|
||||
Args:
|
||||
event_data: Webhook 事件数据
|
||||
"""
|
||||
try:
|
||||
header = event_data.get("header", {})
|
||||
event_id = header.get("event_id", "")
|
||||
if event_id and self._is_duplicate_event(event_id):
|
||||
logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
|
||||
return
|
||||
event_type = header.get("event_type", "")
|
||||
if event_type == "im.message.receive_v1":
|
||||
processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
|
||||
data = (processor.type())(event_data)
|
||||
processor.do(data)
|
||||
else:
|
||||
logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
|
||||
|
||||
async def run(self):
|
||||
# self.client.start()
|
||||
await self.client._connect()
|
||||
if self.connection_mode == "webhook":
|
||||
# Webhook 模式
|
||||
if self.webhook_server is None:
|
||||
logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
|
||||
return
|
||||
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
|
||||
else:
|
||||
logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
|
||||
else:
|
||||
# 长连接模式
|
||||
await self.client._connect()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
if not self.webhook_server:
|
||||
return {"error": "Webhook server not initialized"}, 500
|
||||
|
||||
return await self.webhook_server.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||
if self.connection_mode == "socket":
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已关闭")
|
||||
|
||||
def get_client(self) -> lark.Client:
|
||||
def get_client(self) -> lark.ws.Client:
|
||||
return self.client
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
return bool(
|
||||
self.config.get("lark_connection_mode", "") == "webhook"
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
@@ -5,7 +5,15 @@ import uuid
|
||||
from io import BytesIO
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
Emoji,
|
||||
ReplyMessageRequest,
|
||||
ReplyMessageRequestBody,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(comp.file)
|
||||
file_path = image_file_path
|
||||
file_path = image_file_path if image_file_path else ""
|
||||
elif comp.file and comp.file.startswith("base64://"):
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
file_path = comp.file
|
||||
file_path = comp.file if comp.file else ""
|
||||
|
||||
if image_file is None:
|
||||
image_file = open(file_path, "rb")
|
||||
if not file_path:
|
||||
logger.error("[Lark] 图片路径为空,无法上传")
|
||||
continue
|
||||
try:
|
||||
image_file = open(file_path, "rb")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开图片文件: {e}")
|
||||
continue
|
||||
|
||||
request = (
|
||||
CreateImageRequest.builder()
|
||||
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
|
||||
continue
|
||||
|
||||
response = await lark_client.im.v1.image.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||
continue
|
||||
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
|
||||
continue
|
||||
|
||||
image_key = response.data.image_key
|
||||
logger.debug(image_key)
|
||||
ret.append(_stage)
|
||||
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
|
||||
return
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
|
||||
if not response.success():
|
||||
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||
return
|
||||
|
||||
request = (
|
||||
CreateMessageReactionRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""飞书(Lark) Webhook 服务器实现
|
||||
|
||||
实现飞书事件订阅的 Webhook 模式,支持:
|
||||
1. 请求 URL 验证 (challenge 验证)
|
||||
2. 事件加密/解密 (AES-256-CBC)
|
||||
3. 签名校验 (SHA256)
|
||||
4. 事件接收和处理
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class AESCipher:
|
||||
"""AES 加密/解密工具类"""
|
||||
|
||||
def __init__(self, key: str):
|
||||
self.bs = AES.block_size
|
||||
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
|
||||
|
||||
@staticmethod
|
||||
def str_to_bytes(data):
|
||||
u_type = type(b"".decode("utf8"))
|
||||
if isinstance(data, u_type):
|
||||
return data.encode("utf8")
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _unpad(s):
|
||||
return s[: -ord(s[len(s) - 1 :])]
|
||||
|
||||
def decrypt(self, enc):
|
||||
iv = enc[: AES.block_size]
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
|
||||
|
||||
def decrypt_string(self, enc):
|
||||
enc = base64.b64decode(enc)
|
||||
return self.decrypt(enc).decode("utf8")
|
||||
|
||||
|
||||
class LarkWebhookServer:
|
||||
"""飞书 Webhook 服务器
|
||||
|
||||
仅支持统一 Webhook 模式
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict, event_queue: asyncio.Queue):
|
||||
"""初始化 Webhook 服务器
|
||||
|
||||
Args:
|
||||
config: 飞书配置
|
||||
event_queue: 事件队列
|
||||
"""
|
||||
self.app_id = config["app_id"]
|
||||
self.app_secret = config["app_secret"]
|
||||
self.encrypt_key = config.get("lark_encrypt_key", "")
|
||||
self.verification_token = config.get("lark_verification_token", "")
|
||||
|
||||
self.event_queue = event_queue
|
||||
self.callback: Callable[[dict], Awaitable[None]] | None = None
|
||||
|
||||
# 初始化加密工具
|
||||
self.cipher = None
|
||||
if self.encrypt_key:
|
||||
self.cipher = AESCipher(self.encrypt_key)
|
||||
|
||||
def verify_signature(
|
||||
self,
|
||||
timestamp: str,
|
||||
nonce: str,
|
||||
encrypt_key: str,
|
||||
body: bytes,
|
||||
signature: str,
|
||||
) -> bool:
|
||||
"""验证签名
|
||||
|
||||
Args:
|
||||
timestamp: 请求时间戳
|
||||
nonce: 随机数
|
||||
encrypt_key: 加密密钥
|
||||
body: 请求体
|
||||
signature: 签名
|
||||
|
||||
Returns:
|
||||
签名是否有效
|
||||
"""
|
||||
# 拼接字符串: timestamp + nonce + encrypt_key + body
|
||||
bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8")
|
||||
bytes_b = bytes_b1 + body
|
||||
h = hashlib.sha256(bytes_b)
|
||||
calculated_signature = h.hexdigest()
|
||||
return calculated_signature == signature
|
||||
|
||||
def decrypt_event(self, encrypted_data: str) -> dict:
|
||||
"""解密事件数据
|
||||
|
||||
Args:
|
||||
encrypted_data: 加密的事件数据
|
||||
|
||||
Returns:
|
||||
解密后的事件字典
|
||||
"""
|
||||
if not self.cipher:
|
||||
raise ValueError("未配置 encrypt_key,无法解密事件")
|
||||
|
||||
decrypted_str = self.cipher.decrypt_string(encrypted_data)
|
||||
return json.loads(decrypted_str)
|
||||
|
||||
async def handle_challenge(self, event_data: dict) -> dict:
|
||||
"""处理 challenge 验证请求
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
|
||||
Returns:
|
||||
包含 challenge 的响应
|
||||
"""
|
||||
challenge = event_data.get("challenge", "")
|
||||
logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}")
|
||||
|
||||
return {"challenge": challenge}
|
||||
|
||||
async def handle_callback(self, request) -> tuple[dict, int] | dict:
|
||||
"""处理 webhook 回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
# 获取原始请求体
|
||||
body = await request.get_data()
|
||||
|
||||
try:
|
||||
event_data = await request.json
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 解析请求体失败: {e}")
|
||||
return {"error": "Invalid JSON"}, 400
|
||||
|
||||
if not event_data:
|
||||
logger.error("[Lark Webhook] 请求体为空")
|
||||
return {"error": "Empty request body"}, 400
|
||||
|
||||
# 如果配置了 encrypt_key,进行签名验证
|
||||
if self.encrypt_key:
|
||||
timestamp = request.headers.get("X-Lark-Request-Timestamp", "")
|
||||
nonce = request.headers.get("X-Lark-Request-Nonce", "")
|
||||
signature = request.headers.get("X-Lark-Signature", "")
|
||||
|
||||
if timestamp and nonce and signature:
|
||||
if not self.verify_signature(
|
||||
timestamp, nonce, self.encrypt_key, body, signature
|
||||
):
|
||||
logger.error("[Lark Webhook] 签名验证失败")
|
||||
return {"error": "Invalid signature"}, 401
|
||||
|
||||
# 检查是否是加密事件
|
||||
if "encrypt" in event_data:
|
||||
try:
|
||||
event_data = self.decrypt_event(event_data["encrypt"])
|
||||
logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 解密事件失败: {e}")
|
||||
return {"error": "Decryption failed"}, 400
|
||||
|
||||
# 验证 token
|
||||
if self.verification_token:
|
||||
header = event_data.get("header", {})
|
||||
if header:
|
||||
token = header.get("token", "")
|
||||
else:
|
||||
token = event_data.get("token", "")
|
||||
if token != self.verification_token:
|
||||
logger.error("[Lark Webhook] Verification Token 不匹配。")
|
||||
return {"error": "Invalid verification token"}, 401
|
||||
|
||||
# 处理 URL 验证 (challenge)
|
||||
if event_data.get("type") == "url_verification":
|
||||
return await self.handle_challenge(event_data)
|
||||
|
||||
# 调用回调函数处理事件
|
||||
if self.callback:
|
||||
try:
|
||||
await self.callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True)
|
||||
return {"error": "Event processing failed"}, 500
|
||||
|
||||
return {}
|
||||
|
||||
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
|
||||
"""设置事件回调函数
|
||||
|
||||
Args:
|
||||
callback: 处理事件的异步函数
|
||||
"""
|
||||
self.callback = callback
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
@@ -55,8 +54,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config or {}
|
||||
super().__init__(platform_config or {}, event_queue)
|
||||
self.settings = platform_settings or {}
|
||||
self.instance_url = self.config.get("misskey_instance_url", "")
|
||||
self.access_token = self.config.get("misskey_token", "")
|
||||
@@ -204,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
if not isinstance(message.raw_message, dict):
|
||||
message.raw_message = {}
|
||||
message.raw_message["poll"] = poll
|
||||
message.poll = poll
|
||||
message.__setattr__("poll", poll)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -373,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSession,
|
||||
message_chain: MessageChain,
|
||||
) -> Awaitable[Any]:
|
||||
) -> None:
|
||||
if not self.api:
|
||||
logger.error("[Misskey] API 客户端未初始化")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import aiofiles
|
||||
import botpy
|
||||
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
ret = cast(
|
||||
message.Message,
|
||||
await self._post_send(stream=stream_payload),
|
||||
)
|
||||
stream_payload["index"] += 1
|
||||
stream_payload["id"] = ret["id"]
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
@@ -69,6 +73,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||
stream_payload["state"] = 10
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
else:
|
||||
ret = await self._post_send()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||
@@ -81,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
return None
|
||||
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(
|
||||
|
||||
if not isinstance(
|
||||
source,
|
||||
(
|
||||
botpy.message.Message,
|
||||
@@ -89,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
botpy.message.DirectMessage,
|
||||
botpy.message.C2CMessage,
|
||||
),
|
||||
)
|
||||
):
|
||||
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
|
||||
return None
|
||||
|
||||
(
|
||||
plain_text,
|
||||
@@ -106,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
):
|
||||
return None
|
||||
|
||||
payload = {
|
||||
payload: dict = {
|
||||
"content": plain_text,
|
||||
"msg_id": self.message_obj.message_id,
|
||||
}
|
||||
@@ -116,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
|
||||
ret = None
|
||||
|
||||
match type(source):
|
||||
case botpy.message.GroupMessage:
|
||||
match source:
|
||||
case botpy.message.GroupMessage():
|
||||
if not source.group_openid:
|
||||
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
|
||||
return None
|
||||
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(
|
||||
image_base64,
|
||||
@@ -138,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
group_openid=source.group_openid,
|
||||
**payload,
|
||||
)
|
||||
case botpy.message.C2CMessage:
|
||||
|
||||
case botpy.message.C2CMessage():
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(
|
||||
image_base64,
|
||||
@@ -167,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
**payload,
|
||||
)
|
||||
logger.debug(f"Message sent to C2C: {ret}")
|
||||
case botpy.message.Message:
|
||||
|
||||
case botpy.message.Message():
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
ret = await self.bot.api.post_message(
|
||||
channel_id=source.channel_id,
|
||||
**payload,
|
||||
)
|
||||
case botpy.message.DirectMessage:
|
||||
|
||||
case botpy.message.DirectMessage():
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
@@ -196,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
"file_type": file_type,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
result = None
|
||||
if "openid" in kwargs:
|
||||
payload["openid"] = kwargs["openid"]
|
||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
if "group_openid" in kwargs:
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
elif "group_openid" in kwargs:
|
||||
payload["group_openid"] = kwargs["group_openid"]
|
||||
route = Route(
|
||||
"POST",
|
||||
"/v2/groups/{group_openid}/files",
|
||||
group_openid=kwargs["group_openid"],
|
||||
)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
else:
|
||||
raise ValueError("Invalid upload parameters")
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise RuntimeError(
|
||||
f"Failed to upload image, response is not dict: {result}"
|
||||
)
|
||||
|
||||
return Media(
|
||||
file_uuid=result["file_uuid"],
|
||||
file_info=result["file_info"],
|
||||
ttl=result.get("ttl", 0),
|
||||
)
|
||||
|
||||
async def upload_group_and_c2c_record(
|
||||
self,
|
||||
@@ -250,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
if result:
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"上传文件响应格式错误: {result}")
|
||||
return None
|
||||
|
||||
return Media(
|
||||
file_uuid=result.get("file_uuid"),
|
||||
file_info=result.get("file_info"),
|
||||
file_uuid=result["file_uuid"],
|
||||
file_info=result["file_info"],
|
||||
ttl=result.get("ttl", 0),
|
||||
file_id=result.get("id", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传请求错误: {e}")
|
||||
@@ -271,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
message_reference: message.Reference | None = None,
|
||||
media: message.Media | None = None,
|
||||
msg_id: str | None = None,
|
||||
msg_seq: str = 1,
|
||||
msg_seq: int | None = 1,
|
||||
event_id: str | None = None,
|
||||
markdown: message.MarkdownPayload | None = None,
|
||||
keyboard: message.Keyboard | None = None,
|
||||
@@ -280,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
payload = locals()
|
||||
payload.pop("self", None)
|
||||
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise RuntimeError(
|
||||
f"Failed to post c2c message, response is not dict: {result}"
|
||||
)
|
||||
|
||||
return message.Message(**result)
|
||||
|
||||
@staticmethod
|
||||
async def _parse_to_qqofficial(message: MessageChain):
|
||||
@@ -300,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
elif i.file and i.file.startswith("base64://"):
|
||||
image_base64 = i.file
|
||||
else:
|
||||
elif i.file:
|
||||
image_base64 = file_to_base64(i.file)
|
||||
else:
|
||||
raise ValueError("Unsupported image file format")
|
||||
image_base64 = image_base64.removeprefix("base64://")
|
||||
elif isinstance(i, Record):
|
||||
if i.file:
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import botpy
|
||||
import botpy.message
|
||||
@@ -44,7 +45,9 @@ class botClient(Client):
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
self._commit(abm)
|
||||
|
||||
@@ -97,13 +100,11 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.unique_session: bool = platform_settings["unique_session"]
|
||||
qq_group = platform_config["enable_group_c2c"]
|
||||
guild_dm = platform_config["enable_guild_direct_message"]
|
||||
|
||||
@@ -139,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="qq_official",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_from_qqofficial(
|
||||
message: botpy.message.Message | botpy.message.GroupMessage,
|
||||
message: botpy.message.Message
|
||||
| botpy.message.GroupMessage
|
||||
| botpy.message.DirectMessage
|
||||
| botpy.message.C2CMessage,
|
||||
message_type: MessageType,
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
@@ -152,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
abm.timestamp = int(time.time())
|
||||
abm.raw_message = message
|
||||
abm.message_id = message.id
|
||||
abm.tag = "qq_official"
|
||||
# abm.tag = "qq_official"
|
||||
msg: list[BaseMessageComponent] = []
|
||||
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(
|
||||
@@ -182,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
message,
|
||||
botpy.message.DirectMessage,
|
||||
):
|
||||
try:
|
||||
if isinstance(message, botpy.message.Message):
|
||||
abm.self_id = str(message.mentions[0].id)
|
||||
except BaseException as _:
|
||||
else:
|
||||
abm.self_id = ""
|
||||
|
||||
plain_content = message.content.replace(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import botpy
|
||||
import botpy.message
|
||||
@@ -11,6 +12,7 @@ from astrbot import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, MessageType, Platform, PlatformMetadata
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||
@@ -34,7 +36,9 @@ class botClient(Client):
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
self._commit(abm)
|
||||
|
||||
@@ -87,13 +91,12 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
intents = botpy.Intents(
|
||||
public_messages=True,
|
||||
@@ -106,6 +109,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
timeout=20,
|
||||
)
|
||||
self.client.set_platform(self)
|
||||
self.webhook_helper = None
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
@@ -118,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="qq_official_webhook",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@@ -128,16 +132,37 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
self.client,
|
||||
)
|
||||
await self.webhook_helper.initialize()
|
||||
await self.webhook_helper.start_polling()
|
||||
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(QQ 官方机器人 Webhook)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.webhook_helper.shutdown_event.wait()
|
||||
else:
|
||||
await self.webhook_helper.start_polling()
|
||||
|
||||
def get_client(self) -> botClient:
|
||||
return self.client
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
if not self.webhook_helper:
|
||||
return {"error": "Webhook helper not initialized"}, 500
|
||||
|
||||
# 复用 webhook_helper 的回调处理逻辑
|
||||
return await self.webhook_helper.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
self.webhook_helper.shutdown_event.set()
|
||||
if self.webhook_helper:
|
||||
self.webhook_helper.shutdown_event.set()
|
||||
await self.client.close()
|
||||
try:
|
||||
await self.webhook_helper.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
if self.webhook_helper and not self.unified_webhook_mode:
|
||||
try:
|
||||
await self.webhook_helper.server.shutdown()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Exception occurred during QQOfficialWebhook server shutdown: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import quart
|
||||
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
||||
@@ -78,7 +79,19 @@ class QQOfficialWebhook:
|
||||
return response
|
||||
|
||||
async def callback(self):
|
||||
msg: dict = await quart.request.json
|
||||
"""内部服务器的回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request) -> dict:
|
||||
"""处理 webhook 回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
msg: dict = await request.json
|
||||
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
|
||||
|
||||
event = msg.get("t")
|
||||
@@ -87,7 +100,7 @@ class QQOfficialWebhook:
|
||||
|
||||
if opcode == 13:
|
||||
# validation
|
||||
signed = await self.webhook_validation(data)
|
||||
signed = await self.webhook_validation(cast(dict, data))
|
||||
print(signed)
|
||||
return signed
|
||||
|
||||
|
||||
@@ -38,8 +38,7 @@ class SatoriPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
|
||||
self.api_base_url = self.config.get(
|
||||
|
||||
@@ -4,9 +4,11 @@ import hmac
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from quart import Quart, Response, request
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
@@ -47,51 +49,62 @@ class SlackWebhookClient:
|
||||
|
||||
@self.app.route(self.path, methods=["POST"])
|
||||
async def slack_events():
|
||||
"""处理 Slack 事件"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = await request.get_data()
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
timestamp = request.headers.get("X-Slack-Request-Timestamp")
|
||||
signature = request.headers.get("X-Slack-Signature")
|
||||
if not timestamp or not signature:
|
||||
return Response("Missing headers", status=400)
|
||||
# Calculate the HMAC signature
|
||||
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
my_signature = (
|
||||
"v0="
|
||||
+ hmac.new(
|
||||
self.signing_secret.encode("utf-8"),
|
||||
sig_basestring.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
)
|
||||
# Verify the signature
|
||||
if not hmac.compare_digest(my_signature, signature):
|
||||
logger.warning("Slack request signature verification failed")
|
||||
return Response("Invalid signature", status=400)
|
||||
logger.info(f"Received Slack event: {event_data}")
|
||||
|
||||
# 处理 URL 验证事件
|
||||
if event_data.get("type") == "url_verification":
|
||||
return {"challenge": event_data.get("challenge")}
|
||||
# 处理事件
|
||||
if self.event_handler and event_data.get("type") == "event_callback":
|
||||
await self.event_handler(event_data)
|
||||
|
||||
return Response("", status=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Slack 事件时出错: {e}")
|
||||
return Response("Internal Server Error", status=500)
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(request)
|
||||
|
||||
@self.app.route("/health", methods=["GET"])
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {"status": "ok", "service": "slack-webhook"}
|
||||
|
||||
async def handle_callback(self, req):
|
||||
"""处理 Slack 回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
req: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
Response 对象或字典
|
||||
"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = cast(bytes, await req.get_data())
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
timestamp = req.headers.get("X-Slack-Request-Timestamp")
|
||||
signature = req.headers.get("X-Slack-Signature")
|
||||
if not timestamp or not signature:
|
||||
return Response("Missing headers", status=400)
|
||||
# Calculate the HMAC signature
|
||||
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
my_signature = (
|
||||
"v0="
|
||||
+ hmac.new(
|
||||
self.signing_secret.encode("utf-8"),
|
||||
sig_basestring.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
)
|
||||
# Verify the signature
|
||||
if not hmac.compare_digest(my_signature, signature):
|
||||
logger.warning("Slack request signature verification failed")
|
||||
return Response("Invalid signature", status=400)
|
||||
logger.info(f"Received Slack event: {event_data}")
|
||||
|
||||
# 处理 URL 验证事件
|
||||
if event_data.get("type") == "url_verification":
|
||||
return {"challenge": event_data.get("challenge")}
|
||||
# 处理事件
|
||||
if self.event_handler and event_data.get("type") == "event_callback":
|
||||
await self.event_handler(event_data)
|
||||
|
||||
return Response("", status=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Slack 事件时出错: {e}")
|
||||
return Response("Internal Server Error", status=500)
|
||||
|
||||
async def start(self):
|
||||
"""启动 Webhook 服务器"""
|
||||
logger.info(
|
||||
@@ -128,9 +141,14 @@ class SlackSocketClient:
|
||||
self.event_handler = event_handler
|
||||
self.socket_client = None
|
||||
|
||||
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
|
||||
async def _handle_events(
|
||||
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
|
||||
):
|
||||
"""处理 Socket Mode 事件"""
|
||||
try:
|
||||
if self.socket_client is None:
|
||||
raise RuntimeError("Socket client is not initialized")
|
||||
|
||||
# 确认收到事件
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
await self.socket_client.send_socket_mode_response(response)
|
||||
|
||||
@@ -3,8 +3,7 @@ import base64
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import aiohttp
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@@ -21,6 +20,7 @@ from astrbot.api.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .client import SlackSocketClient, SlackWebhookClient
|
||||
@@ -39,9 +39,7 @@ class SlackAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
@@ -49,6 +47,7 @@ class SlackAdapter(Platform):
|
||||
self.app_token = platform_config.get("app_token")
|
||||
self.signing_secret = platform_config.get("signing_secret")
|
||||
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
|
||||
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
|
||||
self.webhook_path = platform_config.get(
|
||||
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
|
||||
self.metadata = PlatformMetadata(
|
||||
name="slack",
|
||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
|
||||
logger.debug(f"[slack] RawMessage {event}")
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self.bot_self_id
|
||||
abm.self_id = cast(str, self.bot_self_id)
|
||||
|
||||
# 获取用户信息
|
||||
user_id = event.get("user", "")
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=user_id)
|
||||
user_data = user_info["user"]
|
||||
user_data = cast(dict, user_info["user"])
|
||||
user_name = user_data.get("real_name") or user_data.get("name", user_id)
|
||||
except Exception:
|
||||
user_name = user_id
|
||||
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
|
||||
channel_id = event.get("channel", "")
|
||||
try:
|
||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||
is_im = channel_info["channel"]["is_im"]
|
||||
is_im = cast(dict, channel_info["channel"])["is_im"]
|
||||
|
||||
if is_im:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
|
||||
for mention in mentions:
|
||||
try:
|
||||
mentioned_user = await self.web_client.users_info(user=mention)
|
||||
user_data = mentioned_user["user"]
|
||||
user_data = cast(dict, mentioned_user["user"])
|
||||
user_name = user_data.get("real_name") or user_data.get(
|
||||
"name",
|
||||
mention,
|
||||
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
|
||||
)
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
async def run(self) -> None:
|
||||
self.bot_self_id = await self.get_bot_user_id()
|
||||
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
|
||||
|
||||
@@ -361,10 +360,17 @@ class SlackAdapter(Platform):
|
||||
self._handle_webhook_event,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
|
||||
)
|
||||
await self.webhook_client.start()
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(Slack)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.webhook_client.shutdown_event.wait()
|
||||
else:
|
||||
logger.info(
|
||||
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
|
||||
)
|
||||
await self.webhook_client.start()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -391,12 +397,19 @@ class SlackAdapter(Platform):
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
if self.connection_mode != "webhook" or not self.webhook_client:
|
||||
return {"error": "Slack adapter is not in webhook mode"}, 400
|
||||
|
||||
return await self.webhook_client.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
if self.socket_client:
|
||||
await self.socket_client.stop()
|
||||
if self.webhook_client:
|
||||
await self.webhook_client.stop()
|
||||
logger.info("Slack 适配器已被优雅地关闭")
|
||||
logger.info("Slack 适配器已被关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
@@ -414,3 +427,10 @@ class SlackAdapter(Platform):
|
||||
|
||||
def get_client(self):
|
||||
return self.web_client
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
return bool(
|
||||
self.config.get("unified_webhook_mode", False)
|
||||
and self.config.get("slack_connection_mode", "") == "webhook"
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
@@ -31,14 +32,14 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
async def _from_segment_to_slack_block(
|
||||
segment: BaseMessageComponent,
|
||||
web_client: AsyncWebClient,
|
||||
) -> dict:
|
||||
) -> dict | None:
|
||||
"""将消息段转换为 Slack 块格式"""
|
||||
if isinstance(segment, Plain):
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
|
||||
if isinstance(segment, Image):
|
||||
# upload file
|
||||
url = segment.url or segment.file
|
||||
if url.startswith("http"):
|
||||
if url and url.startswith("http"):
|
||||
return {
|
||||
"type": "image",
|
||||
"image_url": url,
|
||||
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "图片上传失败"},
|
||||
}
|
||||
image_url = response["files"][0]["url_private"]
|
||||
image_url = cast(list, response["files"])[0]["url_private"]
|
||||
logger.debug(f"Slack file upload response: {response}")
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
file_url = cast(list, response["files"])[0]["permalink"]
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
@@ -85,7 +86,6 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
||||
},
|
||||
}
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||
|
||||
@staticmethod
|
||||
async def _parse_slack_blocks(
|
||||
@@ -115,7 +115,8 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
segment,
|
||||
web_client,
|
||||
)
|
||||
blocks.append(block)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
|
||||
# 如果最后还有文本内容
|
||||
if text_content.strip():
|
||||
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
|
||||
members = []
|
||||
for member_id in members_response["members"]:
|
||||
for member_id in cast(Iterable, members_response["members"]):
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=member_id)
|
||||
user_data = user_info["user"]
|
||||
user_data = cast(dict, user_info["user"])
|
||||
members.append(
|
||||
MessageMember(
|
||||
user_id=member_id,
|
||||
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
# 如果获取用户信息失败,使用默认信息
|
||||
members.append(MessageMember(user_id=member_id, nickname=member_id))
|
||||
|
||||
channel_data = channel_info["channel"]
|
||||
channel_data = cast(dict, channel_info["channel"])
|
||||
return Group(
|
||||
group_id=channel_id,
|
||||
group_name=channel_data.get("name", ""),
|
||||
|
||||
@@ -42,8 +42,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@@ -381,7 +380,9 @@ class TelegramPlatformAdapter(Platform):
|
||||
f"Telegram document file_path is None, cannot save the file {file_name}.",
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
||||
message.message.append(
|
||||
Comp.File(file=file_path, name=file_name, url=file_path)
|
||||
)
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
@@ -423,6 +424,6 @@ class TelegramPlatformAdapter(Platform):
|
||||
if self.application.updater is not None:
|
||||
await self.application.updater.stop()
|
||||
|
||||
logger.info("Telegram 适配器已被优雅地关闭")
|
||||
logger.info("Telegram 适配器已被关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"Telegram 适配器关闭时出错: {e}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
import telegramify_markdown
|
||||
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
|
||||
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
|
||||
Reply,
|
||||
)
|
||||
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
"chat_id": user_name,
|
||||
}
|
||||
if has_reply:
|
||||
payload["reply_to_message_id"] = reply_message_id
|
||||
payload["reply_to_message_id"] = str(reply_message_id)
|
||||
if message_thread_id:
|
||||
payload["message_thread_id"] = message_thread_id
|
||||
|
||||
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
chunk,
|
||||
max_line_length=None,
|
||||
normalize_whitespace=False,
|
||||
)
|
||||
await client.send_message(
|
||||
text=md_text,
|
||||
parse_mode="MarkdownV2",
|
||||
**payload,
|
||||
**cast(Any, payload),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 send failed: {e}. Using plain text instead.",
|
||||
)
|
||||
await client.send_message(text=chunk, **payload)
|
||||
await client.send_message(text=chunk, **cast(Any, payload))
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
await client.send_photo(photo=image_path, **cast(Any, payload))
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
await client.send_document(document=i.file, filename=i.name, **payload)
|
||||
path = await i.get_file()
|
||||
name = i.name or os.path.basename(path)
|
||||
await client.send_document(
|
||||
document=path, filename=name, **cast(Any, payload)
|
||||
)
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await client.send_voice(voice=path, **payload)
|
||||
await client.send_voice(voice=path, **cast(Any, payload))
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
@@ -204,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
if isinstance(chain, MessageChain):
|
||||
if chain.type == "break":
|
||||
# 分割符
|
||||
if message_id:
|
||||
try:
|
||||
await self.client.edit_message_text(
|
||||
text=delta,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
|
||||
message_id = None # 重置消息 ID
|
||||
delta = "" # 重置 delta
|
||||
continue
|
||||
@@ -214,24 +219,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
delta += i.text
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await self.client.send_photo(photo=image_path, **payload)
|
||||
await self.client.send_photo(
|
||||
photo=image_path, **cast(Any, payload)
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
path = await i.get_file()
|
||||
name = i.name or os.path.basename(path)
|
||||
|
||||
await self.client.send_document(
|
||||
document=i.file,
|
||||
filename=i.name,
|
||||
**payload,
|
||||
document=path,
|
||||
filename=name,
|
||||
**cast(Any, payload),
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await self.client.send_voice(voice=path, **payload)
|
||||
await self.client.send_voice(voice=path, **cast(Any, payload))
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||
@@ -260,7 +264,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
msg = await self.client.send_message(
|
||||
text=delta, **cast(Any, payload)
|
||||
)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
@@ -274,7 +280,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
markdown_text = telegramify_markdown.markdownify(
|
||||
delta,
|
||||
max_line_length=None,
|
||||
normalize_whitespace=False,
|
||||
)
|
||||
await self.client.edit_message_text(
|
||||
|
||||
@@ -2,11 +2,13 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Image, Plain, Record
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core.db.po import PlatformMessageHistory
|
||||
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform import (
|
||||
AstrBotMessage,
|
||||
@@ -74,9 +76,8 @@ class WebChatAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
@@ -96,6 +97,92 @@ class WebChatAdapter(Platform):
|
||||
await WebChatMessageEvent._send(message_chain, session.session_id)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def _get_message_history(
|
||||
self, message_id: int
|
||||
) -> PlatformMessageHistory | None:
|
||||
return await db_helper.get_platform_message_history_by_id(message_id)
|
||||
|
||||
async def _parse_message_parts(
|
||||
self,
|
||||
message_parts: list,
|
||||
depth: int = 0,
|
||||
max_depth: int = 1,
|
||||
) -> tuple[list, list[str]]:
|
||||
"""解析消息段列表,返回消息组件列表和纯文本列表
|
||||
|
||||
Args:
|
||||
message_parts: 消息段列表
|
||||
depth: 当前递归深度
|
||||
max_depth: 最大递归深度(用于处理 reply)
|
||||
|
||||
Returns:
|
||||
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
|
||||
"""
|
||||
components = []
|
||||
text_parts = []
|
||||
|
||||
for part in message_parts:
|
||||
part_type = part.get("type")
|
||||
if part_type == "plain":
|
||||
text = part.get("text", "")
|
||||
components.append(Plain(text))
|
||||
text_parts.append(text)
|
||||
elif part_type == "reply":
|
||||
message_id = part.get("message_id")
|
||||
reply_chain = []
|
||||
reply_message_str = ""
|
||||
sender_id = None
|
||||
sender_name = None
|
||||
|
||||
# recursively get the content of the referenced message
|
||||
if depth < max_depth and message_id:
|
||||
history = await self._get_message_history(message_id)
|
||||
if history and history.content:
|
||||
reply_parts = history.content.get("message", [])
|
||||
if isinstance(reply_parts, list):
|
||||
(
|
||||
reply_chain,
|
||||
reply_text_parts,
|
||||
) = await self._parse_message_parts(
|
||||
reply_parts,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
reply_message_str = "".join(reply_text_parts)
|
||||
sender_id = history.sender_id
|
||||
sender_name = history.sender_name
|
||||
|
||||
components.append(
|
||||
Reply(
|
||||
id=message_id,
|
||||
chain=reply_chain,
|
||||
message_str=reply_message_str,
|
||||
sender_id=sender_id,
|
||||
sender_nickname=sender_name,
|
||||
)
|
||||
)
|
||||
elif part_type == "image":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Image.fromFileSystem(path))
|
||||
elif part_type == "record":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Record.fromFileSystem(path))
|
||||
elif part_type == "file":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
filename = part.get("filename") or (
|
||||
os.path.basename(path) if path else "file"
|
||||
)
|
||||
components.append(File(name=filename, file=path))
|
||||
elif part_type == "video":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Video.fromFileSystem(path))
|
||||
|
||||
return components, text_parts
|
||||
|
||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||
username, cid, payload = data
|
||||
|
||||
@@ -108,40 +195,19 @@ class WebChatAdapter(Platform):
|
||||
abm.session_id = f"webchat!{username}!{cid}"
|
||||
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.message = []
|
||||
|
||||
if payload["message"]:
|
||||
abm.message.append(Plain(payload["message"]))
|
||||
if payload["image_url"]:
|
||||
if isinstance(payload["image_url"], list):
|
||||
for img in payload["image_url"]:
|
||||
abm.message.append(
|
||||
Image.fromFileSystem(os.path.join(self.imgs_dir, img)),
|
||||
)
|
||||
else:
|
||||
abm.message.append(
|
||||
Image.fromFileSystem(
|
||||
os.path.join(self.imgs_dir, payload["image_url"]),
|
||||
),
|
||||
)
|
||||
if payload["audio_url"]:
|
||||
if isinstance(payload["audio_url"], list):
|
||||
for audio in payload["audio_url"]:
|
||||
path = os.path.join(self.imgs_dir, audio)
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
else:
|
||||
path = os.path.join(self.imgs_dir, payload["audio_url"])
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
# 处理消息段列表
|
||||
message_parts = payload.get("message", [])
|
||||
abm.message, message_str_parts = await self._parse_message_parts(message_parts)
|
||||
|
||||
logger.debug(f"WebChatAdapter: {abm.message}")
|
||||
|
||||
message_str = payload["message"]
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.message_str = "".join(message_str_parts)
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
async def callback(data: tuple):
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Image, Plain, Record
|
||||
from astrbot.api.message_components import File, Image, Json, Plain, Record
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from .webchat_queue_mgr import webchat_queue_mgr
|
||||
|
||||
@@ -19,7 +20,9 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
os.makedirs(imgs_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||
async def _send(
|
||||
message: MessageChain | None, session_id: str, streaming: bool = False
|
||||
) -> str | None:
|
||||
cid = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
if not message:
|
||||
@@ -30,7 +33,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"streaming": False,
|
||||
}, # end means this request is finished
|
||||
)
|
||||
return ""
|
||||
return
|
||||
|
||||
data = ""
|
||||
for comp in message.chain:
|
||||
@@ -39,61 +42,62 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Json):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"data": json.dumps(comp.data, ensure_ascii=False),
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
filename = f"{str(uuid.uuid4())}.jpg"
|
||||
path = os.path.join(imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file.startswith("base64://"):
|
||||
base64_str = comp.file[9:]
|
||||
image_data = base64.b64decode(base64_str)
|
||||
with open(path, "wb") as f:
|
||||
f.write(image_data)
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
image_base64 = await comp.convert_to_base64()
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(image_base64))
|
||||
data = f"[IMAGE]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "image",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
# save record to local
|
||||
filename = str(uuid.uuid4()) + ".wav"
|
||||
filename = f"{str(uuid.uuid4())}.wav"
|
||||
path = os.path.join(imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
record_base64 = await comp.convert_to_base64()
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(record_base64))
|
||||
data = f"[RECORD]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "record",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
# save file to local
|
||||
file_path = await comp.get_file()
|
||||
original_name = comp.name or os.path.basename(file_path)
|
||||
ext = os.path.splitext(original_name)[1] or ""
|
||||
filename = f"{uuid.uuid4()!s}{ext}"
|
||||
dest_path = os.path.join(imgs_dir, filename)
|
||||
shutil.copy2(file_path, dest_path)
|
||||
data = f"[FILE]{filename}|{original_name}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "file",
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
@@ -103,9 +107,9 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain | None):
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
await super().send(message)
|
||||
await super().send(MessageChain([]))
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
@@ -113,24 +117,25 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
async for chain in generator:
|
||||
if chain.type == "break" and final_data:
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "break", # break means a segment end
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
# if chain.type == "break" and final_data:
|
||||
# # 分割符
|
||||
# await web_chat_back_queue.put(
|
||||
# {
|
||||
# "type": "break", # break means a segment end
|
||||
# "data": final_data,
|
||||
# "streaming": True,
|
||||
# },
|
||||
# )
|
||||
# final_data = ""
|
||||
# continue
|
||||
|
||||
r = await WebChatMessageEvent._send(
|
||||
chain,
|
||||
session_id=self.session_id,
|
||||
streaming=True,
|
||||
)
|
||||
if not r:
|
||||
continue
|
||||
if chain.type == "reasoning":
|
||||
reasoning_content += chain.get_plain_text()
|
||||
else:
|
||||
@@ -142,7 +147,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"data": final_data,
|
||||
"reasoning": reasoning_content,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
@@ -42,10 +43,9 @@ class WeChatPadProAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
super().__init__(platform_config, event_queue)
|
||||
self._shutdown_event = None
|
||||
self.wxnewpass = None
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
@@ -70,7 +70,7 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
self.base_url = f"http://{self.host}:{self.port}"
|
||||
self.auth_key = None # 用于保存生成的授权码
|
||||
self.wxid = None # 用于保存登录成功后的 wxid
|
||||
self.wxid: str | None = None # 用于保存登录成功后的 wxid
|
||||
self.credentials_file = os.path.join(
|
||||
get_astrbot_data_path(),
|
||||
"wechatpadpro_credentials.json",
|
||||
@@ -399,7 +399,7 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def handle_websocket_message(self, message: str):
|
||||
async def handle_websocket_message(self, message: str | bytes):
|
||||
"""处理从 WebSocket 接收到的消息。"""
|
||||
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||
try:
|
||||
@@ -431,10 +431,13 @@ class WeChatPadProAdapter(Platform):
|
||||
|
||||
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
|
||||
if self.wxid is None:
|
||||
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
|
||||
return None
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = raw_message
|
||||
abm.message_id = str(raw_message.get("msg_id"))
|
||||
abm.timestamp = raw_message.get("create_time")
|
||||
abm.timestamp = cast(int, raw_message.get("create_time"))
|
||||
abm.self_id = self.wxid
|
||||
|
||||
if int(time.time()) - abm.timestamp > 180:
|
||||
@@ -447,7 +450,7 @@ class WeChatPadProAdapter(Platform):
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
content = raw_message.get("content", {}).get("str", "")
|
||||
push_content = raw_message.get("push_content", "")
|
||||
msg_type = raw_message.get("msg_type")
|
||||
msg_type = cast(int, raw_message.get("msg_type"))
|
||||
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
@@ -575,7 +578,7 @@ class WeChatPadProAdapter(Platform):
|
||||
from_user_name: str,
|
||||
to_user_name: str,
|
||||
msg_id: int,
|
||||
):
|
||||
) -> dict | None:
|
||||
"""下载原始图片。"""
|
||||
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||
params = {"key": self.auth_key}
|
||||
@@ -726,12 +729,15 @@ class WeChatPadProAdapter(Platform):
|
||||
# 图片消息
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
msg_id = raw_message.get("msg_id")
|
||||
msg_id = cast(int, raw_message.get("msg_id"))
|
||||
image_resp = await self._download_raw_image(
|
||||
from_user_name,
|
||||
to_user_name,
|
||||
msg_id,
|
||||
)
|
||||
if image_resp is None:
|
||||
logger.error(f"下载图片失败: msg_id={msg_id}")
|
||||
return
|
||||
image_bs64_data = (
|
||||
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||
)
|
||||
@@ -772,6 +778,9 @@ class WeChatPadProAdapter(Platform):
|
||||
bufid = 0
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id is None:
|
||||
logger.error("语音消息缺少 new_msg_id")
|
||||
return
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
@@ -779,6 +788,9 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
|
||||
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
||||
if voicemsg is None:
|
||||
logger.error("无法从 XML 解析 voicemsg 节点")
|
||||
return
|
||||
bufid = voicemsg.get("bufid") or "0"
|
||||
length = int(voicemsg.get("length") or 0)
|
||||
voice_resp = await self.download_voice(
|
||||
@@ -787,6 +799,9 @@ class WeChatPadProAdapter(Platform):
|
||||
bufid=bufid,
|
||||
length=length,
|
||||
)
|
||||
if voice_resp is None:
|
||||
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
|
||||
return
|
||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||
if voice_bs64_data:
|
||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||
@@ -828,7 +843,8 @@ class WeChatPadProAdapter(Platform):
|
||||
try:
|
||||
if self.ws_handle_task:
|
||||
self.ws_handle_task.cancel()
|
||||
self._shutdown_event.set()
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -895,8 +911,8 @@ class WeChatPadProAdapter(Platform):
|
||||
|
||||
async def get_contact_details_list(
|
||||
self,
|
||||
room_wx_id_list: list[str] = None,
|
||||
user_names: list[str] = None,
|
||||
room_wx_id_list: list[str] | None = None,
|
||||
user_names: list[str] | None = None,
|
||||
) -> dict | None:
|
||||
"""获取联系人详情列表。"""
|
||||
if room_wx_id_list is None:
|
||||
|
||||
@@ -2,6 +2,8 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -24,6 +26,7 @@ from astrbot.api.platform import (
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from .wecom_event import WecomPlatformEvent
|
||||
from .wecom_kf import WeChatKF
|
||||
@@ -38,7 +41,7 @@ else:
|
||||
class WecomServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.port = int(cast(str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.server.add_url_rule(
|
||||
"/callback/command",
|
||||
@@ -58,12 +61,24 @@ class WecomServer:
|
||||
config["corpid"].strip(),
|
||||
)
|
||||
|
||||
self.callback = None
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||
args = quart.request.args
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
|
||||
async def handle_verify(self, request) -> str:
|
||||
"""处理验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
验证响应
|
||||
"""
|
||||
logger.info(f"验证请求有效性: {request.args}")
|
||||
args = request.args
|
||||
try:
|
||||
echo_str = self.crypto.check_signature(
|
||||
args.get("msg_signature"),
|
||||
@@ -78,17 +93,29 @@ class WecomServer:
|
||||
raise
|
||||
|
||||
async def callback_command(self):
|
||||
data = await quart.request.get_data()
|
||||
msg_signature = quart.request.args.get("msg_signature")
|
||||
timestamp = quart.request.args.get("timestamp")
|
||||
nonce = quart.request.args.get("nonce")
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request) -> str:
|
||||
"""处理回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
data = await request.get_data()
|
||||
msg_signature = request.args.get("msg_signature")
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
try:
|
||||
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||
except InvalidSignatureException:
|
||||
logger.error("解密失败,签名异常,请检查配置。")
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
msg = cast(BaseMessage, parse_message(xml))
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
@@ -118,14 +145,14 @@ class WecomPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settingss = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
self.api_base_url = platform_config.get(
|
||||
"api_base_url",
|
||||
"https://qyapi.weixin.qq.com/cgi-bin/",
|
||||
)
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/"
|
||||
@@ -150,10 +177,10 @@ class WecomPlatformAdapter(Platform):
|
||||
# inject
|
||||
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||
self.client.kf = self.wechat_kf_api
|
||||
self.client.kf_message = self.wechat_kf_message_api
|
||||
self.client.__setattr__("kf", self.wechat_kf_api)
|
||||
self.client.__setattr__("kf_message", self.wechat_kf_message_api)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||
@@ -232,41 +259,53 @@ class WecomPlatformAdapter(Platform):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await self.server.start_polling()
|
||||
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(企业微信)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.server.shutdown_event.wait()
|
||||
else:
|
||||
await self.server.start_polling()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
# 根据请求方法分发到不同的处理函数
|
||||
if request.method == "GET":
|
||||
return await self.server.handle_verify(request)
|
||||
else:
|
||||
return await self.server.handle_callback(request)
|
||||
|
||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if msg.type == "text":
|
||||
assert isinstance(msg, TextMessage)
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.self_id = str(msg.agent)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
elif isinstance(msg, ImageMessage):
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.agent)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
elif isinstance(msg, VoiceMessage):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self.client.media.download,
|
||||
@@ -293,11 +332,11 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
@@ -309,7 +348,7 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||
msgtype = msg.get("msgtype")
|
||||
external_userid = msg.get("external_userid")
|
||||
external_userid = cast(str, msg.get("external_userid"))
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = msg
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
@@ -383,4 +422,4 @@ class WecomPlatformAdapter(Platform):
|
||||
await self.server.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("企业微信 适配器已被优雅地关闭")
|
||||
logger.info("企业微信 适配器已被关闭")
|
||||
|
||||
@@ -16,7 +16,7 @@ try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
|
||||
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
|
||||
)
|
||||
|
||||
|
||||
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
if is_wechat_kf:
|
||||
# 微信客服
|
||||
kf_message_api = getattr(self.client, "kf_message", None)
|
||||
if not kf_message_api:
|
||||
if not isinstance(kf_message_api, WeChatKFMessage):
|
||||
logger.warning("未找到微信客服发送消息方法。")
|
||||
return
|
||||
assert isinstance(kf_message_api, WeChatKFMessage)
|
||||
|
||||
user_id = self.get_sender_id()
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
|
||||
@@ -22,6 +22,7 @@ from astrbot.api.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .wecomai_api import (
|
||||
@@ -103,9 +104,7 @@ class WecomAIBotAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
|
||||
# 初始化配置参数
|
||||
@@ -122,6 +121,7 @@ class WecomAIBotAdapter(Platform):
|
||||
"wecomaibot_friend_message_welcome_text",
|
||||
"",
|
||||
)
|
||||
self.unified_webhook_mode = self.config.get("unified_webhook_mode", False)
|
||||
|
||||
# 平台元数据
|
||||
self.metadata = PlatformMetadata(
|
||||
@@ -425,17 +425,34 @@ class WecomAIBotAdapter(Platform):
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
"""运行适配器,同时启动HTTP服务器和队列监听器"""
|
||||
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
|
||||
|
||||
async def run_both():
|
||||
# 同时运行HTTP服务器和队列监听器
|
||||
await asyncio.gather(
|
||||
self.server.start_server(),
|
||||
self.queue_listener.run(),
|
||||
)
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
|
||||
# 只运行队列监听器
|
||||
await self.queue_listener.run()
|
||||
else:
|
||||
logger.info(
|
||||
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
|
||||
)
|
||||
# 同时运行HTTP服务器和队列监听器
|
||||
await asyncio.gather(
|
||||
self.server.start_server(),
|
||||
self.queue_listener.run(),
|
||||
)
|
||||
|
||||
return run_both()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
# 根据请求方法分发到不同的处理函数
|
||||
if request.method == "GET":
|
||||
return await self.server.handle_verify(request)
|
||||
else:
|
||||
return await self.server.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("企业微信智能机器人适配器正在关闭...")
|
||||
|
||||
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
|
||||
@staticmethod
|
||||
async def _send(
|
||||
message_chain: MessageChain,
|
||||
message_chain: MessageChain | None,
|
||||
stream_id: str,
|
||||
queue_mgr: WecomAIQueueMgr,
|
||||
streaming: bool = False,
|
||||
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain | None):
|
||||
"""发送消息"""
|
||||
raw = self.message_obj.raw_message
|
||||
assert isinstance(raw, dict), (
|
||||
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
|
||||
await super().send(message)
|
||||
await super().send(MessageChain([]))
|
||||
|
||||
async def send_streaming(self, generator, use_fallback=False):
|
||||
"""流式发送消息,参考webchat的send_streaming设计"""
|
||||
|
||||
@@ -59,8 +59,19 @@ class WecomAIBotServer:
|
||||
)
|
||||
|
||||
async def verify_url(self):
|
||||
"""验证回调 URL"""
|
||||
args = quart.request.args
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
|
||||
async def handle_verify(self, request):
|
||||
"""处理 URL 验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
验证响应元组 (content, status_code, headers)
|
||||
"""
|
||||
args = request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
@@ -81,8 +92,19 @@ class WecomAIBotServer:
|
||||
return result, 200, {"Content-Type": "text/plain"}
|
||||
|
||||
async def handle_message(self):
|
||||
"""处理消息回调"""
|
||||
args = quart.request.args
|
||||
"""内部服务器的 POST 消息回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request):
|
||||
"""处理消息回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应元组 (content, status_code, headers)
|
||||
"""
|
||||
args = request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
@@ -102,7 +124,7 @@ class WecomAIBotServer:
|
||||
|
||||
try:
|
||||
# 获取请求体
|
||||
post_data = await quart.request.get_data()
|
||||
post_data = await request.get_data()
|
||||
|
||||
# 确保 post_data 是 bytes 类型
|
||||
if isinstance(post_data, str):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -22,6 +24,7 @@ from astrbot.api.platform import (
|
||||
)
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||
|
||||
@@ -31,10 +34,10 @@ else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class WecomServer:
|
||||
class WeixinOfficialAccountServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.port = int(cast(int | str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.token = config.get("token")
|
||||
self.encoding_aes_key = config.get("encoding_aes_key")
|
||||
@@ -53,13 +56,25 @@ class WecomServer:
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.callback = None
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
|
||||
args = quart.request.args
|
||||
async def handle_verify(self, request) -> str:
|
||||
"""处理验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
验证响应
|
||||
"""
|
||||
logger.info(f"验证请求有效性: {request.args}")
|
||||
|
||||
args = request.args
|
||||
if not args.get("signature", None):
|
||||
logger.error("未知的响应,请检查回调地址是否填写正确。")
|
||||
return "err"
|
||||
@@ -77,10 +92,22 @@ class WecomServer:
|
||||
return "err"
|
||||
|
||||
async def callback_command(self):
|
||||
data = await quart.request.get_data()
|
||||
msg_signature = quart.request.args.get("msg_signature")
|
||||
timestamp = quart.request.args.get("timestamp")
|
||||
nonce = quart.request.args.get("nonce")
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
|
||||
async def handle_callback(self, request) -> str:
|
||||
"""处理回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
data = await request.get_data()
|
||||
msg_signature = request.args.get("msg_signature")
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
try:
|
||||
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||
except InvalidSignatureException:
|
||||
@@ -88,6 +115,9 @@ class WecomServer:
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
if not msg:
|
||||
logger.error("解析失败。msg为None。")
|
||||
raise
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
@@ -123,8 +153,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settingss = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
self.api_base_url = platform_config.get(
|
||||
@@ -132,6 +161,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
"https://api.weixin.qq.com/cgi-bin/",
|
||||
)
|
||||
self.active_send_mode = self.config.get("active_send_mode", False)
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
|
||||
@@ -143,14 +173,14 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
if not self.api_base_url.endswith("/"):
|
||||
self.api_base_url += "/"
|
||||
|
||||
self.server = WecomServer(self._event_queue, self.config)
|
||||
self.server = WeixinOfficialAccountServer(self._event_queue, self.config)
|
||||
|
||||
self.client = WeChatClient(
|
||||
self.config["appid"].strip(),
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||
|
||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||
# msgid -> Future
|
||||
@@ -162,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[msg.id]
|
||||
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(
|
||||
@@ -174,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
60,
|
||||
) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -202,38 +232,53 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
await self.server.start_polling()
|
||||
# 如果启用统一 webhook 模式,则不启动独立服务器
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if self.unified_webhook_mode and webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(微信公众平台)", webhook_uuid)
|
||||
# 保持运行状态,等待 shutdown
|
||||
await self.server.shutdown_event.wait()
|
||||
else:
|
||||
await self.server.start_polling()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
# 根据请求方法分发到不同的处理函数
|
||||
if request.method == "GET":
|
||||
return await self.server.handle_verify(request)
|
||||
else:
|
||||
return await self.server.handle_callback(request)
|
||||
|
||||
async def convert_message(
|
||||
self,
|
||||
msg,
|
||||
future: asyncio.Future = None,
|
||||
future: asyncio.Future | None = None,
|
||||
) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.message_str = cast(str, msg.content)
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.message = [Plain(cast(str, msg.content))]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
@@ -265,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
future.set_result(None)
|
||||
if future:
|
||||
future.set_result(None)
|
||||
return
|
||||
# 很不优雅 :(
|
||||
abm.raw_message = {
|
||||
@@ -303,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
await self.server.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("微信公众平台 适配器已被优雅地关闭")
|
||||
logger.info("微信公众平台 适配器已被关闭")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.replies import ImageReply, TextReply, VoiceReply
|
||||
@@ -13,7 +14,7 @@ try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
|
||||
)
|
||||
|
||||
|
||||
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||
active_send_mode = cast(dict, message_obj.raw_message).get(
|
||||
"active_send_mode", False
|
||||
)
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = TextReply(
|
||||
content=chunk,
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
|
||||
@@ -10,12 +10,12 @@ class PlatformMessageHistoryManager:
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict], # TODO: parse from message chain
|
||||
content: dict, # TODO: parse from message chain
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
):
|
||||
) -> PlatformMessageHistory:
|
||||
"""Insert a new platform message history record."""
|
||||
await self.db.insert_platform_message_history(
|
||||
return await self.db.insert_platform_message_history(
|
||||
platform_id=platform_id,
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import enum
|
||||
import json
|
||||
@@ -12,6 +14,7 @@ import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import (
|
||||
AssistantMessageSegment,
|
||||
ContentPart,
|
||||
ToolCall,
|
||||
ToolCallMessageSegment,
|
||||
)
|
||||
@@ -90,6 +93,8 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
|
||||
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。"""
|
||||
func_tool: ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
@@ -164,13 +169,23 @@ class ProviderRequest:
|
||||
|
||||
async def assemble_context(self) -> dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if self.prompt and self.prompt.strip():
|
||||
content_blocks.append({"type": "text", "text": self.prompt})
|
||||
elif self.image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if self.extra_user_content_parts:
|
||||
for part in self.extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -183,11 +198,21 @@ class ProviderRequest:
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": self.prompt}
|
||||
|
||||
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
|
||||
if (
|
||||
len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
and not self.extra_user_content_parts
|
||||
and not self.image_urls
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
@@ -199,6 +224,38 @@ class ProviderRequest:
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
input_other: int = 0
|
||||
"""The number of input tokens, excluding cached tokens."""
|
||||
input_cached: int = 0
|
||||
"""The number of input cached tokens."""
|
||||
output: int = 0
|
||||
"""The number of output tokens."""
|
||||
|
||||
@property
|
||||
def total(self) -> int:
|
||||
return self.input_other + self.input_cached + self.output
|
||||
|
||||
@property
|
||||
def input(self) -> int:
|
||||
return self.input_other + self.input_cached
|
||||
|
||||
def __add__(self, other: TokenUsage) -> TokenUsage:
|
||||
return TokenUsage(
|
||||
input_other=self.input_other + other.input_other,
|
||||
input_cached=self.input_cached + other.input_cached,
|
||||
output=self.output + other.output,
|
||||
)
|
||||
|
||||
def __sub__(self, other: TokenUsage) -> TokenUsage:
|
||||
return TokenUsage(
|
||||
input_other=self.input_other - other.input_other,
|
||||
input_cached=self.input_cached - other.input_cached,
|
||||
output=self.output - other.output,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
@@ -227,6 +284,11 @@ class LLMResponse:
|
||||
is_chunk: bool = False
|
||||
"""Indicates if the response is a chunked response."""
|
||||
|
||||
id: str | None = None
|
||||
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
|
||||
usage: TokenUsage | None = None
|
||||
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
@@ -241,6 +303,8 @@ class LLMResponse:
|
||||
| AnthropicMessage
|
||||
| None = None,
|
||||
is_chunk: bool = False,
|
||||
id: str | None = None,
|
||||
usage: TokenUsage | None = None,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -118,7 +118,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list[dict],
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> FuncTool:
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -140,7 +140,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> None:
|
||||
"""添加函数调用工具
|
||||
|
||||
|
||||
+329
-163
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import traceback
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from astrbot.core import astrbot_config, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -10,6 +12,7 @@ from .entities import ProviderType
|
||||
from .provider import (
|
||||
EmbeddingProvider,
|
||||
Provider,
|
||||
Providers,
|
||||
RerankProvider,
|
||||
STTProvider,
|
||||
TTSProvider,
|
||||
@@ -17,6 +20,11 @@ from .provider import (
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasInitialize(Protocol):
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -25,10 +33,12 @@ class ProviderManager:
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.reload_lock = asyncio.Lock()
|
||||
self.resource_lock = asyncio.Lock()
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: list = config["provider"]
|
||||
self.provider_sources_config: list = config.get("provider_sources", [])
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
@@ -48,7 +58,7 @@ class ProviderManager:
|
||||
"""加载的 Rerank Provider 的实例"""
|
||||
self.inst_map: dict[
|
||||
str,
|
||||
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
||||
Providers,
|
||||
] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
@@ -123,15 +133,13 @@ class ProviderManager:
|
||||
self.curr_provider_inst = prov
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(
|
||||
self,
|
||||
provider_type: ProviderType,
|
||||
umo=None,
|
||||
) -> Provider | STTProvider | TTSProvider | None:
|
||||
self, provider_type: ProviderType, umo=None
|
||||
) -> Providers | None:
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
@@ -143,6 +151,7 @@ class ProviderManager:
|
||||
|
||||
"""
|
||||
provider = None
|
||||
provider_id = None
|
||||
if umo:
|
||||
provider_id = sp.get(
|
||||
f"provider_perf_{provider_type.value}",
|
||||
@@ -180,6 +189,12 @@ class ProviderManager:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
if not provider and provider_id:
|
||||
logger.warning(
|
||||
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
async def initialize(self):
|
||||
@@ -191,7 +206,6 @@ class ProviderManager:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
@@ -210,22 +224,173 @@ class ProviderManager:
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
||||
|
||||
temp_provider = (
|
||||
self.inst_map.get(selected_provider_id)
|
||||
if isinstance(selected_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_provider_inst = (
|
||||
temp_provider if isinstance(temp_provider, Provider) else None
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
|
||||
temp_stt = (
|
||||
self.inst_map.get(selected_stt_provider_id)
|
||||
if isinstance(selected_stt_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_stt_provider_inst = (
|
||||
temp_stt if isinstance(temp_stt, STTProvider) else None
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
|
||||
temp_tts = (
|
||||
self.inst_map.get(selected_tts_provider_id)
|
||||
if isinstance(selected_tts_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_tts_provider_inst = (
|
||||
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
def dynamic_import_provider(self, type: str):
|
||||
"""动态导入提供商适配器模块
|
||||
|
||||
Args:
|
||||
type (str): 提供商请求类型。
|
||||
|
||||
Raises:
|
||||
ImportError: 如果提供商类型未知或无法导入对应模块,则抛出异常。
|
||||
"""
|
||||
match type:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import (
|
||||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
)
|
||||
case "sensevoice_stt_selfhost":
|
||||
from .sources.sensevoice_selfhosted_source import (
|
||||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||||
)
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import (
|
||||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||||
)
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import (
|
||||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||||
)
|
||||
case "xinference_stt":
|
||||
from .sources.xinference_stt_provider import (
|
||||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||||
)
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import (
|
||||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||||
)
|
||||
case "edge_tts":
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
)
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import (
|
||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||
)
|
||||
case "dashscope_tts":
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
case "xinference_rerank":
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
|
||||
def get_merged_provider_config(self, provider_config: dict) -> dict:
|
||||
"""获取 provider 配置和 provider_source 配置合并后的结果
|
||||
|
||||
Returns:
|
||||
dict: 合并后的 provider 配置,key 为 provider id,value 为合并后的配置字典
|
||||
"""
|
||||
pc = copy.deepcopy(provider_config)
|
||||
provider_source_id = pc.get("provider_source_id", "")
|
||||
if provider_source_id:
|
||||
provider_source = None
|
||||
for ps in self.provider_sources_config:
|
||||
if ps.get("id") == provider_source_id:
|
||||
provider_source = ps
|
||||
break
|
||||
|
||||
if provider_source:
|
||||
# 合并配置,provider 的配置优先级更高
|
||||
merged_config = {**provider_source, **pc}
|
||||
# 保持 id 为 provider 的 id,而不是 source 的 id
|
||||
merged_config["id"] = pc["id"]
|
||||
pc = merged_config
|
||||
return pc
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
|
||||
provider_config = self.get_merged_provider_config(provider_config)
|
||||
|
||||
if not provider_config["enable"]:
|
||||
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||||
return
|
||||
@@ -238,99 +403,7 @@ class ProviderManager:
|
||||
|
||||
# 动态导入
|
||||
try:
|
||||
match provider_config["type"]:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import (
|
||||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
)
|
||||
case "sensevoice_stt_selfhost":
|
||||
from .sources.sensevoice_selfhosted_source import (
|
||||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||||
)
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import (
|
||||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||||
)
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import (
|
||||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||||
)
|
||||
case "xinference_stt":
|
||||
from .sources.xinference_stt_provider import (
|
||||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||||
)
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import (
|
||||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||||
)
|
||||
case "edge_tts":
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
)
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import (
|
||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||
)
|
||||
case "dashscope_tts":
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
case "xinference_rerank":
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
self.dynamic_import_provider(provider_config["type"])
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
||||
@@ -358,73 +431,103 @@ class ProviderManager:
|
||||
|
||||
provider_metadata.id = provider_config["id"]
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
match provider_metadata.provider_type:
|
||||
case ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
if not issubclass(cls_type, STTProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of STTProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
||||
)
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
case ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
if not issubclass(cls_type, TTSProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||||
)
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
case ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
if not issubclass(cls_type, Provider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of Provider"
|
||||
)
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
)
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
||||
)
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||||
case ProviderType.EMBEDDING:
|
||||
if not issubclass(cls_type, EmbeddingProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
case ProviderType.RERANK:
|
||||
if not issubclass(cls_type, RerankProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
self.rerank_provider_insts.append(inst)
|
||||
case _:
|
||||
# 未知供应商抛出异常,确保inst初始化
|
||||
# Should be unreachable
|
||||
raise Exception(
|
||||
f"未知的提供商类型:{provider_metadata.provider_type}"
|
||||
)
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
||||
)
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
elif provider_metadata.provider_type == ProviderType.RERANK:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.rerank_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
@@ -443,6 +546,7 @@ class ProviderManager:
|
||||
|
||||
# 和配置文件保持同步
|
||||
self.providers_config = astrbot_config["provider"]
|
||||
self.provider_sources_config = astrbot_config.get("provider_sources", [])
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.info(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
@@ -514,6 +618,68 @@ class ProviderManager:
|
||||
)
|
||||
del self.inst_map[provider_id]
|
||||
|
||||
async def delete_provider(
|
||||
self, provider_id: str | None = None, provider_source_id: str | None = None
|
||||
):
|
||||
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
|
||||
async with self.resource_lock:
|
||||
# delete from config
|
||||
target_prov_ids = []
|
||||
if provider_id:
|
||||
target_prov_ids.append(provider_id)
|
||||
else:
|
||||
for prov in self.providers_config:
|
||||
if prov.get("provider_source_id") == provider_source_id:
|
||||
target_prov_ids.append(prov.get("id"))
|
||||
config = self.acm.default_conf
|
||||
for tpid in target_prov_ids:
|
||||
await self.terminate_provider(tpid)
|
||||
config["provider"] = [
|
||||
prov for prov in config["provider"] if prov.get("id") != tpid
|
||||
]
|
||||
config.save_config()
|
||||
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
|
||||
|
||||
async def update_provider(self, origin_provider_id: str, new_config: dict):
|
||||
"""Update provider config and reload the instance. Config will be saved after update."""
|
||||
async with self.resource_lock:
|
||||
npid = new_config.get("id", None)
|
||||
if not npid:
|
||||
raise ValueError("New provider config must have an 'id' field")
|
||||
config = self.acm.default_conf
|
||||
for provider in config["provider"]:
|
||||
if (
|
||||
provider.get("id", None) == npid
|
||||
and provider.get("id", None) != origin_provider_id
|
||||
):
|
||||
raise ValueError(f"Provider ID {npid} already exists")
|
||||
# update config
|
||||
for idx, provider in enumerate(config["provider"]):
|
||||
if provider.get("id", None) == origin_provider_id:
|
||||
config["provider"][idx] = new_config
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Provider ID {origin_provider_id} not found")
|
||||
config.save_config()
|
||||
# reload instance
|
||||
await self.reload(new_config)
|
||||
|
||||
async def create_provider(self, new_config: dict):
|
||||
"""Add new provider config and load the instance. Config will be saved after addition."""
|
||||
async with self.resource_lock:
|
||||
npid = new_config.get("id", None)
|
||||
if not npid:
|
||||
raise ValueError("New provider config must have an 'id' field")
|
||||
config = self.acm.default_conf
|
||||
for provider in config["provider"]:
|
||||
if provider.get("id", None) == npid:
|
||||
raise ValueError(f"Provider ID {npid} already exists")
|
||||
# add to config
|
||||
config["provider"].append(new_config)
|
||||
config.save_config()
|
||||
# load instance
|
||||
await self.load_provider(new_config)
|
||||
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
@@ -11,6 +13,15 @@ from astrbot.core.provider.entities import (
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
Providers: TypeAlias = Union[
|
||||
"Provider",
|
||||
"STTProvider",
|
||||
"TTSProvider",
|
||||
"EmbeddingProvider",
|
||||
"RerankProvider",
|
||||
]
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
@@ -43,6 +54,14 @@ class AbstractProvider(abc.ABC):
|
||||
)
|
||||
return meta
|
||||
|
||||
async def test(self):
|
||||
"""test the provider is a
|
||||
|
||||
raises:
|
||||
Exception: if the provider is not available
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
"""Chat Provider"""
|
||||
@@ -84,6 +103,7 @@ class Provider(AbstractProvider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
@@ -95,6 +115,7 @@ class Provider(AbstractProvider):
|
||||
tools: tool set
|
||||
contexts: 上下文,和 prompt 二选一使用
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
extra_user_content_parts: 额外的用户内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
@@ -114,6 +135,7 @@ class Provider(AbstractProvider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||
@@ -125,6 +147,7 @@ class Provider(AbstractProvider):
|
||||
tools: tool set
|
||||
contexts: 上下文,和 prompt 二选一使用
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
extra_user_content_parts: 额外的用户内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
@@ -132,7 +155,9 @@ class Provider(AbstractProvider):
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
|
||||
"""
|
||||
...
|
||||
if False: # pragma: no cover - make this an async generator for typing
|
||||
yield None # type: ignore
|
||||
raise NotImplementedError()
|
||||
|
||||
async def pop_record(self, context: list):
|
||||
"""弹出 context 第一条非系统提示词对话记录"""
|
||||
@@ -165,6 +190,12 @@ class Provider(AbstractProvider):
|
||||
|
||||
return dicts
|
||||
|
||||
async def test(self, timeout: float = 45.0):
|
||||
await asyncio.wait_for(
|
||||
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -177,6 +208,14 @@ class STTProvider(AbstractProvider):
|
||||
"""获取音频的文本"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self):
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
await self.get_text(sample_audio_path)
|
||||
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -189,6 +228,9 @@ class TTSProvider(AbstractProvider):
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self):
|
||||
await self.get_audio("hi")
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
@@ -211,6 +253,9 @@ class EmbeddingProvider(AbstractProvider):
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
async def test(self):
|
||||
await self.get_embedding("astrbot")
|
||||
|
||||
async def get_embeddings_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
@@ -294,3 +339,8 @@ class RerankProvider(AbstractProvider):
|
||||
) -> list[RerankResult]:
|
||||
"""获取查询和文档的重排序分数"""
|
||||
...
|
||||
|
||||
async def test(self):
|
||||
result = await self.rerank("Apple", documents=["apple", "banana"])
|
||||
if not result:
|
||||
raise Exception("Rerank provider test failed, no results returned")
|
||||
|
||||
@@ -6,10 +6,13 @@ from mimetypes import guess_type
|
||||
import anthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
||||
from anthropic.types.usage import Usage
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.agent.message import ContentPart
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -45,7 +48,7 @@ class ProviderAnthropic(Provider):
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
"""准备 Anthropic API 的请求 payload
|
||||
@@ -107,12 +110,32 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
||||
return TokenUsage(
|
||||
input_other=usage.input_tokens or 0,
|
||||
input_cached=usage.cache_read_input_tokens or 0,
|
||||
output=usage.output_tokens,
|
||||
)
|
||||
|
||||
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
||||
if usage.input_tokens is not None:
|
||||
token_usage.input_other = usage.input_tokens
|
||||
if usage.cache_read_input_tokens is not None:
|
||||
token_usage.input_cached = usage.cache_read_input_tokens
|
||||
if usage.output_tokens is not None:
|
||||
token_usage.output = usage.output_tokens
|
||||
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
completion = await self.client.messages.create(**payloads, stream=False)
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
|
||||
assert isinstance(completion, Message)
|
||||
logger.debug(f"completion: {completion}")
|
||||
@@ -131,6 +154,10 @@ class ProviderAnthropic(Provider):
|
||||
llm_response.tools_call_args.append(content_block.input)
|
||||
llm_response.tools_call_name.append(content_block.name)
|
||||
llm_response.tools_call_ids.append(content_block.id)
|
||||
|
||||
llm_response.id = completion.id
|
||||
llm_response.usage = self._extract_usage(completion.usage)
|
||||
|
||||
# TODO(Soulter): 处理 end_turn 情况
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||
@@ -151,10 +178,19 @@ class ProviderAnthropic(Provider):
|
||||
# 用于累积最终结果
|
||||
final_text = ""
|
||||
final_tool_calls = []
|
||||
id = None
|
||||
usage = TokenUsage()
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
async with self.client.messages.stream(**payloads) as stream:
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
) as stream:
|
||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||
async for event in stream:
|
||||
if event.type == "message_start":
|
||||
# the usage contains input token usage
|
||||
id = event.message.id
|
||||
usage = self._extract_usage(event.message.usage)
|
||||
if event.type == "content_block_start":
|
||||
if event.content_block.type == "text":
|
||||
# 文本块开始
|
||||
@@ -162,6 +198,8 @@ class ProviderAnthropic(Provider):
|
||||
role="assistant",
|
||||
completion_text="",
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
elif event.content_block.type == "tool_use":
|
||||
# 工具使用块开始,初始化缓冲区
|
||||
@@ -179,6 +217,8 @@ class ProviderAnthropic(Provider):
|
||||
role="assistant",
|
||||
completion_text=event.delta.text,
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
elif event.delta.type == "input_json_delta":
|
||||
# 工具调用参数增量
|
||||
@@ -215,6 +255,8 @@ class ProviderAnthropic(Provider):
|
||||
tools_call_name=[tool_info["name"]],
|
||||
tools_call_ids=[tool_info["id"]],
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# JSON 解析失败,跳过这个工具调用
|
||||
@@ -223,11 +265,17 @@ class ProviderAnthropic(Provider):
|
||||
# 清理缓冲区
|
||||
del tool_use_buffer[event.index]
|
||||
|
||||
elif event.type == "message_delta":
|
||||
if event.usage:
|
||||
self._update_usage(usage, event.usage)
|
||||
|
||||
# 返回最终的完整结果
|
||||
final_response = LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=final_text,
|
||||
is_chunk=False,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
|
||||
if final_tool_calls:
|
||||
@@ -249,13 +297,16 @@ class ProviderAnthropic(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -277,10 +328,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -290,7 +340,6 @@ class ProviderAnthropic(Provider):
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
@@ -305,13 +354,16 @@ class ProviderAnthropic(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
):
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -332,10 +384,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -344,48 +395,116 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content.append({"type": "text", "text": " "})
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for block in extra_user_content_parts:
|
||||
block_type = block.get("type")
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
if block_type == "text":
|
||||
# 文本直接添加
|
||||
content.append(block)
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
elif block_type == "image_url":
|
||||
# 转换 OpenAI 格式的图片为 Anthropic 格式
|
||||
image_url_data = block.get("image_url", {})
|
||||
if isinstance(image_url_data, dict):
|
||||
url = image_url_data.get("url", "")
|
||||
else:
|
||||
# 兼容直接传 URL 字符串的情况
|
||||
url = str(image_url_data)
|
||||
|
||||
if url and url.startswith("data:"):
|
||||
try:
|
||||
# 提取 MIME 类型和 base64 数据
|
||||
mime_type = url.split(":")[1].split(";")[0]
|
||||
base64_data = (
|
||||
url.split("base64,")[1] if "base64," in url else url
|
||||
)
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换 image_url 到 Anthropic 格式失败: {e}")
|
||||
else:
|
||||
logger.warning(f"image_url 不是有效的 data URI: {url[:50]}...")
|
||||
|
||||
else:
|
||||
# 其他类型(如 audio_url)Anthropic 不支持,记录警告
|
||||
logger.debug(f"Anthropic 不支持的内容类型 '{block_type}',已忽略")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content) == 1
|
||||
and content[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
|
||||
@@ -29,15 +29,24 @@ class OTTSProvider:
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.client = None
|
||||
self._client: AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError(
|
||||
"Client not initialized. Please use 'async with' context."
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(timeout=self.timeout)
|
||||
self._client = AsyncClient(timeout=self.timeout)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _sync_time(self):
|
||||
try:
|
||||
@@ -90,6 +99,7 @@ class OTTSProvider:
|
||||
if attempt == self.retry_count - 1:
|
||||
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
raise RuntimeError("OTTS未返回音频文件")
|
||||
|
||||
|
||||
class AzureNativeProvider(TTSProvider):
|
||||
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
|
||||
self.endpoint = (
|
||||
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
self.client = None
|
||||
self._client: AsyncClient | None = None
|
||||
self.token = None
|
||||
self.token_expire = 0
|
||||
self.voice_params = {
|
||||
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
|
||||
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||
}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError(
|
||||
"Client not initialized. Please use 'async with' context."
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(
|
||||
self._client = AsyncClient(
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _refresh_token(self):
|
||||
token_url = (
|
||||
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
|
||||
key_value = provider_config.get("azure_tts_subscription_key", "")
|
||||
self.provider = self._parse_provider(key_value, provider_config)
|
||||
|
||||
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
||||
def _parse_provider(
|
||||
self, key_value: str, config: dict
|
||||
) -> OTTSProvider | AzureNativeProvider:
|
||||
if key_value.lower().startswith("other["):
|
||||
json_str = ""
|
||||
try:
|
||||
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
||||
if not match:
|
||||
|
||||
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
|
||||
Returns:
|
||||
重排序结果列表
|
||||
"""
|
||||
if not self.client:
|
||||
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
logger.warning("文档列表为空,返回空结果")
|
||||
return []
|
||||
|
||||
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"messages": None,
|
||||
"api_key": self.chosen_api_key,
|
||||
"voice": self.voice or "Cherry",
|
||||
"text": text,
|
||||
}
|
||||
if not self.voice:
|
||||
logging.warning(
|
||||
|
||||
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=mp3_path, output=wav_path)
|
||||
ff.convert(input_file=mp3_path, output_file=wav_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
# use ffmpeg command line
|
||||
|
||||
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str:
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
|
||||
Args:
|
||||
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
pattern = r"^[a-fA-F0-9]{32}$"
|
||||
return bool(re.match(pattern, reference_id.strip()))
|
||||
|
||||
async def _generate_request(self, text: str) -> dict:
|
||||
async def _generate_request(self, text: str) -> ServeTTSRequest:
|
||||
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
||||
if self.reference_id and self.reference_id.strip():
|
||||
# 验证reference_id格式
|
||||
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
text = await response.aread()
|
||||
body = await response.aread()
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
api_key: str = provider_config.get("embedding_api_key")
|
||||
api_base: str = provider_config.get("embedding_api_base")
|
||||
api_key: str = provider_config["embedding_api_key"]
|
||||
api_base: str = provider_config["embedding_api_base"]
|
||||
timeout: int = int(provider_config.get("timeout", 20))
|
||||
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
model=self.model,
|
||||
contents=text,
|
||||
)
|
||||
assert result.embeddings is not None
|
||||
assert result.embeddings[0].values is not None
|
||||
return result.embeddings[0].values
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的嵌入"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model,
|
||||
contents=texts,
|
||||
contents=cast(types.ContentListUnion, text),
|
||||
)
|
||||
return [embedding.values for embedding in result.embeddings]
|
||||
assert result.embeddings is not None
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
for embedding in result.embeddings:
|
||||
assert embedding.values is not None
|
||||
embeddings.append(embedding.values)
|
||||
return embeddings
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -12,8 +13,9 @@ from google.genai.errors import APIError
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import ContentPart
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -67,7 +69,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
self._init_client()
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
self._init_safety_settings()
|
||||
|
||||
def _init_client(self) -> None:
|
||||
@@ -111,9 +113,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
)
|
||||
# logger.error(
|
||||
# f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
# )
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
@@ -126,18 +128,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
if not modalities:
|
||||
modalities = ["Text"]
|
||||
modalities = ["TEXT"]
|
||||
|
||||
# 流式输出不支持图片模态
|
||||
if (
|
||||
self.provider_settings.get("streaming_response", False)
|
||||
and "Image" in modalities
|
||||
and "IMAGE" in modalities
|
||||
):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
modalities = ["TEXT"]
|
||||
|
||||
tool_list = []
|
||||
model_name = self.get_model()
|
||||
tool_list: list[types.Tool] | None = []
|
||||
model_name = cast(str, payloads.get("model", self.get_model()))
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
url_context = self.provider_config.get("gm_url_context", False)
|
||||
@@ -196,6 +198,53 @@ class ProviderGoogleGenAI(Provider):
|
||||
types.Tool(function_declarations=func_desc["function_declarations"]),
|
||||
]
|
||||
|
||||
# oper thinking config
|
||||
thinking_config = None
|
||||
if model_name in [
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-preview",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-lite-preview",
|
||||
"gemini-robotics-er-1.5-preview",
|
||||
"gemini-live-2.5-flash-preview-native-audio-09-2025",
|
||||
]:
|
||||
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
|
||||
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
if thinking_budget is not None:
|
||||
thinking_config = types.ThinkingConfig(
|
||||
thinking_budget=thinking_budget,
|
||||
)
|
||||
elif model_name in [
|
||||
"gemini-3-pro",
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3-flash-lite",
|
||||
"gemini-3-flash-lite-preview",
|
||||
]:
|
||||
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
|
||||
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
|
||||
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"level", "HIGH"
|
||||
)
|
||||
if thinking_level and isinstance(thinking_level, str):
|
||||
thinking_level = thinking_level.upper()
|
||||
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
|
||||
logger.warning(
|
||||
f"Invalid thinking level: {thinking_level}, using HIGH"
|
||||
)
|
||||
thinking_level = "HIGH"
|
||||
level = types.ThinkingLevel(thinking_level)
|
||||
thinking_config = types.ThinkingConfig()
|
||||
if not hasattr(types.ThinkingConfig, "thinking_level"):
|
||||
setattr(types.ThinkingConfig, "thinking_level", level)
|
||||
else:
|
||||
thinking_config.thinking_level = level
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
@@ -213,24 +262,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
logprobs=payloads.get("logprobs"),
|
||||
seed=payloads.get("seed"),
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
tools=cast(types.ToolListUnion | None, tool_list),
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=(
|
||||
types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget",
|
||||
0,
|
||||
),
|
||||
),
|
||||
24576,
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None
|
||||
),
|
||||
thinking_config=thinking_config,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True,
|
||||
),
|
||||
@@ -257,6 +291,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
content_cls: type[types.Content],
|
||||
) -> None:
|
||||
if contents and isinstance(contents[-1], content_cls):
|
||||
assert contents[-1].parts is not None
|
||||
contents[-1].parts.extend(part)
|
||||
else:
|
||||
contents.append(content_cls(parts=part))
|
||||
@@ -345,6 +380,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
]
|
||||
return "".join(thought_buf).strip()
|
||||
|
||||
def _extract_usage(
|
||||
self, usage_metadata: types.GenerateContentResponseUsageMetadata
|
||||
) -> TokenUsage:
|
||||
"""Extract usage from candidate"""
|
||||
return TokenUsage(
|
||||
input_other=usage_metadata.prompt_token_count or 0,
|
||||
input_cached=usage_metadata.cached_content_token_count or 0,
|
||||
output=usage_metadata.candidates_token_count or 0,
|
||||
)
|
||||
|
||||
def _process_content_parts(
|
||||
self,
|
||||
candidate: types.Candidate,
|
||||
@@ -429,9 +474,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
None,
|
||||
)
|
||||
|
||||
modalities = ["Text"]
|
||||
model = payloads.get("model", self.get_model())
|
||||
|
||||
modalities = ["TEXT"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("Image")
|
||||
modalities.append("IMAGE")
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
temperature = payloads.get("temperature", 0.7)
|
||||
@@ -447,8 +494,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
temperature,
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
logger.debug(f"genai result: {result}")
|
||||
@@ -473,11 +520,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
elif (
|
||||
"Multi-modal output is not supported" in e.message
|
||||
@@ -486,9 +533,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
or "only supports text output" in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
||||
f"{model} 不支持多模态输出,降级为文本模态",
|
||||
)
|
||||
modalities = ["Text"]
|
||||
modalities = ["TEXT"]
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
@@ -499,6 +546,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
result.candidates[0],
|
||||
llm_response,
|
||||
)
|
||||
llm_response.id = result.response_id
|
||||
if result.usage_metadata:
|
||||
llm_response.usage = self._extract_usage(result.usage_metadata)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
@@ -511,7 +561,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
model = payloads.get("model", self.get_model())
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
@@ -523,8 +573,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_instruction,
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
break
|
||||
@@ -533,11 +583,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
else:
|
||||
raise
|
||||
@@ -567,6 +617,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
chunk.candidates[0],
|
||||
llm_response,
|
||||
)
|
||||
llm_response.id = chunk.response_id
|
||||
if chunk.usage_metadata:
|
||||
llm_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||
yield llm_response
|
||||
return
|
||||
|
||||
@@ -594,6 +647,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
chunk.candidates[0],
|
||||
final_response,
|
||||
)
|
||||
final_response.id = chunk.response_id
|
||||
if chunk.usage_metadata:
|
||||
final_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||
break
|
||||
|
||||
# Yield final complete response with accumulated text
|
||||
@@ -625,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -650,10 +709,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -678,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -703,10 +764,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -744,13 +804,33 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文。"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -763,14 +843,25 @@ class ProviderGoogleGenAI(Provider):
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
|
||||
return json.dumps(dict_body)
|
||||
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
|
||||
"""进行流式请求"""
|
||||
try:
|
||||
async with (
|
||||
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
data = json.loads(message[6:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
audio: str | None = data.get("data", {}).get(
|
||||
"audio"
|
||||
)
|
||||
if audio is not None:
|
||||
yield audio
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的嵌入"""
|
||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
||||
embeddings = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
|
||||
@@ -12,14 +12,15 @@ from openai._exceptions import NotFoundError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
@@ -68,8 +69,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.client.chat.completions.create,
|
||||
).parameters.keys()
|
||||
|
||||
model_config = provider_config.get("model_config", {})
|
||||
model = model_config.get("model", "unknown")
|
||||
model = provider_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
|
||||
self.reasoning_key = "reasoning_content"
|
||||
@@ -208,6 +208,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# handle the content delta
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
_y = False
|
||||
llm_response.id = chunk.id
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
_y = True
|
||||
@@ -217,6 +218,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
chain=[Comp.Plain(completion_text)],
|
||||
)
|
||||
_y = True
|
||||
if chunk.usage:
|
||||
llm_response.usage = self._extract_usage(chunk.usage)
|
||||
if _y:
|
||||
yield llm_response
|
||||
|
||||
@@ -245,6 +248,15 @@ class ProviderOpenAIOfficial(Provider):
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
|
||||
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
|
||||
ptd = usage.prompt_tokens_details
|
||||
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
|
||||
return TokenUsage(
|
||||
input_other=usage.prompt_tokens - cached,
|
||||
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
|
||||
output=usage.completion_tokens,
|
||||
)
|
||||
|
||||
async def _parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: ToolSet | None
|
||||
) -> LLMResponse:
|
||||
@@ -284,6 +296,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if isinstance(tool_call, str):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
if tools is None:
|
||||
# 工具集未提供
|
||||
# Should be unreachable
|
||||
raise Exception("工具集未提供")
|
||||
for tool in tools.func_list:
|
||||
if (
|
||||
tool_call.type == "function"
|
||||
@@ -317,6 +333,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
llm_response.raw_completion = completion
|
||||
llm_response.id = completion.id
|
||||
|
||||
if completion.usage:
|
||||
llm_response.usage = self._extract_usage(completion.usage)
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -328,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -335,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -354,10 +377,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
# xAI origin search tool inject
|
||||
self._maybe_inject_xai_search(payloads, **kwargs)
|
||||
@@ -433,7 +455,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
payloads.pop("tools", None)
|
||||
return False, chosen_key, available_api_keys, payloads, context_query, None
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
# logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
@@ -457,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
@@ -466,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
model=model,
|
||||
extra_user_content_parts=extra_user_content_parts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -520,6 +544,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话,与服务商交互并逐步返回结果"""
|
||||
@@ -530,6 +555,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
model=model,
|
||||
extra_user_content_parts=extra_user_content_parts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -605,13 +631,29 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -624,14 +666,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from funasr_onnx import SenseVoiceSmall
|
||||
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
||||
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("stt_model"))
|
||||
self.set_model(provider_config["stt_model"])
|
||||
self.model = None
|
||||
self.is_emotion = provider_config.get("is_emotion", False)
|
||||
|
||||
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
loop = asyncio.get_event_loop()
|
||||
res = await loop.run_in_executor(
|
||||
None, # 使用默认的线程池
|
||||
lambda: self.model(audio_url, language="auto", use_itn=True),
|
||||
lambda: cast(SenseVoiceSmall, self.model)(
|
||||
audio_url, language="auto", use_itn=True
|
||||
),
|
||||
)
|
||||
|
||||
# res = self.model(audio_url, language="auto", use_itn=True)
|
||||
|
||||
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
assert self.client is not None
|
||||
async with self.client.post(
|
||||
f"{self.base_url}/v1/rerank",
|
||||
json=payload,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user